Skip to content

Commit 5231653

Browse files
committed
brewing
1 parent 7936025 commit 5231653

File tree

8 files changed

+602
-143
lines changed

8 files changed

+602
-143
lines changed

py-src/data_formulator/app.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,10 @@
3636
# blueprints
3737
from data_formulator.tables_routes import tables_bp
3838
from data_formulator.agent_routes import agent_bp
39+
from data_formulator.sse_routes import sse_bp
3940

41+
import queue
42+
from typing import Dict, Any
4043

4144
app = Flask(__name__, static_url_path='', static_folder=os.path.join(APP_ROOT, "dist"))
4245
app.secret_key = secrets.token_hex(16) # Generate a random secret key for sessions
@@ -65,6 +68,7 @@ def default(self, obj):
6568
# register blueprints
6669
app.register_blueprint(tables_bp)
6770
app.register_blueprint(agent_bp)
71+
app.register_blueprint(sse_bp)
6872

6973
print(APP_ROOT)
7074

py-src/data_formulator/sse_routes.py

Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
import json
5+
import queue
6+
import time
7+
import logging
8+
from typing import Dict, Any, Generator
9+
from flask import Blueprint, Response, session, request, jsonify, current_app
10+
import threading
11+
12+
# Get logger for this module
13+
logger = logging.getLogger(__name__)
14+
15+
# Create blueprint for SSE routes
16+
sse_bp = Blueprint('sse', __name__, url_prefix='/api/sse')
17+
18+
# Global dictionary to store SSE connections
19+
sse_connections: Dict[str, queue.Queue] = {}
20+
21+
@sse_bp.route('/connect')
22+
def sse_connect():
23+
"""
24+
SSE endpoint for clients to establish a connection and receive real-time messages
25+
"""
26+
session_id = session.get('session_id')
27+
if not session_id:
28+
return Response("No session ID found", status=401)
29+
30+
logger.info(f"SSE connection established for session: {session_id}")
31+
32+
# Create a queue for this session if it doesn't exist
33+
if session_id not in sse_connections:
34+
sse_connections[session_id] = queue.Queue()
35+
36+
def event_stream() -> Generator[str, None, None]:
37+
"""Generator function that yields SSE formatted messages"""
38+
client_queue = sse_connections[session_id]
39+
40+
# Send initial connection confirmation
41+
yield format_sse_message({
42+
"type": "notification",
43+
"text": "SSE connection established successfully",
44+
"timestamp": time.time()
45+
})
46+
47+
try:
48+
while True:
49+
try:
50+
# Wait for messages with a timeout to allow periodic heartbeat
51+
message = client_queue.get(timeout=30) # 30 second timeout
52+
yield format_sse_message(message)
53+
except queue.Empty:
54+
# Send heartbeat to keep connection alive
55+
yield format_sse_message({
56+
"type": "notification",
57+
"text": "Heartbeat",
58+
"timestamp": time.time()
59+
})
60+
except Exception as e:
61+
logger.error(f"Error in SSE stream for session {session_id}: {e}")
62+
break
63+
finally:
64+
# Clean up connection when client disconnects
65+
if session_id in sse_connections:
66+
del sse_connections[session_id]
67+
logger.info(f"SSE connection closed for session: {session_id}")
68+
69+
return Response(
70+
event_stream(),
71+
mimetype='text/event-stream',
72+
headers={
73+
'Cache-Control': 'no-cache',
74+
'Connection': 'keep-alive',
75+
'Access-Control-Allow-Origin': '*',
76+
'Access-Control-Allow-Headers': 'Cache-Control'
77+
}
78+
)
79+
80+
@sse_bp.route('/send-message', methods=['POST'])
81+
def send_message_to_session():
82+
"""
83+
Endpoint to send a message to a specific session via SSE
84+
Expected JSON payload: {
85+
"type": "action" | "notification" | "heartbeat",
86+
"message": {...},
87+
"data": {...}
88+
}
89+
"""
90+
data = request.get_json()
91+
if not data:
92+
return jsonify({"error": "No JSON data provided"}), 400
93+
94+
target_session_id = data.get('session_id')
95+
message = data.get('message', {})
96+
97+
if not target_session_id:
98+
return jsonify({"error": "session_id is required"}), 400
99+
100+
success = send_sse_message(target_session_id, message)
101+
102+
if success:
103+
return jsonify({"status": "Message sent successfully"})
104+
else:
105+
return jsonify({"error": "Session not found or not connected"}), 404
106+
107+
@sse_bp.route('/broadcast', methods=['POST'])
108+
def broadcast_message():
109+
"""
110+
Endpoint to broadcast a message to all connected SSE clients
111+
Expected JSON payload: {
112+
"message": {...}
113+
}
114+
"""
115+
data = request.get_json()
116+
if not data:
117+
return jsonify({"error": "No JSON data provided"}), 400
118+
119+
message = data.get('message', {})
120+
sent_count = broadcast_sse_message(message)
121+
122+
return jsonify({
123+
"status": f"Message broadcasted to {sent_count} connected clients"
124+
})
125+
126+
@sse_bp.route('/status')
127+
def get_sse_status():
128+
"""Get the current status of SSE connections"""
129+
130+
return jsonify({
131+
"connected_sessions": list(sse_connections.keys()),
132+
"total_connections": len(sse_connections)
133+
})
134+
135+
@sse_bp.route('/trigger_notification', methods=['POST'])
136+
def trigger_notification():
137+
"""
138+
Endpoint to trigger a notification to a specific session
139+
Expected JSON payload: {
140+
"type": "notification",
141+
"text": "Notification message",
142+
"data": {...} (optional)
143+
}
144+
"""
145+
data = request.get_json()
146+
if not data:
147+
return jsonify({"error": "No JSON data provided"}), 400
148+
149+
session_id = data.get('session_id')
150+
text = data.get('text')
151+
data = data.get('data', {})
152+
153+
# Validate required fields
154+
if not session_id:
155+
return jsonify({"error": "session_id is required"}), 400
156+
if not text:
157+
return jsonify({"error": "text is required"}), 400
158+
159+
# Extract any additional data
160+
additional_data = data.get('data', {})
161+
162+
# Send the notification
163+
success = send_notification(
164+
session_id=session_id,
165+
text=text,
166+
data=data
167+
)
168+
169+
if success:
170+
return jsonify({
171+
"status": "Notification sent successfully",
172+
"session_id": session_id,
173+
})
174+
else:
175+
return jsonify({"error": "Session not found or not connected"}), 404
176+
177+
# Utility functions
178+
179+
def format_sse_message(data: Dict[str, Any]) -> str:
180+
"""Format a message for SSE transmission"""
181+
# Add timestamp if not present
182+
if 'timestamp' not in data:
183+
data['timestamp'] = time.time()
184+
185+
json_data = json.dumps(data)
186+
return f"data: {json_data}\n\n"
187+
188+
def send_sse_message(session_id: str, message: Dict[str, Any]) -> bool:
189+
"""
190+
Send a message to a specific session via SSE
191+
192+
Args:
193+
session_id: Target session ID
194+
message: Message data to send
195+
196+
Returns:
197+
bool: True if message was sent successfully, False otherwise
198+
"""
199+
200+
if session_id not in sse_connections:
201+
logger.warning(f"Attempted to send message to non-existent session: {session_id}")
202+
return False
203+
204+
try:
205+
sse_connections[session_id].put(message)
206+
logger.info(f"Message sent to session {session_id}: {message.get('type', 'unknown')}")
207+
return True
208+
except Exception as e:
209+
logger.error(f"Failed to send message to session {session_id}: {e}")
210+
return False
211+
212+
def broadcast_sse_message(message: Dict[str, Any]) -> int:
213+
"""
214+
Broadcast a message to all connected SSE clients
215+
216+
Args:
217+
message: Message data to broadcast
218+
219+
Returns:
220+
int: Number of clients the message was sent to
221+
"""
222+
sent_count = 0
223+
224+
for session_id in list(sse_connections.keys()):
225+
if send_sse_message(session_id, message):
226+
sent_count += 1
227+
228+
logger.info(f"Broadcasted message to {sent_count} clients: {message.get('type', 'unknown')}")
229+
return sent_count
230+
231+
def send_notification(session_id: str, text: str, data: Dict[str, Any]):
232+
"""
233+
Send a notification message to a specific session
234+
235+
Args:
236+
session_id: Target session ID
237+
text: Notification text
238+
data: Additional data to include in the notification
239+
"""
240+
message = {
241+
"type": "notification",
242+
"text": text,
243+
"data": data
244+
}
245+
246+
return send_sse_message(session_id, message)
247+

