6
6
from fastapi .responses import StreamingResponse , HTMLResponse
7
7
from openai import AsyncOpenAI
8
8
from openai .resources .beta .threads .runs .runs import AsyncAssistantStreamManager
9
- from openai .types .beta .assistant_stream_event import ThreadMessageCreated , ThreadMessageDelta , ThreadRunCompleted
9
+ from openai .types .beta .assistant_stream_event import ThreadMessageCreated , ThreadMessageDelta , ThreadRunCompleted , ThreadRunRequiresAction
10
10
from fastapi .responses import StreamingResponse
11
11
from fastapi import APIRouter , Depends , Form , HTTPException
12
12
from pydantic import BaseModel
13
+ import json
13
14
15
+ # Import our get_weather method
16
+ from utils .weather import get_weather
14
17
15
18
logger : logging .Logger = logging .getLogger ("uvicorn.error" )
16
19
logger .setLevel (logging .DEBUG )
@@ -30,20 +33,25 @@ class ToolCallOutputs(BaseModel):
30
33
runId : str
31
34
32
35
async def post_tool_outputs (client : AsyncOpenAI , data : dict , thread_id : str ):
33
-
36
+ """
37
+ data is expected to be something like
38
+
39
+ {
40
+ "tool_outputs": {"location": "City", "temperature": 70, "conditions": "Sunny"},
41
+ "runId": "some-run-id",
42
+ }
43
+ """
34
44
try :
35
- # Parse the JSON body into the ToolCallOutputs model
36
- tool_call_outputs = ToolCallOutputs (** data )
37
-
38
- # Submit tool outputs stream
39
- stream = await client .beta .threads .runs .submit_tool_outputs_stream (
40
- thread_id ,
41
- tool_call_outputs .runId ,
42
- {"tool_outputs" : tool_call_outputs .tool_outputs }
45
+ outputs_list = [data ["tool_outputs" ]]
46
+
47
+ stream_manager = client .beta .threads .runs .submit_tool_outputs_stream (
48
+ thread_id = thread_id ,
49
+ run_id = data ["runId" ],
50
+ tool_outputs = outputs_list ,
43
51
)
44
52
45
- # Return the stream as a response
46
- return stream . to_readable_stream ()
53
+ return stream_manager
54
+
47
55
except Exception as e :
48
56
logger .error (f"Error submitting tool outputs: { e } " )
49
57
raise HTTPException (status_code = 500 , detail = str (e ))
@@ -120,6 +128,34 @@ async def event_generator():
120
128
f"data: { event .data .delta .content [0 ].text .value } \n \n "
121
129
)
122
130
131
+ if isinstance (event , ThreadRunRequiresAction ):
132
+ required_action = event .data .required_action
133
+ if required_action and required_action .submit_tool_outputs :
134
+ for tool_call in required_action .submit_tool_outputs .tool_calls :
135
+ yield (
136
+ f"event: toolCallCreated\n "
137
+ f"data: { templates .get_template ('components/assistant-step.html' ).render (
138
+ step_type = 'toolCall' , stream_name = f'toolDelta{ step_counter } '
139
+ ).replace ('\n ' , '' )} \n \n "
140
+ )
141
+
142
+ if tool_call .type == "function" and tool_call .function .name == "get_weather" :
143
+ try :
144
+ args = json .loads (tool_call .function .arguments )
145
+ location = args .get ("location" , "Unknown" )
146
+ except Exception as err :
147
+ logger .error (f"Failed to parse function arguments: { err } " )
148
+ location = "Unknown"
149
+
150
+ weather_output = get_weather (location )
151
+ logger .info (f"Weather output: { weather_output } " )
152
+
153
+ data_for_tool = {
154
+ "tool_outputs" : weather_output ,
155
+ "runId" : event .data .id ,
156
+ }
157
+ await post_tool_outputs (client , data_for_tool , thread_id )
158
+
123
159
if isinstance (event , ThreadRunCompleted ):
124
160
yield "event: endStream\n data: DONE\n \n "
125
161
0 commit comments