|
20 | 20 | from common.utils.utils_date import format_dates_in_messages |
21 | 21 | # Updated import for KernelArguments |
22 | 22 | from common.utils.utils_kernel import rai_success |
| 23 | +from common.utils.websocket_streaming import (websocket_streaming_endpoint, |
| 24 | + ws_manager) |
23 | 25 | # FastAPI imports |
24 | | -from fastapi import FastAPI, HTTPException, Query, Request |
| 26 | +from fastapi import FastAPI, HTTPException, Query, Request, WebSocket |
25 | 27 | from fastapi.middleware.cors import CORSMiddleware |
26 | 28 | from kernel_agents.agent_factory import AgentFactory |
27 | 29 | # Local imports |
|
81 | 83 | logging.info("Added health check middleware") |
82 | 84 |
|
83 | 85 |
|
| 86 | +# WebSocket streaming endpoint |
| 87 | +@app.websocket("/ws/streaming") |
| 88 | +async def websocket_endpoint(websocket: WebSocket): |
| 89 | + """WebSocket endpoint for real-time plan execution streaming""" |
| 90 | + await websocket_streaming_endpoint(websocket) |
| 91 | + |
| 92 | + |
84 | 93 | @app.post("/api/user_browser_language") |
85 | 94 | async def user_browser_language_endpoint(user_language: UserLanguage, request: Request): |
86 | 95 | """ |
@@ -587,12 +596,14 @@ async def approve_step_endpoint( |
587 | 596 |
|
588 | 597 | return {"status": "All steps approved"} |
589 | 598 |
|
| 599 | + |
590 | 600 | # Get plans is called in the initial side rendering of the frontend |
591 | 601 | @app.get("/api/plans") |
592 | 602 | async def get_plans( |
593 | 603 | request: Request, |
594 | 604 | session_id: Optional[str] = Query(None), |
595 | 605 | plan_id: Optional[str] = Query(None), |
| 606 | + team_id: Optional[str] = Query(None), |
596 | 607 | ): |
597 | 608 | """ |
598 | 609 | Retrieve plans for the current user. |
@@ -659,7 +670,7 @@ async def get_plans( |
659 | 670 | "UserIdNotFound", {"status_code": 400, "detail": "no user"} |
660 | 671 | ) |
661 | 672 | raise HTTPException(status_code=400, detail="no user") |
662 | | - |
| 673 | + |
663 | 674 | # Initialize agent team for this user session |
664 | 675 | await OrchestrationManager.get_current_orchestration(user_id=user_id) |
665 | 676 |
|
@@ -884,6 +895,82 @@ async def get_agent_tools(): |
884 | 895 | return [] |
885 | 896 |
|
886 | 897 |
|
| 898 | +@app.post("/api/test/streaming/{plan_id}") |
| 899 | +async def test_streaming_updates(plan_id: str): |
| 900 | + """ |
| 901 | + Test endpoint to simulate streaming updates for a plan. |
| 902 | + This is for testing the WebSocket streaming functionality. |
| 903 | + """ |
| 904 | + from common.utils.websocket_streaming import (send_agent_message, |
| 905 | + send_plan_update, |
| 906 | + send_step_update) |
| 907 | + |
| 908 | + try: |
| 909 | + # Simulate a series of streaming updates |
| 910 | + await send_agent_message( |
| 911 | + plan_id=plan_id, |
| 912 | + agent_name="Data Analyst", |
| 913 | + content="Starting analysis of the data...", |
| 914 | + message_type="thinking", |
| 915 | + ) |
| 916 | + |
| 917 | + await asyncio.sleep(1) |
| 918 | + |
| 919 | + await send_plan_update( |
| 920 | + plan_id=plan_id, |
| 921 | + step_id="step_1", |
| 922 | + agent_name="Data Analyst", |
| 923 | + content="Analyzing customer data patterns...", |
| 924 | + status="in_progress", |
| 925 | + message_type="action", |
| 926 | + ) |
| 927 | + |
| 928 | + await asyncio.sleep(2) |
| 929 | + |
| 930 | + await send_agent_message( |
| 931 | + plan_id=plan_id, |
| 932 | + agent_name="Data Analyst", |
| 933 | + content="Found 3 key insights in the customer data. Processing recommendations...", |
| 934 | + message_type="result", |
| 935 | + ) |
| 936 | + |
| 937 | + await asyncio.sleep(1) |
| 938 | + |
| 939 | + await send_step_update( |
| 940 | + plan_id=plan_id, |
| 941 | + step_id="step_1", |
| 942 | + status="completed", |
| 943 | + content="Data analysis completed successfully!", |
| 944 | + ) |
| 945 | + |
| 946 | + await send_agent_message( |
| 947 | + plan_id=plan_id, |
| 948 | + agent_name="Business Advisor", |
| 949 | + content="Reviewing the analysis results and preparing strategic recommendations...", |
| 950 | + message_type="thinking", |
| 951 | + ) |
| 952 | + |
| 953 | + await asyncio.sleep(2) |
| 954 | + |
| 955 | + await send_plan_update( |
| 956 | + plan_id=plan_id, |
| 957 | + step_id="step_2", |
| 958 | + agent_name="Business Advisor", |
| 959 | + content="Based on the data analysis, I recommend focusing on customer retention strategies for the identified high-value segments.", |
| 960 | + status="completed", |
| 961 | + message_type="result", |
| 962 | + ) |
| 963 | + |
| 964 | + return { |
| 965 | + "status": "success", |
| 966 | + "message": f"Test streaming updates sent for plan {plan_id}", |
| 967 | + } |
| 968 | + |
| 969 | + except Exception as e: |
| 970 | + logging.error(f"Error sending test streaming updates: {e}") |
| 971 | + raise HTTPException(status_code=500, detail=str(e)) |
| 972 | + |
| 973 | + |
887 | 974 | # Run the app |
888 | 975 | if __name__ == "__main__": |
889 | 976 | import uvicorn |
|
0 commit comments