src/app/dfSlice.tsx

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,13 @@ export const generateFreshChart = (tableRef: string, chartType?: string) : Chart
2626
}
2727
}
2828

29+
export interface SSEMessage {
30+
type: "notification" | "action";
31+
text: string;
32+
data?: Record<string, any>;
33+
timestamp: number;
34+
}
35+
2936
export interface ModelConfig {
3037
id: string; // unique identifier for the model / client combination
3138
endpoint: string;
@@ -73,6 +80,8 @@ export interface DataFormulatorState {
7380
}
7481

7582
dataLoaderConnectParams: Record<string, Record<string, string>>; // {table_name: {param_name: param_value}}
83+
84+
lastSSEMessage: SSEMessage | undefined; // Store the last received SSE message
7685
}
7786

7887
// Define the initial state using that type
@@ -112,7 +121,9 @@ const initialState: DataFormulatorState = {
112121
defaultChartHeight: 300,
113122
},
114123

115-
dataLoaderConnectParams: {}
124+
dataLoaderConnectParams: {},
125+
126+
lastSSEMessage: undefined,
116127
}
117128

118129
let getUnrefedDerivedTableIds = (state: DataFormulatorState) => {
@@ -301,7 +312,7 @@ export const dataFormulatorSlice = createSlice({
301312

302313
state.conceptShelfItems = savedState.conceptShelfItems || [];
303314

304-
state.messages = [];
315+
state.messages = [];
305316
state.displayedMessageIdx = -1;
306317

307318
state.focusedTableId = savedState.focusedTableId || undefined;
@@ -755,6 +766,29 @@ export const dataFormulatorSlice = createSlice({
755766
deleteDataLoaderConnectParams: (state, action: PayloadAction<string>) => {
756767
let dataLoaderType = action.payload;
757768
delete state.dataLoaderConnectParams[dataLoaderType];
769+
},
770+
handleSSEMessage: (state, action: PayloadAction<SSEMessage>) => {
771+
state.lastSSEMessage = action.payload;
772+
if (action.payload.type == "notification") {
773+
console.log('SSE message stored in Redux:', action.payload);
774+
state.messages = [...state.messages, {
775+
component: "server",
776+
type: "info",
777+
timestamp: action.payload.timestamp,
778+
value: action.payload.text || "Unknown message"
779+
}];
780+
} else if (action.payload.type == "action") {
781+
console.log('SSE message stored in Redux:', action.payload);
782+
state.messages = [...state.messages, {
783+
component: "server",
784+
type: "info",
785+
timestamp: action.payload.timestamp,
786+
value: action.payload.text || "Unknown message"
787+
}];
788+
}
789+
},
790+
clearMessages: (state) => {
791+
state.messages = [];
758792
}
759793
},
760794
extraReducers: (builder) => {

0 commit comments

Comments
 (0)