Skip to content

Commit 2b8c028

Browse files
type lint
1 parent 6a825d2 commit 2b8c028

File tree

5 files changed

+169
-139
lines changed

5 files changed

+169
-139
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,6 @@
55
__pycache__
66
*.pyc
77
.specstory
8-
.mypy_cache
8+
.mypy_cache
9+
.cursorrules
10+
.repomix-output.txt

main.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import os
22
import logging
3+
from typing import List, Dict, Any, Optional
34
from dotenv import load_dotenv
45
from contextlib import asynccontextmanager
56
from fastapi import FastAPI, Request
67
from fastapi.staticfiles import StaticFiles
78
from fastapi.templating import Jinja2Templates
8-
from fastapi.responses import RedirectResponse
9+
from fastapi.responses import RedirectResponse, Response
910
from routers import chat, files, setup
1011
from utils.threads import create_thread
1112
from fastapi.exceptions import HTTPException
@@ -33,7 +34,7 @@ async def lifespan(app: FastAPI):
3334
templates = Jinja2Templates(directory="templates")
3435

3536
@app.exception_handler(Exception)
36-
async def general_exception_handler(request: Request, exc: Exception):
37+
async def general_exception_handler(request: Request, exc: Exception) -> Response:
3738
logger.error(f"Unhandled error: {exc}")
3839
return templates.TemplateResponse(
3940
"error.html",
@@ -42,7 +43,7 @@ async def general_exception_handler(request: Request, exc: Exception):
4243
)
4344

