8
8
from typing import Dict , Any , Generator
9
9
from flask import Blueprint , Response , session , request , jsonify , current_app
10
10
import threading
11
+ import uuid
12
+ from pprint import pprint
11
13
12
14
# Get logger for this module
13
15
logger = logging .getLogger (__name__ )
14
16
15
17
# Create blueprint for SSE routes
16
18
sse_bp = Blueprint ('sse' , __name__ , url_prefix = '/api/sse' )
17
19
18
- # Global dictionary to store SSE connections
19
- sse_connections : Dict [str , queue .Queue ] = {}
20
+ # Add a lock for thread safety
21
+ sse_connections_lock = threading .RLock ()
22
+ sse_connections : Dict [str , Dict [str , Any ]] = {}
20
23
21
24
@sse_bp .route ('/connect' )
22
25
def sse_connect ():
23
26
"""
24
27
SSE endpoint for clients to establish a connection and receive real-time messages
25
28
"""
26
29
session_id = session .get ('session_id' )
30
+ connection_id = f"conn_{ uuid .uuid4 ().hex [:8 ]} "
31
+
27
32
if not session_id :
28
33
return Response ("No session ID found" , status = 401 )
29
34
30
- logger .info (f"SSE connection established for session: { session_id } " )
35
+ logger .info (f"[SSE Connect] Thread { threading .current_thread ().name } accessing sse_connections" )
36
+ logger .info (f"[SSE Connect] sse_connections id: { id (sse_connections )} " )
31
37
32
- # Create a queue for this session if it doesn't exist
33
- if session_id not in sse_connections :
34
- sse_connections [session_id ] = queue .Queue ()
38
+ # Thread-safe connection creation
39
+ with sse_connections_lock :
40
+ if session_id not in sse_connections :
41
+ sse_connections [session_id ] = {
42
+ 'queue' : queue .Queue (),
43
+ 'connected_clients' : []
44
+ }
45
+ sse_connections [session_id ]['connected_clients' ].append (connection_id )
46
+ logger .info (f"[SSE Connect] sse_connections after creation: { sse_connections } " )
35
47
48
+
36
49
def event_stream () -> Generator [str , None , None ]:
37
50
"""Generator function that yields SSE formatted messages"""
38
- client_queue = sse_connections [session_id ]
51
+ client_queue = sse_connections [session_id ][ 'queue' ]
39
52
40
53
# Send initial connection confirmation
41
54
yield format_sse_message ({
@@ -45,27 +58,39 @@ def event_stream() -> Generator[str, None, None]:
45
58
})
46
59
47
60
try :
61
+ logger .info (f"Starting event stream for connection { connection_id } for session { session_id } " )
62
+ last_heartbeat_time = time .time ()
48
63
while True :
49
64
try :
50
- # Wait for messages with a timeout to allow periodic heartbeat
51
- message = client_queue .get (timeout = 30 ) # 30 second timeout
65
+ message = client_queue .get (timeout = 1 ) # 1 second timeout
52
66
yield format_sse_message (message )
53
67
except queue .Empty :
54
68
# Send heartbeat to keep connection alive
55
- yield format_sse_message ({
56
- "type" : "notification" ,
57
- "text" : "Heartbeat" ,
58
- "timestamp" : time .time ()
59
- })
69
+ if time .time () - last_heartbeat_time > 30 :
70
+ last_heartbeat_time = time .time ()
71
+ yield format_sse_message ({
72
+ "type" : "notification" ,
73
+ "text" : "Heartbeat" ,
74
+ "timestamp" : time .time ()
75
+ })
76
+ else :
77
+ # lightweight heartbeat to keep connection alive (no data)
78
+ yield ": heartbeat\n \n "
60
79
except Exception as e :
61
80
logger .error (f"Error in SSE stream for session { session_id } : { e } " )
62
81
break
63
82
finally :
64
- # Clean up connection when client disconnects
65
- if session_id in sse_connections :
66
- del sse_connections [session_id ]
67
- logger .info (f"SSE connection closed for session: { session_id } " )
68
-
83
+ # Safe cleanup with reference counting
84
+ with sse_connections_lock :
85
+ logger .info (f"[SSE Connect] cleaning up connection { connection_id } for session { session_id } " )
86
+ logger .info (f"[SSE Connect] sse_connections before cleanup: { sse_connections } " )
87
+ if session_id in sse_connections :
88
+ sse_connections [session_id ]['connected_clients' ].remove (connection_id )
89
+ if len (sse_connections [session_id ]['connected_clients' ]) == 0 :
90
+ del sse_connections [session_id ]
91
+ logger .info (f"Last SSE connection ({ connection_id } ) closed for session { session_id } " )
92
+ logger .info (f"[SSE Connect] sse_connections after cleanup: { sse_connections } " )
93
+
69
94
return Response (
70
95
event_stream (),
71
96
mimetype = 'text/event-stream' ,
@@ -83,20 +108,27 @@ def send_message_to_session():
83
108
Endpoint to send a message to a specific session via SSE
84
109
Expected JSON payload: {
85
110
"type": "action" | "notification" | "heartbeat",
86
- "message ": { ...} ,
87
- "data": {...}
111
+ "text ": " ...." ,
112
+ "data": {...} (optional)
88
113
}
89
114
"""
90
- data = request .get_json ()
91
- if not data :
115
+ content = request .get_json ()
116
+ if not content :
92
117
return jsonify ({"error" : "No JSON data provided" }), 400
93
118
94
- target_session_id = data .get ('session_id' )
95
- message = data .get ('message' , {})
119
+ target_session_id = content .get ('session_id' )
120
+ text = content .get ('text' )
121
+ data = content .get ('data' , {})
96
122
97
123
if not target_session_id :
98
124
return jsonify ({"error" : "session_id is required" }), 400
99
125
126
+ message = {
127
+ "type" : "action" ,
128
+ "text" : text ,
129
+ "data" : data
130
+ }
131
+
100
132
success = send_sse_message (target_session_id , message )
101
133
102
134
if success :
@@ -127,52 +159,29 @@ def broadcast_message():
127
159
def get_sse_status ():
128
160
"""Get the current status of SSE connections"""
129
161
162
+ connection_info = {}
163
+ logger .info (f"[SSE Status] Thread { threading .current_thread ().name } accessing sse_connections" )
164
+ logger .info (f"[SSE Status] sse_connections id: { id (sse_connections )} " )
165
+ logger .info (f"[SSE Status] sse_connections: { sse_connections } " )
166
+
167
+ with sse_connections_lock :
168
+ for session_id , connection in sse_connections .items ():
169
+ connection_info [session_id ] = {
170
+ 'connected_clients' : connection ['connected_clients' ],
171
+ 'queue_size' : connection ['queue' ].qsize ()
172
+ }
173
+
130
174
return jsonify ({
131
- "connected_sessions" : list (sse_connections .keys ()),
132
- "total_connections" : len (sse_connections )
175
+ "connected_sessions" : connection_info
133
176
})
134
177
135
- @sse_bp .route ('/trigger_notification' , methods = ['POST' ])
136
- def trigger_notification ():
137
- """
138
- Endpoint to trigger a notification to a specific session
139
- Expected JSON payload: {
140
- "type": "notification",
141
- "text": "Notification message",
142
- "data": {...} (optional)
143
- }
144
- """
145
- data = request .get_json ()
146
- if not data :
147
- return jsonify ({"error" : "No JSON data provided" }), 400
148
-
149
- session_id = data .get ('session_id' )
150
- text = data .get ('text' )
151
- data = data .get ('data' , {})
152
-
153
- # Validate required fields
154
- if not session_id :
155
- return jsonify ({"error" : "session_id is required" }), 400
156
- if not text :
157
- return jsonify ({"error" : "text is required" }), 400
158
-
159
- # Extract any additional data
160
- additional_data = data .get ('data' , {})
161
-
162
- # Send the notification
163
- success = send_notification (
164
- session_id = session_id ,
165
- text = text ,
166
- data = data
167
- )
168
-
169
- if success :
170
- return jsonify ({
171
- "status" : "Notification sent successfully" ,
172
- "session_id" : session_id ,
173
- })
178
+ @sse_bp .route ('/sse-connection-check' , methods = ['POST' ])
179
+ def sse_connection_check ():
180
+ session_id = request .json .get ('session_id' )
181
+ if session_id in sse_connections :
182
+ return jsonify ({"status" : "connected" })
174
183
else :
175
- return jsonify ({"error " : "Session not found or not connected " }), 404
184
+ return jsonify ({"status " : "disconnected " }), 404
176
185
177
186
# Utility functions
178
187
@@ -196,18 +205,18 @@ def send_sse_message(session_id: str, message: Dict[str, Any]) -> bool:
196
205
Returns:
197
206
bool: True if message was sent successfully, False otherwise
198
207
"""
199
-
200
- if session_id not in sse_connections :
201
- logger .warning (f"Attempted to send message to non-existent session: { session_id } " )
202
- return False
203
-
204
- try :
205
- sse_connections [session_id ].put (message )
206
- logger .info (f"Message sent to session { session_id } : { message .get ('type' , 'unknown' )} " )
207
- return True
208
- except Exception as e :
209
- logger .error (f"Failed to send message to session { session_id } : { e } " )
210
- return False
208
+ with sse_connections_lock :
209
+ if session_id not in sse_connections :
210
+ logger .warning (f"Attempted to send message to non-existent session: { session_id } " )
211
+ return False
212
+
213
+ try :
214
+ sse_connections [session_id ][ 'queue' ].put (message )
215
+ logger .info (f"Message sent to session { session_id } : { message .get ('type' , 'unknown' )} " )
216
+ return True
217
+ except Exception as e :
218
+ logger .error (f"Failed to send message to session { session_id } : { e } " )
219
+ return False
211
220
212
221
def broadcast_sse_message (message : Dict [str , Any ]) -> int :
213
222
"""
@@ -221,27 +230,13 @@ def broadcast_sse_message(message: Dict[str, Any]) -> int:
221
230
"""
222
231
sent_count = 0
223
232
224
- for session_id in list (sse_connections .keys ()):
233
+ # Get a snapshot of session IDs to avoid iteration issues
234
+ with sse_connections_lock :
235
+ session_ids = list (sse_connections .keys ())
236
+
237
+ for session_id in session_ids :
225
238
if send_sse_message (session_id , message ):
226
239
sent_count += 1
227
240
228
241
logger .info (f"Broadcasted message to { sent_count } clients: { message .get ('type' , 'unknown' )} " )
229
- return sent_count
230
-
231
- def send_notification (session_id : str , text : str , data : Dict [str , Any ]):
232
- """
233
- Send a notification message to a specific session
234
-
235
- Args:
236
- session_id: Target session ID
237
- text: Notification text
238
- data: Additional data to include in the notification
239
- """
240
- message = {
241
- "type" : "notification" ,
242
- "text" : text ,
243
- "data" : data
244
- }
245
-
246
- return send_sse_message (session_id , message )
247
-
242
+ return sent_count
0 commit comments