Skip to content

Commit 6a825d2

Browse files
Type linting
1 parent c05b639 commit 6a825d2

File tree

6 files changed

+122
-65
lines changed

6 files changed

+122
-65
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@
44
.venv
55
__pycache__
66
*.pyc
7-
.specstory
7+
.specstory
8+
.mypy_cache

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,8 @@ dependencies = [
1212
"python-multipart>=0.0.20",
1313
"uvicorn>=0.32.0",
1414
]
15+
16+
[dependency-groups]
17+
dev = [
18+
"mypy>=1.15.0",
19+
]

routers/chat.py

Lines changed: 54 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
ThreadMessageCreated, ThreadMessageDelta, ThreadRunCompleted,
1212
ThreadRunRequiresAction, ThreadRunStepCreated, ThreadRunStepDelta
1313
)
14+
from openai.types.beta import AssistantStreamEvent
15+
from openai.lib.streaming._assistants import AsyncAssistantEventHandler
1416
from openai.types.beta.threads.run_submit_tool_outputs_params import ToolOutput
1517
from openai.types.beta.threads.run import RequiredAction
1618
from fastapi.responses import StreamingResponse
@@ -122,45 +124,45 @@ async def handle_assistant_stream(
122124
templates: Jinja2Templates,
123125
logger: logging.Logger,
124126
stream_manager: AsyncAssistantStreamManager,
125-
start_step_count: int = 0
127+
step_id: int = 0
126128
) -> AsyncGenerator:
127129
"""
128130
Async generator to yield SSE events.
129131
We yield a final 'metadata' dictionary event once we're done.
130132
"""
131-
step_counter: int = start_step_count
132133
required_action: RequiredAction | None = None
133134
run_requires_action_event: ThreadRunRequiresAction | None = None
134135

136+
event_handler: AsyncAssistantEventHandler
135137
async with stream_manager as event_handler:
138+
event: AssistantStreamEvent
136139
async for event in event_handler:
137-
logger.info(f"{event}")
138-
139140
if isinstance(event, ThreadMessageCreated):
140-
step_counter += 1
141+
step_id = event.data.id
141142

142143
yield sse_format(
143144
"messageCreated",
144145
templates.get_template("components/assistant-step.html").render(
145146
step_type="assistantMessage",
146-
stream_name=f"textDelta{step_counter}"
147+
stream_name=f"textDelta{step_id}"
147148
)
148149
)
149150
time.sleep(0.25) # Give the client time to render the message
150151

151152
if isinstance(event, ThreadMessageDelta):
152-
logger.info(f"Sending delta with name textDelta{step_counter}")
153153
yield sse_format(
154-
f"textDelta{step_counter}",
154+
f"textDelta{step_id}",
155155
event.data.delta.content[0].text.value
156156
)
157157

