|
3 | 3 | import logging |
4 | 4 | import os |
5 | 5 | import uuid |
| 6 | +import time |
6 | 7 | from typing import Dict, List, Optional |
7 | 8 |
|
8 | 9 | # Semantic Kernel imports |
|
17 | 18 | # FastAPI imports |
18 | 19 | from fastapi import FastAPI, HTTPException, Query, Request |
19 | 20 | from fastapi.middleware.cors import CORSMiddleware |
| 21 | +from fastapi.responses import StreamingResponse, Response |
20 | 22 | from kernel_agents.agent_factory import AgentFactory |
21 | 23 |
|
22 | 24 | # Local imports |
|
67 | 69 | # Initialize the FastAPI app |
68 | 70 | app = FastAPI() |
69 | 71 |
|
| 72 | +# Add a simple in-memory store to track active streaming requests with timestamps |
| 73 | +active_streams = {} # Changed to dict to store timestamps |
| 74 | + |
70 | 75 | frontend_url = Config.FRONTEND_SITE_NAME |
71 | 76 |
|
72 | 77 | # Add this near the top of your app.py, after initializing the app |
@@ -316,6 +321,153 @@ async def create_plan_endpoint(input_task: InputTask, request: Request): |
316 | 321 | raise HTTPException(status_code=400, detail=f"Error creating plan: {e}") |
317 | 322 |
|
318 | 323 |
|
| 324 | +@app.options("/api/generate_plan/{plan_id}") |
| 325 | +async def generate_plan_options(plan_id: str): |
| 326 | + """Handle CORS preflight for generate_plan endpoint""" |
| 327 | + return Response( |
| 328 | + headers={ |
| 329 | + "Access-Control-Allow-Origin": "*", |
| 330 | + "Access-Control-Allow-Methods": "POST, OPTIONS", |
| 331 | + "Access-Control-Allow-Headers": "*", |
| 332 | + } |
| 333 | + ) |
| 334 | + |
| 335 | + |
| 336 | +@app.post("/api/generate_plan/{plan_id}") |
| 337 | +async def generate_plan_endpoint(plan_id: str, request: Request): |
| 338 | + """ |
| 339 | + Generate detailed plan with steps using reasoning LLM and stream the process. |
| 340 | + |
| 341 | + --- |
| 342 | + tags: |
| 343 | + - Plans |
| 344 | + parameters: |
| 345 | + - name: plan_id |
| 346 | + in: path |
| 347 | + type: string |
| 348 | + required: true |
| 349 | + description: The ID of the plan to generate steps for |
| 350 | + - name: user_principal_id |
| 351 | + in: header |
| 352 | + type: string |
| 353 | + required: true |
| 354 | + description: User ID extracted from the authentication header |
| 355 | + responses: |
| 356 | + 200: |
| 357 | + description: Streaming response of the reasoning process |
| 358 | + content: |
| 359 | + text/plain: |
| 360 | + schema: |
| 361 | + type: string |
| 362 | + description: Stream of reasoning process and final JSON |
| 363 | + 400: |
| 364 | + description: Plan not found or other error |
| 365 | + schema: |
| 366 | + type: object |
| 367 | + properties: |
| 368 | + detail: |
| 369 | + type: string |
| 370 | + description: Error message |
| 371 | + """ |
| 372 | + # Get authenticated user first |
| 373 | + authenticated_user = get_authenticated_user_details(request_headers=request.headers) |
| 374 | + user_id = authenticated_user["user_principal_id"] |
| 375 | + |
| 376 | + if not user_id: |
| 377 | + track_event_if_configured( |
| 378 | + "UserIdNotFound", {"status_code": 400, "detail": "no user"} |
| 379 | + ) |
| 380 | + raise HTTPException(status_code=400, detail="no user") |
| 381 | + |
| 382 | + # Clean up stale streams (older than 5 minutes) |
| 383 | + current_time = time.time() |
| 384 | + stale_streams = [k for k, v in active_streams.items() if current_time - v > 300] |
| 385 | + for stale_key in stale_streams: |
| 386 | + active_streams.pop(stale_key, None) |
| 387 | + logging.info(f"Cleaned up stale stream: {stale_key}") |
| 388 | + |
| 389 | + # Check if there's already an active stream for this plan + user combination |
| 390 | + stream_key = f"stream_{plan_id}_{user_id}" |
| 391 | + logging.info(f"Received stream request for plan {plan_id} from user {user_id}, active streams: {list(active_streams.keys())}") |
| 392 | + if stream_key in active_streams: |
| 393 | + logging.warning(f"Duplicate stream request for plan {plan_id} from user {user_id}, rejecting. Active streams: {list(active_streams.keys())}") |
| 394 | + raise HTTPException(status_code=429, detail="Stream already in progress for this plan") |
| 395 | + |
| 396 | + try: |
| 397 | + # Add to active streams with timestamp |
| 398 | + active_streams[stream_key] = current_time |
| 399 | + logging.info(f"Added stream {stream_key} to active streams. Current active: {list(active_streams.keys())}") |
| 400 | + |
| 401 | + # Initialize memory store |
| 402 | + kernel, memory_store = await initialize_runtime_and_context("", user_id) |
| 403 | + |
| 404 | + # Get the existing plan |
| 405 | + plan = await memory_store.get_plan_by_plan_id(plan_id) |
| 406 | + if not plan: |
| 407 | + track_event_if_configured( |
| 408 | + "PlanNotFound", |
| 409 | + {"plan_id": plan_id, "error": "Plan not found"}, |
| 410 | + ) |
| 411 | + active_streams.pop(stream_key, None) # Remove from active streams |
| 412 | + logging.info(f"Plan {plan_id} not found, removed stream from active streams") |
| 413 | + raise HTTPException(status_code=404, detail="Plan not found") |
| 414 | + |
| 415 | + # Generate streaming response |
| 416 | + async def generate_reasoning_stream(): |
| 417 | + try: |
| 418 | + logging.info(f"Starting stream for plan {plan_id}") |
| 419 | + |
| 420 | + # Import the reasoning generation function |
| 421 | + from utils_kernel import generate_plan_with_reasoning_stream |
| 422 | + |
| 423 | + # Stream the reasoning process and get the final result |
| 424 | + async for chunk in generate_plan_with_reasoning_stream(plan.initial_goal, plan_id, memory_store): |
| 425 | + yield f"data: {chunk}\n\n" |
| 426 | + |
| 427 | + # Send completion signal |
| 428 | + yield f"data: [DONE]\n\n" |
| 429 | + logging.info(f"Completed stream for plan {plan_id}") |
| 430 | + |
| 431 | + except Exception as e: |
| 432 | + error_msg = f"Error during plan generation: {str(e)}" |
| 433 | + logging.error(error_msg) |
| 434 | + yield f"data: ERROR: {error_msg}\n\n" |
| 435 | + finally: |
| 436 | + # Always remove from active streams when done |
| 437 | + active_streams.pop(stream_key, None) |
| 438 | + logging.info(f"Removed stream {stream_key} from active streams. Remaining: {list(active_streams.keys())}") |
| 439 | + |
| 440 | + return StreamingResponse( |
| 441 | + generate_reasoning_stream(), |
| 442 | + media_type="text/event-stream", |
| 443 | + headers={ |
| 444 | + "Cache-Control": "no-cache", |
| 445 | + "Connection": "keep-alive", |
| 446 | + "Access-Control-Allow-Origin": "*", |
| 447 | + "Access-Control-Allow-Headers": "*", |
| 448 | + "Access-Control-Allow-Methods": "GET, POST, OPTIONS", |
| 449 | + } |
| 450 | + ) |
| 451 | + |
| 452 | + except HTTPException: |
| 453 | + # Remove from active streams on HTTP errors |
| 454 | + active_streams.pop(stream_key, None) |
| 455 | + logging.info(f"HTTP error, removed stream {stream_key} from active streams") |
| 456 | + raise |
| 457 | + except Exception as e: |
| 458 | + # Remove from active streams on other errors |
| 459 | + active_streams.pop(stream_key, None) |
| 460 | + logging.error(f"Error in generate_plan_endpoint: {e}, removed stream {stream_key} from active streams") |
| 461 | + track_event_if_configured( |
| 462 | + "GeneratePlanError", |
| 463 | + { |
| 464 | + "plan_id": plan_id, |
| 465 | + "error": str(e), |
| 466 | + }, |
| 467 | + ) |
| 468 | + raise HTTPException(status_code=400, detail=f"Error generating plan: {e}") |
| 469 | + |
| 470 | + |
319 | 471 | @app.post("/api/human_feedback") |
320 | 472 | async def human_feedback_endpoint(human_feedback: HumanFeedback, request: Request): |
321 | 473 | """ |
@@ -1098,6 +1250,27 @@ async def get_agent_tools(): |
1098 | 1250 | return [] |
1099 | 1251 |
|
1100 | 1252 |
|
| 1253 | +@app.get("/api/test_stream") |
| 1254 | +async def test_stream(): |
| 1255 | + """Simple test endpoint for streaming""" |
| 1256 | + async def generate_test_stream(): |
| 1257 | + for i in range(5): |
| 1258 | + yield f"data: Test message {i+1}\n\n" |
| 1259 | + await asyncio.sleep(0.5) |
| 1260 | + yield f"data: [DONE]\n\n" |
| 1261 | + |
| 1262 | + return StreamingResponse( |
| 1263 | + generate_test_stream(), |
| 1264 | + media_type="text/event-stream", |
| 1265 | + headers={ |
| 1266 | + "Cache-Control": "no-cache", |
| 1267 | + "Connection": "keep-alive", |
| 1268 | + "Access-Control-Allow-Origin": "*", |
| 1269 | + "Access-Control-Allow-Headers": "*", |
| 1270 | + } |
| 1271 | + ) |
| 1272 | + |
| 1273 | + |
1101 | 1274 | # Run the app |
1102 | 1275 | if __name__ == "__main__": |
1103 | 1276 | import uvicorn |
|
0 commit comments