|
| 1 | +# Copyright (c) Microsoft Corporation. |
| 2 | +# Licensed under the MIT License. |
| 3 | + |
| 4 | +import json |
| 5 | +import queue |
| 6 | +import time |
| 7 | +import logging |
| 8 | +from typing import Dict, Any, Generator |
| 9 | +from flask import Blueprint, Response, session, request, jsonify, current_app |
| 10 | +import threading |
| 11 | + |
| 12 | +# Get logger for this module |
| 13 | +logger = logging.getLogger(__name__) |
| 14 | + |
| 15 | +# Create blueprint for SSE routes |
| 16 | +sse_bp = Blueprint('sse', __name__, url_prefix='/api/sse') |
| 17 | + |
| 18 | +# Global dictionary to store SSE connections |
| 19 | +sse_connections: Dict[str, queue.Queue] = {} |
| 20 | + |
| 21 | +@sse_bp.route('/connect') |
| 22 | +def sse_connect(): |
| 23 | + """ |
| 24 | + SSE endpoint for clients to establish a connection and receive real-time messages |
| 25 | + """ |
| 26 | + session_id = session.get('session_id') |
| 27 | + if not session_id: |
| 28 | + return Response("No session ID found", status=401) |
| 29 | + |
| 30 | + logger.info(f"SSE connection established for session: {session_id}") |
| 31 | + |
| 32 | + # Create a queue for this session if it doesn't exist |
| 33 | + if session_id not in sse_connections: |
| 34 | + sse_connections[session_id] = queue.Queue() |
| 35 | + |
| 36 | + def event_stream() -> Generator[str, None, None]: |
| 37 | + """Generator function that yields SSE formatted messages""" |
| 38 | + client_queue = sse_connections[session_id] |
| 39 | + |
| 40 | + # Send initial connection confirmation |
| 41 | + yield format_sse_message({ |
| 42 | + "type": "notification", |
| 43 | + "text": "SSE connection established successfully", |
| 44 | + "timestamp": time.time() |
| 45 | + }) |
| 46 | + |
| 47 | + try: |
| 48 | + while True: |
| 49 | + try: |
| 50 | + # Wait for messages with a timeout to allow periodic heartbeat |
| 51 | + message = client_queue.get(timeout=30) # 30 second timeout |
| 52 | + yield format_sse_message(message) |
| 53 | + except queue.Empty: |
| 54 | + # Send heartbeat to keep connection alive |
| 55 | + yield format_sse_message({ |
| 56 | + "type": "notification", |
| 57 | + "text": "Heartbeat", |
| 58 | + "timestamp": time.time() |
| 59 | + }) |
| 60 | + except Exception as e: |
| 61 | + logger.error(f"Error in SSE stream for session {session_id}: {e}") |
| 62 | + break |
| 63 | + finally: |
| 64 | + # Clean up connection when client disconnects |
| 65 | + if session_id in sse_connections: |
| 66 | + del sse_connections[session_id] |
| 67 | + logger.info(f"SSE connection closed for session: {session_id}") |
| 68 | + |
| 69 | + return Response( |
| 70 | + event_stream(), |
| 71 | + mimetype='text/event-stream', |
| 72 | + headers={ |
| 73 | + 'Cache-Control': 'no-cache', |
| 74 | + 'Connection': 'keep-alive', |
| 75 | + 'Access-Control-Allow-Origin': '*', |
| 76 | + 'Access-Control-Allow-Headers': 'Cache-Control' |
| 77 | + } |
| 78 | + ) |
| 79 | + |
| 80 | +@sse_bp.route('/send-message', methods=['POST']) |
| 81 | +def send_message_to_session(): |
| 82 | + """ |
| 83 | + Endpoint to send a message to a specific session via SSE |
| 84 | + Expected JSON payload: { |
| 85 | + "type": "action" | "notification" | "heartbeat", |
| 86 | + "message": {...}, |
| 87 | + "data": {...} |
| 88 | + } |
| 89 | + """ |
| 90 | + data = request.get_json() |
| 91 | + if not data: |
| 92 | + return jsonify({"error": "No JSON data provided"}), 400 |
| 93 | + |
| 94 | + target_session_id = data.get('session_id') |
| 95 | + message = data.get('message', {}) |
| 96 | + |
| 97 | + if not target_session_id: |
| 98 | + return jsonify({"error": "session_id is required"}), 400 |
| 99 | + |
| 100 | + success = send_sse_message(target_session_id, message) |
| 101 | + |
| 102 | + if success: |
| 103 | + return jsonify({"status": "Message sent successfully"}) |
| 104 | + else: |
| 105 | + return jsonify({"error": "Session not found or not connected"}), 404 |
| 106 | + |
| 107 | +@sse_bp.route('/broadcast', methods=['POST']) |
| 108 | +def broadcast_message(): |
| 109 | + """ |
| 110 | + Endpoint to broadcast a message to all connected SSE clients |
| 111 | + Expected JSON payload: { |
| 112 | + "message": {...} |
| 113 | + } |
| 114 | + """ |
| 115 | + data = request.get_json() |
| 116 | + if not data: |
| 117 | + return jsonify({"error": "No JSON data provided"}), 400 |
| 118 | + |
| 119 | + message = data.get('message', {}) |
| 120 | + sent_count = broadcast_sse_message(message) |
| 121 | + |
| 122 | + return jsonify({ |
| 123 | + "status": f"Message broadcasted to {sent_count} connected clients" |
| 124 | + }) |
| 125 | + |
| 126 | +@sse_bp.route('/status') |
| 127 | +def get_sse_status(): |
| 128 | + """Get the current status of SSE connections""" |
| 129 | + |
| 130 | + return jsonify({ |
| 131 | + "connected_sessions": list(sse_connections.keys()), |
| 132 | + "total_connections": len(sse_connections) |
| 133 | + }) |
| 134 | + |
| 135 | +@sse_bp.route('/trigger_notification', methods=['POST']) |
| 136 | +def trigger_notification(): |
| 137 | + """ |
| 138 | + Endpoint to trigger a notification to a specific session |
| 139 | + Expected JSON payload: { |
| 140 | + "type": "notification", |
| 141 | + "text": "Notification message", |
| 142 | + "data": {...} (optional) |
| 143 | + } |
| 144 | + """ |
| 145 | + data = request.get_json() |
| 146 | + if not data: |
| 147 | + return jsonify({"error": "No JSON data provided"}), 400 |
| 148 | + |
| 149 | + session_id = data.get('session_id') |
| 150 | + text = data.get('text') |
| 151 | + data = data.get('data', {}) |
| 152 | + |
| 153 | + # Validate required fields |
| 154 | + if not session_id: |
| 155 | + return jsonify({"error": "session_id is required"}), 400 |
| 156 | + if not text: |
| 157 | + return jsonify({"error": "text is required"}), 400 |
| 158 | + |
| 159 | + # Extract any additional data |
| 160 | + additional_data = data.get('data', {}) |
| 161 | + |
| 162 | + # Send the notification |
| 163 | + success = send_notification( |
| 164 | + session_id=session_id, |
| 165 | + text=text, |
| 166 | + data=data |
| 167 | + ) |
| 168 | + |
| 169 | + if success: |
| 170 | + return jsonify({ |
| 171 | + "status": "Notification sent successfully", |
| 172 | + "session_id": session_id, |
| 173 | + }) |
| 174 | + else: |
| 175 | + return jsonify({"error": "Session not found or not connected"}), 404 |
| 176 | + |
| 177 | +# Utility functions |
| 178 | + |
| 179 | +def format_sse_message(data: Dict[str, Any]) -> str: |
| 180 | + """Format a message for SSE transmission""" |
| 181 | + # Add timestamp if not present |
| 182 | + if 'timestamp' not in data: |
| 183 | + data['timestamp'] = time.time() |
| 184 | + |
| 185 | + json_data = json.dumps(data) |
| 186 | + return f"data: {json_data}\n\n" |
| 187 | + |
| 188 | +def send_sse_message(session_id: str, message: Dict[str, Any]) -> bool: |
| 189 | + """ |
| 190 | + Send a message to a specific session via SSE |
| 191 | + |
| 192 | + Args: |
| 193 | + session_id: Target session ID |
| 194 | + message: Message data to send |
| 195 | + |
| 196 | + Returns: |
| 197 | + bool: True if message was sent successfully, False otherwise |
| 198 | + """ |
| 199 | + |
| 200 | + if session_id not in sse_connections: |
| 201 | + logger.warning(f"Attempted to send message to non-existent session: {session_id}") |
| 202 | + return False |
| 203 | + |
| 204 | + try: |
| 205 | + sse_connections[session_id].put(message) |
| 206 | + logger.info(f"Message sent to session {session_id}: {message.get('type', 'unknown')}") |
| 207 | + return True |
| 208 | + except Exception as e: |
| 209 | + logger.error(f"Failed to send message to session {session_id}: {e}") |
| 210 | + return False |
| 211 | + |
| 212 | +def broadcast_sse_message(message: Dict[str, Any]) -> int: |
| 213 | + """ |
| 214 | + Broadcast a message to all connected SSE clients |
| 215 | + |
| 216 | + Args: |
| 217 | + message: Message data to broadcast |
| 218 | + |
| 219 | + Returns: |
| 220 | + int: Number of clients the message was sent to |
| 221 | + """ |
| 222 | + sent_count = 0 |
| 223 | + |
| 224 | + for session_id in list(sse_connections.keys()): |
| 225 | + if send_sse_message(session_id, message): |
| 226 | + sent_count += 1 |
| 227 | + |
| 228 | + logger.info(f"Broadcasted message to {sent_count} clients: {message.get('type', 'unknown')}") |
| 229 | + return sent_count |
| 230 | + |
| 231 | +def send_notification(session_id: str, text: str, data: Dict[str, Any]): |
| 232 | + """ |
| 233 | + Send a notification message to a specific session |
| 234 | + |
| 235 | + Args: |
| 236 | + session_id: Target session ID |
| 237 | + text: Notification text |
| 238 | + data: Additional data to include in the notification |
| 239 | + """ |
| 240 | + message = { |
| 241 | + "type": "notification", |
| 242 | + "text": text, |
| 243 | + "data": data |
| 244 | + } |
| 245 | + |
| 246 | + return send_sse_message(session_id, message) |
| 247 | + |
0 commit comments