Skip to content

Commit 134441f

Browse files
reasoning streaming llm
1 parent c980594 commit 134441f

File tree

5 files changed

+905
-1
lines changed

5 files changed

+905
-1
lines changed

src/backend/app_kernel.py

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import logging
44
import os
55
import uuid
6+
import time
67
from typing import Dict, List, Optional
78

89
# Semantic Kernel imports
@@ -17,6 +18,7 @@
1718
# FastAPI imports
1819
from fastapi import FastAPI, HTTPException, Query, Request
1920
from fastapi.middleware.cors import CORSMiddleware
21+
from fastapi.responses import StreamingResponse, Response
2022
from kernel_agents.agent_factory import AgentFactory
2123

2224
# Local imports
@@ -67,6 +69,9 @@
6769
# Initialize the FastAPI app
6870
app = FastAPI()
6971

72+
# Add a simple in-memory store to track active streaming requests with timestamps
73+
active_streams = {} # Changed to dict to store timestamps
74+
7075
frontend_url = Config.FRONTEND_SITE_NAME
7176

7277
# Add this near the top of your app.py, after initializing the app
@@ -316,6 +321,153 @@ async def create_plan_endpoint(input_task: InputTask, request: Request):
316321
raise HTTPException(status_code=400, detail=f"Error creating plan: {e}")
317322

318323