4445
@app.exception_handler(HTTPException)
45-
async def http_exception_handler(request: Request, exc: HTTPException):
46+
async def http_exception_handler(request: Request, exc: HTTPException) -> Response:
4647
logger.error(f"HTTP error: {exc.detail}")
4748
return templates.TemplateResponse(
4849
"error.html",
@@ -54,23 +55,30 @@ async def http_exception_handler(request: Request, exc: HTTPException):
5455
# TODO: Implement some kind of thread id storage or management logic to allow
5556
# user to load an old thread, delete an old thread, etc. instead of start new
5657
@app.get("/")
57-
async def read_home(request: Request, thread_id: str = None, messages: list = []):
58+
async def read_home(
59+
request: Request,
60+
thread_id: Optional[str] = None,
61+
messages: List[Dict[str, Any]] = []
62+
) -> Response:
5863
logger.info("Home page requested")
5964

6065
# Check if environment variables are missing
6166
load_dotenv(override=True)
62-
if not os.getenv("OPENAI_API_KEY") or not os.getenv("ASSISTANT_ID"):
67+
openai_api_key = os.getenv("OPENAI_API_KEY")
68+
assistant_id = os.getenv("ASSISTANT_ID")
69+
70+
if not openai_api_key or not assistant_id:
6371
return RedirectResponse(url=app.url_path_for("read_setup"))
6472

6573
# Create a new assistant chat thread if no thread ID is provided
6674
if not thread_id or thread_id == "None" or thread_id == "null":
67-
thread_id: str = await create_thread()
75+
thread_id = await create_thread()
6876

6977
return templates.TemplateResponse(
7078
"index.html",
7179
{
7280
"request": request,
73-
"assistant_id": os.getenv("ASSISTANT_ID"),
81+
"assistant_id": assistant_id,
7482
"messages": messages,
7583
"thread_id": thread_id
7684
}

routers/chat.py

Lines changed: 82 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
import time
33
from datetime import datetime
4-
from typing import Any, AsyncGenerator
4+
from typing import Any, AsyncGenerator, Dict, List, Optional, Union, cast
55
from fastapi.templating import Jinja2Templates
66
from fastapi import APIRouter, Form, Depends, Request
77
from fastapi.responses import StreamingResponse, HTMLResponse
@@ -38,10 +38,10 @@
3838

3939
# Utility function for submitting tool outputs to the assistant
4040
class ToolCallOutputs(BaseModel):
41-
tool_outputs: Any
41+
tool_outputs: Dict[str, Any]
4242
runId: str
4343

44-
async def post_tool_outputs(client: AsyncOpenAI, data: dict, thread_id: str):
44+
async def post_tool_outputs(client: AsyncOpenAI, data: Dict[str, Any], thread_id: str) -> AsyncAssistantStreamManager:
4545
"""
4646
data is expected to be something like
4747
{
@@ -55,7 +55,7 @@ async def post_tool_outputs(client: AsyncOpenAI, data: dict, thread_id: str):
5555
try:
5656
outputs_list = [
5757
ToolOutput(
58-
output=data["tool_outputs"]["output"],
58+
output=str(data["tool_outputs"]["output"]),
5959
tool_call_id=data["tool_outputs"]["tool_call_id"]
6060
)
6161
]
@@ -124,14 +124,14 @@ async def handle_assistant_stream(
124124
templates: Jinja2Templates,
125125
logger: logging.Logger,
126126
stream_manager: AsyncAssistantStreamManager,
127-
step_id: int = 0
128-
) -> AsyncGenerator:
127+
step_id: str = ""
128+
) -> AsyncGenerator[Union[Dict[str, Any], str], None]:
129129
"""
130130
Async generator to yield SSE events.
131131
We yield a final 'metadata' dictionary event once we're done.
132132
"""
133-
required_action: RequiredAction | None = None
134-
run_requires_action_event: ThreadRunRequiresAction | None = None
133+
required_action: Optional[RequiredAction] = None
134+
run_requires_action_event: Optional[ThreadRunRequiresAction] = None
135135

136136
event_handler: AsyncAssistantEventHandler
137137
async with stream_manager as event_handler:
@@ -149,11 +149,13 @@ async def handle_assistant_stream(
149149
)
150150
time.sleep(0.25) # Give the client time to render the message
151151

152-
if isinstance(event, ThreadMessageDelta):
153-
yield sse_format(
154-
f"textDelta{step_id}",
155-
event.data.delta.content[0].text.value
156-
)
152+
if isinstance(event, ThreadMessageDelta) and event.data.delta.content:
153+
content = event.data.delta.content[0]
154+
if hasattr(content, 'text') and content.text and content.text.value:
155+
yield sse_format(
156+
f"textDelta{step_id}",
157+
content.text.value
158+
)
157159

158160
if isinstance(event, ThreadRunStepCreated) and event.data.type == "tool_calls":
159161
step_id = event.data.id
@@ -167,47 +169,50 @@ async def handle_assistant_stream(
167169
)
168170
time.sleep(0.25) # Give the client time to render the message
169171

170-
if isinstance(event, ThreadRunStepDelta) and event.data.delta.step_details.type == "tool_calls":
171-
tool_call = event.data.delta.step_details.tool_calls[0]
172-
173-
# Handle function tool calls
174-
if tool_call.type == "function":
175-
if tool_call.function.name:
176-
yield sse_format(
177-
f"toolDelta{step_id}",
178-
tool_call.function.name + "<br>"
179-
)
180-
elif tool_call.function.arguments:
181-
yield sse_format(
182-
f"toolDelta{step_id}",
183-
tool_call.function.arguments
184-
)
185-
186-
# Handle code interpreter tool calls
187-
elif tool_call.type == "code_interpreter":
188-
if tool_call.code_interpreter.input:
189-
yield sse_format(
190-
f"toolDelta{step_id}",
191-
f"{tool_call.code_interpreter.input}"
192-
)
193-
if tool_call.code_interpreter.outputs:
194-
for output in tool_call.code_interpreter.outputs:
195-
if output.type == "logs":
196-
yield sse_format(
197-
f"toolDelta{step_id}",
198-
f"{output.logs}"
199-
)
200-
elif output.type == "image":
201-
yield sse_format(
202-
f"toolDelta{step_id}",
203-
f"{output.image.file_id}"
204-
)
172+
if isinstance(event, ThreadRunStepDelta) and event.data.delta.step_details and event.data.delta.step_details.type == "tool_calls":
173+
tool_calls = event.data.delta.step_details.tool_calls
174+
if tool_calls:
175+
# TODO: Support parallel function calling
176+
tool_call = tool_calls[0]
177+
178+
# Handle function tool call
179+
if tool_call.type == "function":
180+
if tool_call.function and tool_call.function.name:
181+
yield sse_format(
182+
f"toolDelta{step_id}",
183+
tool_call.function.name + "<br>"
184+
)
185+
if tool_call.function and tool_call.function.arguments:
186+
yield sse_format(
187+
f"toolDelta{step_id}",
188+
tool_call.function.arguments
189+
)
190+
191+
# Handle code interpreter tool calls
192+
elif tool_call.type == "code_interpreter":
193+
if tool_call.code_interpreter and tool_call.code_interpreter.input:
194+
yield sse_format(
195+
f"toolDelta{step_id}",
196+
str(tool_call.code_interpreter.input)
197+
)
198+
if tool_call.code_interpreter and tool_call.code_interpreter.outputs:
199+
for output in tool_call.code_interpreter.outputs:
200+
if output.type == "logs" and output.logs:
201+
yield sse_format(
202+
f"toolDelta{step_id}",
203+
str(output.logs)
204+
)
205+
elif output.type == "image" and output.image and output.image.file_id:
206+
yield sse_format(
207+
f"toolDelta{step_id}",
208+
str(output.image.file_id)
209+
)
205210

206211
# If the assistant run requires an action (a tool call), break and handle it
207212
if isinstance(event, ThreadRunRequiresAction):
208213
required_action = event.data.required_action
209214
run_requires_action_event = event
210-
if required_action.submit_tool_outputs:
215+
if required_action and required_action.submit_tool_outputs:
211216
break
212217

213218
if isinstance(event, ThreadRunCompleted):
@@ -221,45 +226,50 @@ async def handle_assistant_stream(
221226
"run_requires_action_event": run_requires_action_event
222227
}
223228

224-
async def event_generator():
229+
async def event_generator() -> AsyncGenerator[str, None]:
225230
"""
226231
Main generator for SSE events. We call our helper function to handle the assistant
227232
stream, and if the assistant requests a tool call, we do it and then re-run the stream.
228233
"""
229-
step_id = 0
230-
initial_manager = client.beta.threads.runs.stream(
234+
step_id: str = ""
235+
stream_manager: AsyncAssistantStreamManager[AsyncAssistantEventHandler] = client.beta.threads.runs.stream(
231236
assistant_id=assistant_id,
232237
thread_id=thread_id,
233238
parallel_tool_calls=False
234239
)
235240

236-
stream_manager = initial_manager
237241
while True:
242+
event: dict[str, Any] | str
238243
async for event in handle_assistant_stream(templates, logger, stream_manager, step_id):
239244
# Detect the special "metadata" event at the end of the generator
240245
if isinstance(event, dict) and event.get("type") == "metadata":
241-
required_action: RequiredAction | None = event["required_action"]
242-
step_id: int = event["step_id"]
243-
run_requires_action_event: ThreadRunRequiresAction | None = event["run_requires_action_event"]
246+
required_action = cast(Optional[RequiredAction], event.get("required_action"))
247+
step_id = cast(str, event.get("step_id", ""))
248+
run_requires_action_event = cast(Optional[ThreadRunRequiresAction], event.get("run_requires_action_event"))
244249

245250
# If the assistant still needs a tool call, do it and then re-stream
246-
if required_action and required_action.submit_tool_outputs:
251+
if required_action and required_action.submit_tool_outputs and required_action.submit_tool_outputs.tool_calls:
247252
for tool_call in required_action.submit_tool_outputs.tool_calls:
248253
if tool_call.type == "function":
249254
try:
250255
args = json.loads(tool_call.function.arguments)
251256
location = args.get("location", "Unknown")
252-
dates = args.get("dates", [datetime.today()])
257+
dates_raw = args.get("dates", [datetime.today().strftime("%Y-%m-%d")])
258+
dates = [
259+
datetime.strptime(d, "%Y-%m-%d") if isinstance(d, str) else d
260+
for d in dates_raw
261+
]
253262
except Exception as err:
254263
logger.error(f"Failed to parse function arguments: {err}")
255264
location = "Unknown"
265+
dates = [datetime.today()]
256266

257267
try:
258-
weather_output: list[dict] = get_weather(location, dates)
268+
weather_output: list = get_weather(location, dates)
259269
logger.info(f"Weather output: {weather_output}")
260270

261271
# Render the weather widget
262-
weather_widget_html: str = templates.get_template(
272+
weather_widget_html = templates.get_template(
263273
"components/weather-widget.html"
264274
).render(
265275
reports=weather_output
@@ -273,7 +283,7 @@ async def event_generator():
273283
"output": str(weather_output),
274284
"tool_call_id": tool_call.id
275285
},
276-
"runId": run_requires_action_event.data.id,
286+
"runId": run_requires_action_event.data.id if run_requires_action_event else "",
277287
}
278288
except Exception as err:
279289
error_message = f"Failed to get weather output: {err}"
@@ -284,24 +294,24 @@ async def event_generator():
284294
"output": error_message,
285295
"tool_call_id": tool_call.id
286296
},
287-
"runId": run_requires_action_event.data.id,
297+
"runId": run_requires_action_event.data.id if run_requires_action_event else "",
288298
}
289299

290-
# Afterwards, create a fresh stream_manager for the next iteration
291-
new_stream_manager: AsyncAssistantStreamManager = await post_tool_outputs(
292-
client,
293-
data_for_tool,
294-
thread_id
295-
)
296-
stream_manager = new_stream_manager
297-
# proceed to rerun the loop
298-
break
300+
# Afterwards, create a fresh stream_manager for the next iteration
301+
new_stream_manager = await post_tool_outputs(
302+
client,
303+
data_for_tool,
304+
thread_id
305+
)
306+
stream_manager = new_stream_manager
307+
# proceed to rerun the loop
308+
break
299309
else:
300310
# No more tool calls needed; we're done streaming
301311
return
302312
else:
303313
# Normal SSE events: yield them to the client
304-
yield event
314+
yield str(event)
305315

306316
return StreamingResponse(
307317
event_generator(),

0 commit comments

Comments
 (0)