Skip to content

Commit 5732d1c

Browse files
New metadata class
1 parent 2b8c028 commit 5732d1c

File tree

1 file changed

+51
-23
lines changed

1 file changed

+51
-23
lines changed

routers/chat.py

Lines changed: 51 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import time
33
from datetime import datetime
44
from typing import Any, AsyncGenerator, Dict, List, Optional, Union, cast
5+
from dataclasses import dataclass
56
from fastapi.templating import Jinja2Templates
67
from fastapi import APIRouter, Form, Depends, Request
78
from fastapi.responses import StreamingResponse, HTMLResponse
@@ -14,7 +15,7 @@
1415
from openai.types.beta import AssistantStreamEvent
1516
from openai.lib.streaming._assistants import AsyncAssistantEventHandler
1617
from openai.types.beta.threads.run_submit_tool_outputs_params import ToolOutput
17-
from openai.types.beta.threads.run import RequiredAction
18+
from openai.types.beta.threads.run import RequiredAction, Run
1819
from fastapi.responses import StreamingResponse
1920
from fastapi import APIRouter, Depends, Form, HTTPException
2021
from pydantic import BaseModel
@@ -24,6 +25,38 @@
2425
from utils.custom_functions import get_weather
2526
from utils.sse import sse_format
2627

28+
@dataclass
29+
class AssistantStreamMetadata:
30+
"""Metadata for assistant stream events that require further processing."""
31+
type: str # Always "metadata"
32+
required_action: Optional[RequiredAction]
33+
step_id: str
34+
run_requires_action_event: Optional[ThreadRunRequiresAction]
35+
36+
@classmethod
37+
def create(cls,
38+
required_action: Optional[RequiredAction],
39+
step_id: str,
40+
run_requires_action_event: Optional[ThreadRunRequiresAction]
41+
) -> "AssistantStreamMetadata":
42+
"""Factory method to create a metadata instance with validation."""
43+
return cls(
44+
type="metadata",
45+
required_action=required_action,
46+
step_id=step_id,
47+
run_requires_action_event=run_requires_action_event
48+
)
49+
50+
def requires_tool_call(self) -> bool:
51+
"""Check if this metadata indicates a required tool call."""
52+
return (self.required_action is not None
53+
and self.required_action.submit_tool_outputs is not None
54+
and bool(self.required_action.submit_tool_outputs.tool_calls))
55+
56+
def get_run_id(self) -> str:
57+
"""Get the run ID from the requires action event, or empty string if none."""
58+
return self.run_requires_action_event.data.id if self.run_requires_action_event else ""
59+
2760
logger: logging.Logger = logging.getLogger("uvicorn.error")
2861
logger.setLevel(logging.DEBUG)
2962

@@ -125,10 +158,10 @@ async def handle_assistant_stream(
125158
logger: logging.Logger,
126159
stream_manager: AsyncAssistantStreamManager,
127160
step_id: str = ""
128-
) -> AsyncGenerator[Union[Dict[str, Any], str], None]:
161+
) -> AsyncGenerator[Union[AssistantStreamMetadata, str], None]:
129162
"""
130163
Async generator to yield SSE events.
131-
We yield a final 'metadata' dictionary event once we're done.
164+
We yield a final AssistantStreamMetadata instance once we're done.
132165
"""
133166
required_action: Optional[RequiredAction] = None
134167
run_requires_action_event: Optional[ThreadRunRequiresAction] = None
@@ -218,18 +251,17 @@ async def handle_assistant_stream(
218251
if isinstance(event, ThreadRunCompleted):
219252
yield sse_format("endStream", "DONE")
220253

221-
# At the end (or break) of this async generator, we yield a final "metadata" object
222-
yield {
223-
"type": "metadata",
224-
"required_action": required_action,
225-
"step_id": step_id,
226-
"run_requires_action_event": run_requires_action_event
227-
}
254+
# At the end (or break) of this async generator, yield a final AssistantStreamMetadata
255+
yield AssistantStreamMetadata.create(
256+
required_action=required_action,
257+
step_id=step_id,
258+
run_requires_action_event=run_requires_action_event
259+
)
228260

229261
async def event_generator() -> AsyncGenerator[str, None]:
230262
"""
231263
Main generator for SSE events. We call our helper function to handle the assistant
232-
stream, and if the assistant requests a tool call, we do it and then re-run the stream.
264+
stream, and if the assistant requests a tool call, we do it and then re-stream the stream.
233265
"""
234266
step_id: str = ""
235267
stream_manager: AsyncAssistantStreamManager[AsyncAssistantEventHandler] = client.beta.threads.runs.stream(
@@ -239,17 +271,13 @@ async def event_generator() -> AsyncGenerator[str, None]:
239271
)
240272

241273
while True:
242-
event: dict[str, Any] | str
274+
event: Union[AssistantStreamMetadata, str]
243275
async for event in handle_assistant_stream(templates, logger, stream_manager, step_id):
244-
# Detect the special "metadata" event at the end of the generator
245-
if isinstance(event, dict) and event.get("type") == "metadata":
246-
required_action = cast(Optional[RequiredAction], event.get("required_action"))
247-
step_id = cast(str, event.get("step_id", ""))
248-
run_requires_action_event = cast(Optional[ThreadRunRequiresAction], event.get("run_requires_action_event"))
249-
250-
# If the assistant still needs a tool call, do it and then re-stream
251-
if required_action and required_action.submit_tool_outputs and required_action.submit_tool_outputs.tool_calls:
252-
for tool_call in required_action.submit_tool_outputs.tool_calls:
276+
if isinstance(event, AssistantStreamMetadata):
277+
# Use the helper methods from our class
278+
step_id = event.step_id
279+
if event.requires_tool_call():
280+
for tool_call in event.required_action.submit_tool_outputs.tool_calls: # type: ignore
253281
if tool_call.type == "function":
254282
try:
255283
args = json.loads(tool_call.function.arguments)
@@ -283,7 +311,7 @@ async def event_generator() -> AsyncGenerator[str, None]:
283311
"output": str(weather_output),
284312
"tool_call_id": tool_call.id
285313
},
286-
"runId": run_requires_action_event.data.id if run_requires_action_event else "",
314+
"runId": event.get_run_id(),
287315
}
288316
except Exception as err:
289317
error_message = f"Failed to get weather output: {err}"
@@ -294,7 +322,7 @@ async def event_generator() -> AsyncGenerator[str, None]:
294322
"output": error_message,
295323
"tool_call_id": tool_call.id
296324
},
297-
"runId": run_requires_action_event.data.id if run_requires_action_event else "",
325+
"runId": event.get_run_id(),
298326
}
299327

300328
# Afterwards, create a fresh stream_manager for the next iteration

0 commit comments

Comments
 (0)