158158
if isinstance(event, ThreadRunStepCreated) and event.data.type == "tool_calls":
159+
step_id = event.data.id
160+
159161
yield sse_format(
160162
f"toolCallCreated",
161163
templates.get_template('components/assistant-step.html').render(
162164
step_type='toolCall',
163-
stream_name=f'toolDelta{step_counter}'
165+
stream_name=f'toolDelta{step_id}'
164166
)
165167
)
166168
time.sleep(0.25) # Give the client time to render the message
@@ -172,32 +174,32 @@ async def handle_assistant_stream(
172174
if tool_call.type == "function":
173175
if tool_call.function.name:
174176
yield sse_format(
175-
f"toolDelta{step_counter}",
177+
f"toolDelta{step_id}",
176178
tool_call.function.name + "<br>"
177179
)
178180
elif tool_call.function.arguments:
179181
yield sse_format(
180-
f"toolDelta{step_counter}",
182+
f"toolDelta{step_id}",
181183
tool_call.function.arguments
182184
)
183185

184186
# Handle code interpreter tool calls
185187
elif tool_call.type == "code_interpreter":
186188
if tool_call.code_interpreter.input:
187189
yield sse_format(
188-
f"toolDelta{step_counter}",
190+
f"toolDelta{step_id}",
189191
f"{tool_call.code_interpreter.input}"
190192
)
191193
if tool_call.code_interpreter.outputs:
192194
for output in tool_call.code_interpreter.outputs:
193195
if output.type == "logs":
194196
yield sse_format(
195-
f"toolDelta{step_counter}",
197+
f"toolDelta{step_id}",
196198
f"{output.logs}"
197199
)
198200
elif output.type == "image":
199201
yield sse_format(
200-
f"toolDelta{step_counter}",
202+
f"toolDelta{step_id}",
201203
f"{output.image.file_id}"
202204
)
203205

@@ -215,7 +217,7 @@ async def handle_assistant_stream(
215217
yield {
216218
"type": "metadata",
217219
"required_action": required_action,
218-
"step_counter": step_counter,
220+
"step_id": step_id,
219221
"run_requires_action_event": run_requires_action_event
220222
}
221223

@@ -224,36 +226,26 @@ async def event_generator():
224226
Main generator for SSE events. We call our helper function to handle the assistant
225227
stream, and if the assistant requests a tool call, we do it and then re-run the stream.
226228
"""
227-
step_counter = 0
228-
# First run of the assistant stream
229+
step_id = 0
229230
initial_manager = client.beta.threads.runs.stream(
230231
assistant_id=assistant_id,
231232
thread_id=thread_id,
232233
parallel_tool_calls=False
233234
)
234235

235-
# We'll re-run the loop if needed for tool calls
236236
stream_manager = initial_manager
237-
while True:
238-
async for event in handle_assistant_stream(templates, logger, stream_manager, step_counter):
237+
while True:
238+
async for event in handle_assistant_stream(templates, logger, stream_manager, step_id):
239239
# Detect the special "metadata" event at the end of the generator
240240
if isinstance(event, dict) and event.get("type") == "metadata":
241241
required_action: RequiredAction | None = event["required_action"]
242-
step_counter: int = event["step_counter"]
242+
step_id: int = event["step_id"]
243243
run_requires_action_event: ThreadRunRequiresAction | None = event["run_requires_action_event"]
244244

245245
# If the assistant still needs a tool call, do it and then re-stream
246246
if required_action and required_action.submit_tool_outputs:
247247
for tool_call in required_action.submit_tool_outputs.tool_calls:
248-
yield sse_format(
249-
"toolCallCreated",
250-
templates.get_template('components/assistant-step.html').render(
251-
step_type='toolCall',
252-
stream_name=f'toolDelta{step_counter}'
253-
)
254-
)
255-
256-
if tool_call.type == "function" and tool_call.function.name == "get_weather":
248+
if tool_call.type == "function":
257249
try:
258250
args = json.loads(tool_call.function.arguments)
259251
location = args.get("location", "Unknown")
@@ -262,26 +254,38 @@ async def event_generator():
262254
logger.error(f"Failed to parse function arguments: {err}")
263255
location = "Unknown"
264256

265-
weather_output: list[dict] = get_weather(location, dates)
266-
logger.info(f"Weather output: {weather_output}")
267-
268-
# Render the weather widget
269-
weather_widget_html: str = templates.get_template(
270-
"components/weather-widget.html"
271-
).render(
272-
reports=weather_output
273-
)
274-
275-
# Yield the rendered HTML
276-
yield sse_format("toolOutput", weather_widget_html)
277-
278-
data_for_tool = {
279-
"tool_outputs": {
280-
"output": str(weather_output),
281-
"tool_call_id": tool_call.id
282-
},
283-
"runId": run_requires_action_event.data.id,
284-
}
257+
try:
258+
weather_output: list[dict] = get_weather(location, dates)
259+
logger.info(f"Weather output: {weather_output}")
260+
261+
# Render the weather widget
262+
weather_widget_html: str = templates.get_template(
263+
"components/weather-widget.html"
264+
).render(
265+
reports=weather_output
266+
)
267+
268+
# Yield the rendered HTML
269+
yield sse_format("toolOutput", weather_widget_html)
270+
271+
data_for_tool = {
272+
"tool_outputs": {
273+
"output": str(weather_output),
274+
"tool_call_id": tool_call.id
275+
},
276+
"runId": run_requires_action_event.data.id,
277+
}
278+
except Exception as err:
279+
error_message = f"Failed to get weather output: {err}"
280+
logger.error(error_message)
281+
yield sse_format("toolOutput", error_message)
282+
data_for_tool = {
283+
"tool_outputs": {
284+
"output": error_message,
285+
"tool_call_id": tool_call.id
286+
},
287+
"runId": run_requires_action_event.data.id,
288+
}
285289

286290
# Afterwards, create a fresh stream_manager for the next iteration
287291
new_stream_manager: AsyncAssistantStreamManager = await post_tool_outputs(

utils/create_assistant.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import os
22
import logging
33
import asyncio
4+
from typing import cast
45
from dotenv import load_dotenv
56
from openai import AsyncOpenAI
67
from openai.types.beta.assistant_create_params import AssistantCreateParams
8+
from openai.types.beta.assistant_update_params import AssistantUpdateParams
79
from openai.types.beta.assistant_tool_param import CodeInterpreterToolParam, FileSearchToolParam, FunctionToolParam
810
from openai.types.beta.assistant import Assistant
9-
from openai.types import FunctionDefinition
11+
from openai.types.shared_params.function_definition import FunctionDefinition
1012
from openai.types.beta.file_search_tool_param import FileSearch
1113

1214

@@ -42,11 +44,11 @@
4244
}
4345
}
4446
},
45-
# Currently OpenAI requires that all properties be required
46-
"required": ["location", "dates"],
47+
"required": ["location"],
4748
"additionalProperties": False,
4849
},
49-
strict=True,
50+
# strict=True gives better adherence to the schema, but all arguments must be required
51+
strict=False
5052
)
5153
),
5254
],
@@ -88,35 +90,38 @@ def update_env_file(var_name: str, var_value: str, logger: logging.Logger):
8890

8991
async def create_or_update_assistant(
9092
client: AsyncOpenAI,
91-
assistant_id: str,
92-
request: AssistantCreateParams,
93+
assistant_id: str | None,
94+
request: AssistantCreateParams | AssistantUpdateParams,
9395
logger: logging.Logger
9496
) -> str:
9597
"""
9698
Create or update the assistant based on the presence of an assistant_id.
9799
"""
98100
try:
101+
assistant: Assistant
99102
if assistant_id:
100103
# Update the existing assistant
101-
assistant: Assistant = await client.beta.assistants.update(
104+
assistant = await client.beta.assistants.update(
102105
assistant_id,
103-
**request
106+
**cast(AssistantUpdateParams, request)
104107
)
105108
logger.info(f"Updated assistant with ID: {assistant_id}")
106109
else:
107110
# Create a new assistant
108-
assistant: Assistant = await client.beta.assistants.create(**request)
111+
assistant = await client.beta.assistants.create(
112+
**cast(AssistantCreateParams, request)
113+
)
109114
logger.info(f"Created new assistant: {assistant.id}")
110115

111116
# Update the .env file with the new assistant ID
112117
update_env_file("ASSISTANT_ID", assistant.id, logger)
113-
114-
return assistant.id
115118

116119
except Exception as e:
117120
action = "update" if assistant_id else "create"
118121
logger.error(f"Failed to {action} assistant: {e}")
119122

123+
return assistant.id
124+
120125

121126
# Run the assistant creation in an asyncio event loop
122127
if __name__ == "__main__":
@@ -126,8 +131,8 @@ async def create_or_update_assistant(
126131
logging.basicConfig(level=logging.INFO, stream=sys.stdout)
127132
logger: logging.Logger = logging.getLogger(__name__)
128133

129-
load_dotenv()
130-
assistant_id = os.getenv("ASSISTANT_ID")
134+
load_dotenv(override=True)
135+
assistant_id = os.getenv("ASSISTANT_ID", None)
131136

132137
# Initialize the OpenAI client
133138
openai: AsyncOpenAI = AsyncOpenAI()

utils/sse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
def sse_format(event: str, data: str, retry: int = None) -> str:
1+
def sse_format(event: str, data: str, retry: int | None = None) -> str:
22
"""
33
Helper function to format a Server-Sent Event (SSE) message.
44

0 commit comments

Comments
 (0)