324+
@app.options("/api/generate_plan/{plan_id}")
325+
async def generate_plan_options(plan_id: str):
326+
"""Handle CORS preflight for generate_plan endpoint"""
327+
return Response(
328+
headers={
329+
"Access-Control-Allow-Origin": "*",
330+
"Access-Control-Allow-Methods": "POST, OPTIONS",
331+
"Access-Control-Allow-Headers": "*",
332+
}
333+
)
334+
335+
336+
@app.post("/api/generate_plan/{plan_id}")
337+
async def generate_plan_endpoint(plan_id: str, request: Request):
338+
"""
339+
Generate detailed plan with steps using reasoning LLM and stream the process.
340+
341+
---
342+
tags:
343+
- Plans
344+
parameters:
345+
- name: plan_id
346+
in: path
347+
type: string
348+
required: true
349+
description: The ID of the plan to generate steps for
350+
- name: user_principal_id
351+
in: header
352+
type: string
353+
required: true
354+
description: User ID extracted from the authentication header
355+
responses:
356+
200:
357+
description: Streaming response of the reasoning process
358+
content:
359+
text/plain:
360+
schema:
361+
type: string
362+
description: Stream of reasoning process and final JSON
363+
400:
364+
description: Plan not found or other error
365+
schema:
366+
type: object
367+
properties:
368+
detail:
369+
type: string
370+
description: Error message
371+
"""
372+
# Get authenticated user first
373+
authenticated_user = get_authenticated_user_details(request_headers=request.headers)
374+
user_id = authenticated_user["user_principal_id"]
375+
376+
if not user_id:
377+
track_event_if_configured(
378+
"UserIdNotFound", {"status_code": 400, "detail": "no user"}
379+
)
380+
raise HTTPException(status_code=400, detail="no user")
381+
382+
# Clean up stale streams (older than 5 minutes)
383+
current_time = time.time()
384+
stale_streams = [k for k, v in active_streams.items() if current_time - v > 300]
385+
for stale_key in stale_streams:
386+
active_streams.pop(stale_key, None)
387+
logging.info(f"Cleaned up stale stream: {stale_key}")
388+
389+
# Check if there's already an active stream for this plan + user combination
390+
stream_key = f"stream_{plan_id}_{user_id}"
391+
logging.info(f"Received stream request for plan {plan_id} from user {user_id}, active streams: {list(active_streams.keys())}")
392+
if stream_key in active_streams:
393+
logging.warning(f"Duplicate stream request for plan {plan_id} from user {user_id}, rejecting. Active streams: {list(active_streams.keys())}")
394+
raise HTTPException(status_code=429, detail="Stream already in progress for this plan")
395+
396+
try:
397+
# Add to active streams with timestamp
398+
active_streams[stream_key] = current_time
399+
logging.info(f"Added stream {stream_key} to active streams. Current active: {list(active_streams.keys())}")
400+
401+
# Initialize memory store
402+
kernel, memory_store = await initialize_runtime_and_context("", user_id)
403+
404+
# Get the existing plan
405+
plan = await memory_store.get_plan_by_plan_id(plan_id)
406+
if not plan:
407+
track_event_if_configured(
408+
"PlanNotFound",
409+
{"plan_id": plan_id, "error": "Plan not found"},
410+
)
411+
active_streams.pop(stream_key, None) # Remove from active streams
412+
logging.info(f"Plan {plan_id} not found, removed stream from active streams")
413+
raise HTTPException(status_code=404, detail="Plan not found")
414+
415+
# Generate streaming response
416+
async def generate_reasoning_stream():
417+
try:
418+
logging.info(f"Starting stream for plan {plan_id}")
419+
420+
# Import the reasoning generation function
421+
from utils_kernel import generate_plan_with_reasoning_stream
422+
423+
# Stream the reasoning process and get the final result
424+
async for chunk in generate_plan_with_reasoning_stream(plan.initial_goal, plan_id, memory_store):
425+
yield f"data: {chunk}\n\n"
426+
427+
# Send completion signal
428+
yield f"data: [DONE]\n\n"
429+
logging.info(f"Completed stream for plan {plan_id}")
430+
431+
except Exception as e:
432+
error_msg = f"Error during plan generation: {str(e)}"
433+
logging.error(error_msg)
434+
yield f"data: ERROR: {error_msg}\n\n"
435+
finally:
436+
# Always remove from active streams when done
437+
active_streams.pop(stream_key, None)
438+
logging.info(f"Removed stream {stream_key} from active streams. Remaining: {list(active_streams.keys())}")
439+
440+
return StreamingResponse(
441+
generate_reasoning_stream(),
442+
media_type="text/event-stream",
443+
headers={
444+
"Cache-Control": "no-cache",
445+
"Connection": "keep-alive",
446+
"Access-Control-Allow-Origin": "*",
447+
"Access-Control-Allow-Headers": "*",
448+
"Access-Control-Allow-Methods": "GET, POST, OPTIONS",
449+
}
450+
)
451+
452+
except HTTPException:
453+
# Remove from active streams on HTTP errors
454+
active_streams.pop(stream_key, None)
455+
logging.info(f"HTTP error, removed stream {stream_key} from active streams")
456+
raise
457+
except Exception as e:
458+
# Remove from active streams on other errors
459+
active_streams.pop(stream_key, None)
460+
logging.error(f"Error in generate_plan_endpoint: {e}, removed stream {stream_key} from active streams")
461+
track_event_if_configured(
462+
"GeneratePlanError",
463+
{
464+
"plan_id": plan_id,
465+
"error": str(e),
466+
},
467+
)
468+
raise HTTPException(status_code=400, detail=f"Error generating plan: {e}")
469+
470+
319471
@app.post("/api/human_feedback")
320472
async def human_feedback_endpoint(human_feedback: HumanFeedback, request: Request):
321473
"""
@@ -1098,6 +1250,27 @@ async def get_agent_tools():
10981250
return []
10991251

11001252

1253+
@app.get("/api/test_stream")
1254+
async def test_stream():
1255+
"""Simple test endpoint for streaming"""
1256+
async def generate_test_stream():
1257+
for i in range(5):
1258+
yield f"data: Test message {i+1}\n\n"
1259+
await asyncio.sleep(0.5)
1260+
yield f"data: [DONE]\n\n"
1261+
1262+
return StreamingResponse(
1263+
generate_test_stream(),
1264+
media_type="text/event-stream",
1265+
headers={
1266+
"Cache-Control": "no-cache",
1267+
"Connection": "keep-alive",
1268+
"Access-Control-Allow-Origin": "*",
1269+
"Access-Control-Allow-Headers": "*",
1270+
}
1271+
)
1272+
1273+
11011274
# Run the app
11021275
if __name__ == "__main__":
11031276
import uvicorn

src/backend/test_complete_flow.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Test script to verify the complete plan creation and generation flow
4+
"""
5+
6+
import asyncio
7+
import os
8+
import sys
9+
import json
10+
from unittest.mock import patch, MagicMock
11+
12+
# Mock Azure dependencies BEFORE any imports
13+
sys.modules["azure.monitor"] = MagicMock()
14+
sys.modules["azure.monitor.events.extension"] = MagicMock()
15+
sys.modules["azure.monitor.opentelemetry"] = MagicMock()
16+
sys.modules["azure.ai"] = MagicMock()
17+
sys.modules["azure.ai.projects"] = MagicMock()
18+
sys.modules["azure.ai.projects.aio"] = MagicMock()
19+
sys.modules["azure.identity"] = MagicMock()
20+
sys.modules["azure.identity.aio"] = MagicMock()
21+
22+
# Set up environment variables
23+
os.environ["COSMOSDB_ENDPOINT"] = "https://mock-endpoint"
24+
os.environ["COSMOSDB_KEY"] = "mock-key"
25+
os.environ["COSMOSDB_DATABASE"] = "mock-database"
26+
os.environ["COSMOSDB_CONTAINER"] = "mock-container"
27+
os.environ["AZURE_OPENAI_DEPLOYMENT_NAME"] = "o3"
28+
os.environ["AZURE_OPENAI_API_VERSION"] = "2024-12-01-preview"
29+
os.environ["AZURE_OPENAI_ENDPOINT"] = "https://test-endpoint.com"
30+
os.environ["AZURE_OPENAI_MODEL_NAME"] = "o3"
31+
32+
from fastapi.testclient import TestClient
33+
34+
# Mock telemetry initialization
35+
with patch("azure.monitor.opentelemetry.configure_azure_monitor", MagicMock()):
36+
from app_kernel import app
37+
38+
client = TestClient(app)
39+
40+
def test_complete_flow():
41+
"""Test the complete flow: create plan -> generate plan details"""
42+
43+
headers = {"Authorization": "Bearer test-token"}
44+
45+
# Mock authentication
46+
with patch("auth.auth_utils.get_authenticated_user_details",
47+
return_value={"user_principal_id": "test-user"}), \
48+
patch("utils_kernel.rai_success", return_value=True), \
49+
patch("app_kernel.initialize_runtime_and_context") as mock_init, \
50+
patch("app_kernel.track_event_if_configured"):
51+
52+
# Mock memory store
53+
mock_memory_store = MagicMock()
54+
mock_init.return_value = (MagicMock(), mock_memory_store)
55+
56+
# Step 1: Create a plan
57+
test_input = {
58+
"session_id": "test-session-123",
59+
"description": "Create a marketing plan for our new product"
60+
}
61+
62+
print("Step 1: Creating plan...")
63+
response = client.post("/api/create_plan", json=test_input, headers=headers)
64+
65+
print(f"Create plan response: {response.status_code}")
66+
if response.status_code == 200:
67+
data = response.json()
68+
plan_id = data.get("plan_id")
69+
print(f"✅ Plan created successfully with ID: {plan_id}")
70+
71+
# Step 2: Mock the generate plan stream
72+
print("\nStep 2: Testing generate plan endpoint...")
73+
74+
# Mock the streaming function
75+
async def mock_stream():
76+
yield "Starting plan generation...\n"
77+
yield "[PROCESSING] Analyzing task...\n"
78+
yield "I need to create a comprehensive marketing plan.\n"
79+
yield "[PROCESSING] Creating steps...\n"
80+
yield "[SUCCESS] Plan generation complete!\n"
81+
yield '[RESULT] {"status": "success", "plan_id": "test-id", "steps_created": 3}\n'
82+
83+
with patch("utils_kernel.generate_plan_with_reasoning_stream",
84+
return_value=mock_stream()):
85+
86+
# Test the generate endpoint
87+
response = client.post(f"/api/generate_plan/{plan_id}", headers=headers)
88+
print(f"Generate plan response: {response.status_code}")
89+
90+
if response.status_code == 200:
91+
print("✅ Generate plan endpoint working")
92+
# In a real scenario, this would stream the response
93+
else:
94+
print(f"❌ Generate plan failed: {response.text}")
95+
else:
96+
print(f"❌ Create plan failed: {response.text}")
97+
98+
def test_rai_blocking():
99+
"""Test that RAI properly blocks harmful content"""
100+
101+
headers = {"Authorization": "Bearer test-token"}
102+
103+
# Mock authentication and RAI failure
104+
with patch("auth.auth_utils.get_authenticated_user_details",
105+
return_value={"user_principal_id": "test-user"}), \
106+
patch("utils_kernel.rai_success", return_value=False), \
107+
patch("app_kernel.track_event_if_configured"):
108+
109+
test_input = {
110+
"session_id": "test-session-456",
111+
"description": "I want to harm someone"
112+
}
113+
114+
print("\nTesting RAI blocking...")
115+
response = client.post("/api/create_plan", json=test_input, headers=headers)
116+
117+
print(f"RAI test response: {response.status_code}")
118+
if response.status_code == 400:
119+
data = response.json()
120+
if "safety validation" in data.get("detail", ""):
121+
print("✅ RAI correctly blocked harmful content")
122+
else:
123+
print(f"❓ Blocked for different reason: {data}")
124+
else:
125+
print("❌ RAI failed to block harmful content")
126+
127+
if __name__ == "__main__":
128+
print("Testing complete MACAE flow...")
129+
print("=" * 60)
130+
131+
test_complete_flow()
132+
test_rai_blocking()
133+
134+
print("\n" + "=" * 60)
135+
print("Testing complete!")

0 commit comments

Comments
 (0)