Skip to content

Commit 3b98198

Browse files
committed
use threads so that SSE works with shared memory
1 parent 56d4151 commit 3b98198

File tree

14 files changed

+209
-336
lines changed

14 files changed

+209
-336
lines changed

local_server.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,6 @@
55
# export http_proxy=http://127.0.0.1:7890
66
# export https_proxy=http://127.0.0.1:7890
77

8-
env FLASK_APP=py-src/data_formulator/app.py FLASK_RUN_PORT=5000 FLASK_RUN_HOST=0.0.0.0 flask run
8+
#env FLASK_APP=py-src/data_formulator/app.py FLASK_RUN_PORT=5000 FLASK_RUN_HOST=0.0.0.0 flask run
9+
export FLASK_RUN_PORT=5000
10+
python -m py-src.data_formulator.app --port ${FLASK_RUN_PORT} --dev

py-src/data_formulator/agent_routes.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -425,8 +425,6 @@ def request_code_expl():
425425
if request.is_json:
426426
logger.info("# request data: ")
427427
content = request.get_json()
428-
token = content["token"]
429-
430428
client = get_client(content['model'])
431429

432430
# each table is a dict with {"name": xxx, "rows": [...]}

py-src/data_formulator/app.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,8 @@ def parse_args() -> argparse.Namespace:
256256
help="Whether to execute python in subprocess, it makes the app more secure (reducing the chance for the model to access the local machine), but increases the time of response")
257257
parser.add_argument("-d", "--disable-display-keys", action='store_true', default=False,
258258
help="Whether disable displaying keys in the frontend UI, recommended to turn on if you host the app not just for yourself.")
259+
parser.add_argument("--dev", action='store_true', default=False,
260+
help="Launch the app in development mode (prevents the app from opening the browser automatically)")
259261
return parser.parse_args()
260262

261263

@@ -268,11 +270,12 @@ def run_app():
268270
'disable_display_keys': args.disable_display_keys
269271
}
270272

271-
url = "http://localhost:{0}".format(args.port)
272-
threading.Timer(2, lambda: webbrowser.open(url, new=2)).start()
273+
if not args.dev:
274+
url = "http://localhost:{0}".format(args.port)
275+
threading.Timer(2, lambda: webbrowser.open(url, new=2)).start()
273276

274277
app.run(host='0.0.0.0', port=args.port, threaded=True)
275-
278+
276279
if __name__ == '__main__':
277280
#app.run(debug=True, host='127.0.0.1', port=5000)
278281
#use 0.0.0.0 for public

py-src/data_formulator/sse_routes.py

Lines changed: 94 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -8,34 +8,47 @@
88
from typing import Dict, Any, Generator
99
from flask import Blueprint, Response, session, request, jsonify, current_app
1010
import threading
11+
import uuid
12+
from pprint import pprint
1113

1214
# Get logger for this module
1315
logger = logging.getLogger(__name__)
1416

1517
# Create blueprint for SSE routes
1618
sse_bp = Blueprint('sse', __name__, url_prefix='/api/sse')
1719

18-
# Global dictionary to store SSE connections
19-
sse_connections: Dict[str, queue.Queue] = {}
20+
# Add a lock for thread safety
21+
sse_connections_lock = threading.RLock()
22+
sse_connections: Dict[str, Dict[str, Any]] = {}
2023

2124
@sse_bp.route('/connect')
2225
def sse_connect():
2326
"""
2427
SSE endpoint for clients to establish a connection and receive real-time messages
2528
"""
2629
session_id = session.get('session_id')
30+
connection_id = f"conn_{uuid.uuid4().hex[:8]}"
31+
2732
if not session_id:
2833
return Response("No session ID found", status=401)
2934

30-
logger.info(f"SSE connection established for session: {session_id}")
35+
logger.info(f"[SSE Connect] Thread {threading.current_thread().name} accessing sse_connections")
36+
logger.info(f"[SSE Connect] sse_connections id: {id(sse_connections)}")
3137

32-
# Create a queue for this session if it doesn't exist
33-
if session_id not in sse_connections:
34-
sse_connections[session_id] = queue.Queue()
38+
# Thread-safe connection creation
39+
with sse_connections_lock:
40+
if session_id not in sse_connections:
41+
sse_connections[session_id] = {
42+
'queue': queue.Queue(),
43+
'connected_clients': []
44+
}
45+
sse_connections[session_id]['connected_clients'].append(connection_id)
46+
logger.info(f"[SSE Connect] sse_connections after creation: {sse_connections}")
3547

