1
1
import logging
2
2
import time
3
3
from datetime import datetime
4
- from typing import Any , AsyncGenerator , Dict , List , Optional , Union , cast
5
- from dataclasses import dataclass
4
+ from typing import AsyncGenerator , Optional , Union
6
5
from fastapi .templating import Jinja2Templates
7
6
from fastapi import APIRouter , Form , Depends , Request
8
7
from fastapi .responses import StreamingResponse , HTMLResponse
9
8
from openai import AsyncOpenAI
10
- from openai .resources . beta . threads . runs . runs import AsyncAssistantStreamManager
9
+ from openai .lib . streaming . _assistants import AsyncAssistantStreamManager , AsyncAssistantEventHandler
11
10
from openai .types .beta .assistant_stream_event import (
12
11
ThreadMessageCreated , ThreadMessageDelta , ThreadRunCompleted ,
13
12
ThreadRunRequiresAction , ThreadRunStepCreated , ThreadRunStepDelta
14
13
)
15
14
from openai .types .beta import AssistantStreamEvent
16
- from openai .lib . streaming . _assistants import AsyncAssistantEventHandler
17
- from openai .types .beta .threads .run_submit_tool_outputs_params import ToolOutput
18
- from openai .types .beta .threads .run import RequiredAction , Run
15
+ from openai .types . beta . threads . run import RequiredAction
16
+ from openai .types .beta .threads .message_content_delta import MessageContentDelta
17
+ from openai .types .beta .threads .text_delta_block import TextDeltaBlock
19
18
from fastapi .responses import StreamingResponse
20
- from fastapi import APIRouter , Depends , Form , HTTPException
21
- from pydantic import BaseModel
19
+ from fastapi import APIRouter , Depends , Form
22
20
23
21
import json
24
22
25
- from utils .custom_functions import get_weather
23
+ from utils .custom_functions import get_weather , post_tool_outputs
26
24
from utils .sse import sse_format
25
+ from utils .streaming import AssistantStreamMetadata
27
26
28
- @dataclass
29
- class AssistantStreamMetadata :
30
- """Metadata for assistant stream events that require further processing."""
31
- type : str # Always "metadata"
32
- required_action : Optional [RequiredAction ]
33
- step_id : str
34
- run_requires_action_event : Optional [ThreadRunRequiresAction ]
35
-
36
- @classmethod
37
- def create (cls ,
38
- required_action : Optional [RequiredAction ],
39
- step_id : str ,
40
- run_requires_action_event : Optional [ThreadRunRequiresAction ]
41
- ) -> "AssistantStreamMetadata" :
42
- """Factory method to create a metadata instance with validation."""
43
- return cls (
44
- type = "metadata" ,
45
- required_action = required_action ,
46
- step_id = step_id ,
47
- run_requires_action_event = run_requires_action_event
48
- )
49
-
50
- def requires_tool_call (self ) -> bool :
51
- """Check if this metadata indicates a required tool call."""
52
- return (self .required_action is not None
53
- and self .required_action .submit_tool_outputs is not None
54
- and bool (self .required_action .submit_tool_outputs .tool_calls ))
55
-
56
- def get_run_id (self ) -> str :
57
- """Get the run ID from the requires action event, or empty string if none."""
58
- return self .run_requires_action_event .data .id if self .run_requires_action_event else ""
59
27
60
28
logger : logging .Logger = logging .getLogger ("uvicorn.error" )
61
29
logger .setLevel (logging .DEBUG )
@@ -69,43 +37,6 @@ def get_run_id(self) -> str:
69
37
# Jinja2 templates
70
38
templates = Jinja2Templates (directory = "templates" )
71
39
72
- # Utility function for submitting tool outputs to the assistant
73
- class ToolCallOutputs (BaseModel ):
74
- tool_outputs : Dict [str , Any ]
75
- runId : str
76
-
77
- async def post_tool_outputs (client : AsyncOpenAI , data : Dict [str , Any ], thread_id : str ) -> AsyncAssistantStreamManager :
78
- """
79
- data is expected to be something like
80
- {
81
- "tool_outputs": {
82
- "output": [{"location": "City", "temperature": 70, "conditions": "Sunny"}],
83
- "tool_call_id": "call_123"
84
- },
85
- "runId": "some-run-id",
86
- }
87
- """
88
- try :
89
- outputs_list = [
90
- ToolOutput (
91
- output = str (data ["tool_outputs" ]["output" ]),
92
- tool_call_id = data ["tool_outputs" ]["tool_call_id" ]
93
- )
94
- ]
95
-
96
-
97
- stream_manager = client .beta .threads .runs .submit_tool_outputs_stream (
98
- thread_id = thread_id ,
99
- run_id = data ["runId" ],
100
- tool_outputs = outputs_list ,
101
- )
102
-
103
- return stream_manager
104
-
105
- except Exception as e :
106
- logger .error (f"Error submitting tool outputs: { e } " )
107
- raise HTTPException (status_code = 500 , detail = str (e ))
108
-
109
40
110
41
# Route to submit a new user message to a thread and mount a component that
111
42
# will start an assistant run stream
@@ -170,8 +101,13 @@ async def handle_assistant_stream(
170
101
async with stream_manager as event_handler :
171
102
event : AssistantStreamEvent
172
103
async for event in event_handler :
104
+ # Debug logging for all events
105
+ logger .debug (f"SSE Event Type: { type (event ).__name__ } " )
106
+ logger .debug (f"SSE Event Data: { event .data } " )
107
+
173
108
if isinstance (event , ThreadMessageCreated ):
174
109
step_id = event .data .id
110
+ logger .debug (f"Message Created - Step ID: { step_id } " )
175
111
176
112
yield sse_format (
177
113
"messageCreated" ,
@@ -183,15 +119,16 @@ async def handle_assistant_stream(
183
119
time .sleep (0.25 ) # Give the client time to render the message
184
120
185
121
if isinstance (event , ThreadMessageDelta ) and event .data .delta .content :
186
- content = event .data .delta .content [0 ]
187
- if hasattr (content , 'text' ) and content .text and content .text .value :
122
+ content : MessageContentDelta = event .data .delta .content [0 ]
123
+ if isinstance (content , TextDeltaBlock ) and content .text and content .text .value :
188
124
yield sse_format (
189
125
f"textDelta{ step_id } " ,
190
126
content .text .value
191
127
)
192
128
193
129
if isinstance (event , ThreadRunStepCreated ) and event .data .type == "tool_calls" :
194
130
step_id = event .data .id
131
+ logger .debug (f"Tool Call Created - Step ID: { step_id } " )
195
132
196
133
yield sse_format (
197
134
f"toolCallCreated" ,
@@ -207,6 +144,7 @@ async def handle_assistant_stream(
207
144
if tool_calls :
208
145
# TODO: Support parallel function calling
209
146
tool_call = tool_calls [0 ]
147
+ logger .debug (f"Tool Call Delta - Type: { tool_call .type } " )
210
148
211
149
# Handle function tool call
212
150
if tool_call .type == "function" :
@@ -224,27 +162,33 @@ async def handle_assistant_stream(
224
162
# Handle code interpreter tool calls
225
163
elif tool_call .type == "code_interpreter" :
226
164
if tool_call .code_interpreter and tool_call .code_interpreter .input :
165
+ logger .debug (f"Code Interpreter Input: { tool_call .code_interpreter .input } " )
227
166
yield sse_format (
228
167
f"toolDelta{ step_id } " ,
229
168
str (tool_call .code_interpreter .input )
230
169
)
231
170
if tool_call .code_interpreter and tool_call .code_interpreter .outputs :
232
171
for output in tool_call .code_interpreter .outputs :
172
+ logger .debug (f"Code Interpreter Output Type: { output .type } " )
233
173
if output .type == "logs" and output .logs :
234
174
yield sse_format (
235
175
f"toolDelta{ step_id } " ,
236
176
str (output .logs )
237
177
)
238
178
elif output .type == "image" and output .image and output .image .file_id :
179
+ logger .debug (f"Image Output - File ID: { output .image .file_id } " )
180
+ # Create the image HTML on the backend
181
+ image_html = f'<img src="/assistants/{ assistant_id } /files/{ output .image .file_id } /content" class="code-interpreter-image">'
239
182
yield sse_format (
240
- f"toolDelta { step_id } " ,
241
- str ( output . image . file_id )
183
+ f"imageOutput " ,
184
+ image_html
242
185
)
243
186
244
187
# If the assistant run requires an action (a tool call), break and handle it
245
188
if isinstance (event , ThreadRunRequiresAction ):
246
189
required_action = event .data .required_action
247
190
run_requires_action_event = event
191
+ logger .debug ("Run Requires Action Event" )
248
192
if required_action and required_action .submit_tool_outputs :
249
193
break
250
194
@@ -284,8 +228,8 @@ async def event_generator() -> AsyncGenerator[str, None]:
284
228
location = args .get ("location" , "Unknown" )
285
229
dates_raw = args .get ("dates" , [datetime .today ().strftime ("%Y-%m-%d" )])
286
230
dates = [
287
- datetime .strptime (d , "%Y-%m-%d" ) if isinstance ( d , str ) else d
288
- for d in dates_raw
231
+ datetime .strptime (d , "%Y-%m-%d" )
232
+ for d in dates_raw if isinstance ( d , str )
289
233
]
290
234
except Exception as err :
291
235
logger .error (f"Failed to parse function arguments: { err } " )
0 commit comments