48+
3649
def event_stream() -> Generator[str, None, None]:
3750
"""Generator function that yields SSE formatted messages"""
38-
client_queue = sse_connections[session_id]
51+
client_queue = sse_connections[session_id]['queue']
3952

4053
# Send initial connection confirmation
4154
yield format_sse_message({
@@ -45,27 +58,39 @@ def event_stream() -> Generator[str, None, None]:
4558
})
4659

4760
try:
61+
logger.info(f"Starting event stream for connection {connection_id} for session {session_id}")
62+
last_heartbeat_time = time.time()
4863
while True:
4964
try:
50-
# Wait for messages with a timeout to allow periodic heartbeat
51-
message = client_queue.get(timeout=30) # 30 second timeout
65+
message = client_queue.get(timeout=1) # 1 second timeout
5266
yield format_sse_message(message)
5367
except queue.Empty:
5468
# Send heartbeat to keep connection alive
55-
yield format_sse_message({
56-
"type": "notification",
57-
"text": "Heartbeat",
58-
"timestamp": time.time()
59-
})
69+
if time.time() - last_heartbeat_time > 30:
70+
last_heartbeat_time = time.time()
71+
yield format_sse_message({
72+
"type": "notification",
73+
"text": "Heartbeat",
74+
"timestamp": time.time()
75+
})
76+
else:
77+
# lightweight heartbeat to keep connection alive (no data)
78+
yield ": heartbeat\n\n"
6079
except Exception as e:
6180
logger.error(f"Error in SSE stream for session {session_id}: {e}")
6281
break
6382
finally:
64-
# Clean up connection when client disconnects
65-
if session_id in sse_connections:
66-
del sse_connections[session_id]
67-
logger.info(f"SSE connection closed for session: {session_id}")
68-
83+
# Safe cleanup with reference counting
84+
with sse_connections_lock:
85+
logger.info(f"[SSE Connect] cleaning up connection {connection_id} for session {session_id}")
86+
logger.info(f"[SSE Connect] sse_connections before cleanup: {sse_connections}")
87+
if session_id in sse_connections:
88+
sse_connections[session_id]['connected_clients'].remove(connection_id)
89+
if len(sse_connections[session_id]['connected_clients']) == 0:
90+
del sse_connections[session_id]
91+
logger.info(f"Last SSE connection ({connection_id}) closed for session {session_id}")
92+
logger.info(f"[SSE Connect] sse_connections after cleanup: {sse_connections}")
93+
6994
return Response(
7095
event_stream(),
7196
mimetype='text/event-stream',
@@ -83,20 +108,27 @@ def send_message_to_session():
83108
Endpoint to send a message to a specific session via SSE
84109
Expected JSON payload: {
85110
"type": "action" | "notification" | "heartbeat",
86-
"message": {...},
87-
"data": {...}
111+
"text": "....",
112+
"data": {...} (optional)
88113
}
89114
"""
90-
data = request.get_json()
91-
if not data:
115+
content = request.get_json()
116+
if not content:
92117
return jsonify({"error": "No JSON data provided"}), 400
93118

94-
target_session_id = data.get('session_id')
95-
message = data.get('message', {})
119+
target_session_id = content.get('session_id')
120+
text = content.get('text')
121+
data = content.get('data', {})
96122

97123
if not target_session_id:
98124
return jsonify({"error": "session_id is required"}), 400
99125

126+
message = {
127+
"type": "action",
128+
"text": text,
129+
"data": data
130+
}
131+
100132
success = send_sse_message(target_session_id, message)
101133

102134
if success:
@@ -127,52 +159,29 @@ def broadcast_message():
127159
def get_sse_status():
128160
"""Get the current status of SSE connections"""
129161

162+
connection_info = {}
163+
logger.info(f"[SSE Status] Thread {threading.current_thread().name} accessing sse_connections")
164+
logger.info(f"[SSE Status] sse_connections id: {id(sse_connections)}")
165+
logger.info(f"[SSE Status] sse_connections: {sse_connections}")
166+
167+
with sse_connections_lock:
168+
for session_id, connection in sse_connections.items():
169+
connection_info[session_id] = {
170+
'connected_clients': connection['connected_clients'],
171+
'queue_size': connection['queue'].qsize()
172+
}
173+
130174
return jsonify({
131-
"connected_sessions": list(sse_connections.keys()),
132-
"total_connections": len(sse_connections)
175+
"connected_sessions": connection_info
133176
})
134177

135-
@sse_bp.route('/trigger_notification', methods=['POST'])
136-
def trigger_notification():
137-
"""
138-
Endpoint to trigger a notification to a specific session
139-
Expected JSON payload: {
140-
"type": "notification",
141-
"text": "Notification message",
142-
"data": {...} (optional)
143-
}
144-
"""
145-
data = request.get_json()
146-
if not data:
147-
return jsonify({"error": "No JSON data provided"}), 400
148-
149-
session_id = data.get('session_id')
150-
text = data.get('text')
151-
data = data.get('data', {})
152-
153-
# Validate required fields
154-
if not session_id:
155-
return jsonify({"error": "session_id is required"}), 400
156-
if not text:
157-
return jsonify({"error": "text is required"}), 400
158-
159-
# Extract any additional data
160-
additional_data = data.get('data', {})
161-
162-
# Send the notification
163-
success = send_notification(
164-
session_id=session_id,
165-
text=text,
166-
data=data
167-
)
168-
169-
if success:
170-
return jsonify({
171-
"status": "Notification sent successfully",
172-
"session_id": session_id,
173-
})
178+
@sse_bp.route('/sse-connection-check', methods=['POST'])
179+
def sse_connection_check():
180+
session_id = request.json.get('session_id')
181+
if session_id in sse_connections:
182+
return jsonify({"status": "connected"})
174183
else:
175-
return jsonify({"error": "Session not found or not connected"}), 404
184+
return jsonify({"status": "disconnected"}), 404
176185

177186
# Utility functions
178187

@@ -196,18 +205,18 @@ def send_sse_message(session_id: str, message: Dict[str, Any]) -> bool:
196205
Returns:
197206
bool: True if message was sent successfully, False otherwise
198207
"""
199-
200-
if session_id not in sse_connections:
201-
logger.warning(f"Attempted to send message to non-existent session: {session_id}")
202-
return False
203-
204-
try:
205-
sse_connections[session_id].put(message)
206-
logger.info(f"Message sent to session {session_id}: {message.get('type', 'unknown')}")
207-
return True
208-
except Exception as e:
209-
logger.error(f"Failed to send message to session {session_id}: {e}")
210-
return False
208+
with sse_connections_lock:
209+
if session_id not in sse_connections:
210+
logger.warning(f"Attempted to send message to non-existent session: {session_id}")
211+
return False
212+
213+
try:
214+
sse_connections[session_id]['queue'].put(message)
215+
logger.info(f"Message sent to session {session_id}: {message.get('type', 'unknown')}")
216+
return True
217+
except Exception as e:
218+
logger.error(f"Failed to send message to session {session_id}: {e}")
219+
return False
211220

212221
def broadcast_sse_message(message: Dict[str, Any]) -> int:
213222
"""
@@ -221,27 +230,13 @@ def broadcast_sse_message(message: Dict[str, Any]) -> int:
221230
"""
222231
sent_count = 0
223232

224-
for session_id in list(sse_connections.keys()):
233+
# Get a snapshot of session IDs to avoid iteration issues
234+
with sse_connections_lock:
235+
session_ids = list(sse_connections.keys())
236+
237+
for session_id in session_ids:
225238
if send_sse_message(session_id, message):
226239
sent_count += 1
227240

228241
logger.info(f"Broadcasted message to {sent_count} clients: {message.get('type', 'unknown')}")
229-
return sent_count
230-
231-
def send_notification(session_id: str, text: str, data: Dict[str, Any]):
232-
"""
233-
Send a notification message to a specific session
234-
235-
Args:
236-
session_id: Target session ID
237-
text: Notification text
238-
data: Additional data to include in the notification
239-
"""
240-
message = {
241-
"type": "notification",
242-
"text": text,
243-
"data": data
244-
}
245-
246-
return send_sse_message(session_id, message)
247-
242+
return sent_count

py-src/data_formulator/tables_routes.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ def list_tables():
5959
if database_name in ['system', 'temp']:
6060
continue
6161

62-
6362
print(f"table_metadata: {table_metadata}")
6463

6564
try:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ dependencies = [
3535
"python-dotenv",
3636
"vega_datasets",
3737
"litellm",
38-
"duckdb"
38+
"duckdb",
3939
]
4040

4141
[project.urls]

0 commit comments

Comments
 (0)