Skip to content

Commit 609cd78

Browse files
Merge pull request #27 from Promptly-Technologies-LLC/display-files
Make assistant tools and vector store configurable on the setup page
2 parents 5732d1c + bd00c69 commit 609cd78

File tree

16 files changed

+1477
-1028
lines changed

16 files changed

+1477
-1028
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,6 @@ __pycache__
77
.specstory
88
.mypy_cache
99
.cursorrules
10-
.repomix-output.txt
10+
.repomix-output.txt
11+
repomix-output.txt
12+
artifacts/

main.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66
from fastapi import FastAPI, Request
77
from fastapi.staticfiles import StaticFiles
88
from fastapi.templating import Jinja2Templates
9-
from fastapi.responses import RedirectResponse, Response
9+
from fastapi.responses import RedirectResponse, Response, HTMLResponse
1010
from routers import chat, files, setup
1111
from utils.threads import create_thread
12-
from fastapi.exceptions import HTTPException
12+
from fastapi.exceptions import HTTPException, RequestValidationError
1313

1414

1515
logger = logging.getLogger("uvicorn.error")
@@ -42,6 +42,25 @@ async def general_exception_handler(request: Request, exc: Exception) -> Respons
4242
status_code=500
4343
)
4444

45+
@app.exception_handler(RequestValidationError)
46+
async def validation_exception_handler(request: Request, exc: RequestValidationError):
47+
# Log the detailed validation errors
48+
logger.error(f"Validation error: {exc.errors()}")
49+
error_details = "; ".join([f"{err['loc'][-1]}: {err['msg']}" for err in exc.errors()])
50+
51+
# Check if it's an htmx request
52+
if request.headers.get("hx-request") == "true":
53+
# Return an HTML fragment suitable for htmx swapping
54+
error_html = f'<div id="file-list-container"><p class="errorMessage">Validation Error: {error_details}</p></div>' # Assuming target is file-list-container
55+
return HTMLResponse(content=error_html, status_code=200)
56+
else:
57+
# Return the full error page for standard requests
58+
return templates.TemplateResponse(
59+
"error.html",
60+
{"request": request, "error_message": f"Invalid input: {error_details}"},
61+
status_code=422,
62+
)
63+
4564
@app.exception_handler(HTTPException)
4665
async def http_exception_handler(request: Request, exc: HTTPException) -> Response:
4766
logger.error(f"HTTP error: {exc.detail}")

routers/chat.py

Lines changed: 27 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1,61 +1,29 @@
11
import logging
22
import time
33
from datetime import datetime
4-
from typing import Any, AsyncGenerator, Dict, List, Optional, Union, cast
5-
from dataclasses import dataclass
4+
from typing import AsyncGenerator, Optional, Union
65
from fastapi.templating import Jinja2Templates
76
from fastapi import APIRouter, Form, Depends, Request
87
from fastapi.responses import StreamingResponse, HTMLResponse
98
from openai import AsyncOpenAI
10-
from openai.resources.beta.threads.runs.runs import AsyncAssistantStreamManager
9+
from openai.lib.streaming._assistants import AsyncAssistantStreamManager, AsyncAssistantEventHandler
1110
from openai.types.beta.assistant_stream_event import (
1211
ThreadMessageCreated, ThreadMessageDelta, ThreadRunCompleted,
1312
ThreadRunRequiresAction, ThreadRunStepCreated, ThreadRunStepDelta
1413
)
1514
from openai.types.beta import AssistantStreamEvent
16-
from openai.lib.streaming._assistants import AsyncAssistantEventHandler
17-
from openai.types.beta.threads.run_submit_tool_outputs_params import ToolOutput
18-
from openai.types.beta.threads.run import RequiredAction, Run
15+
from openai.types.beta.threads.run import RequiredAction
16+
from openai.types.beta.threads.message_content_delta import MessageContentDelta
17+
from openai.types.beta.threads.text_delta_block import TextDeltaBlock
1918
from fastapi.responses import StreamingResponse
20-
from fastapi import APIRouter, Depends, Form, HTTPException
21-
from pydantic import BaseModel
19+
from fastapi import APIRouter, Depends, Form
2220

2321
import json
2422

25-
from utils.custom_functions import get_weather
23+
from utils.custom_functions import get_weather, post_tool_outputs
2624
from utils.sse import sse_format
25+
from utils.streaming import AssistantStreamMetadata
2726

28-
@dataclass
29-
class AssistantStreamMetadata:
30-
"""Metadata for assistant stream events that require further processing."""
31-
type: str # Always "metadata"
32-
required_action: Optional[RequiredAction]
33-
step_id: str
34-
run_requires_action_event: Optional[ThreadRunRequiresAction]
35-
36-
@classmethod
37-
def create(cls,
38-
required_action: Optional[RequiredAction],
39-
step_id: str,
40-
run_requires_action_event: Optional[ThreadRunRequiresAction]
41-
) -> "AssistantStreamMetadata":
42-
"""Factory method to create a metadata instance with validation."""
43-
return cls(
44-
type="metadata",
45-
required_action=required_action,
46-
step_id=step_id,
47-
run_requires_action_event=run_requires_action_event
48-
)
49-
50-
def requires_tool_call(self) -> bool:
51-
"""Check if this metadata indicates a required tool call."""
52-
return (self.required_action is not None
53-
and self.required_action.submit_tool_outputs is not None
54-
and bool(self.required_action.submit_tool_outputs.tool_calls))
55-
56-
def get_run_id(self) -> str:
57-
"""Get the run ID from the requires action event, or empty string if none."""
58-
return self.run_requires_action_event.data.id if self.run_requires_action_event else ""
5927

6028
logger: logging.Logger = logging.getLogger("uvicorn.error")
6129
logger.setLevel(logging.DEBUG)
@@ -69,43 +37,6 @@ def get_run_id(self) -> str:
6937
# Jinja2 templates
7038
templates = Jinja2Templates(directory="templates")
7139

72-
# Utility function for submitting tool outputs to the assistant
73-
class ToolCallOutputs(BaseModel):
74-
tool_outputs: Dict[str, Any]
75-
runId: str
76-
77-
async def post_tool_outputs(client: AsyncOpenAI, data: Dict[str, Any], thread_id: str) -> AsyncAssistantStreamManager:
78-
"""
79-
data is expected to be something like
80-
{
81-
"tool_outputs": {
82-
"output": [{"location": "City", "temperature": 70, "conditions": "Sunny"}],
83-
"tool_call_id": "call_123"
84-
},
85-
"runId": "some-run-id",
86-
}
87-
"""
88-
try:
89-
outputs_list = [
90-
ToolOutput(
91-
output=str(data["tool_outputs"]["output"]),
92-
tool_call_id=data["tool_outputs"]["tool_call_id"]
93-
)
94-
]
95-
96-
97-
stream_manager = client.beta.threads.runs.submit_tool_outputs_stream(
98-
thread_id=thread_id,
99-
run_id=data["runId"],
100-
tool_outputs=outputs_list,
101-
)
102-
103-
return stream_manager
104-
105-
except Exception as e:
106-
logger.error(f"Error submitting tool outputs: {e}")
107-
raise HTTPException(status_code=500, detail=str(e))
108-
10940

11041
# Route to submit a new user message to a thread and mount a component that
11142
# will start an assistant run stream
@@ -170,8 +101,13 @@ async def handle_assistant_stream(
170101
async with stream_manager as event_handler:
171102
event: AssistantStreamEvent
172103
async for event in event_handler:
104+
# Debug logging for all events
105+
logger.debug(f"SSE Event Type: {type(event).__name__}")
106+
logger.debug(f"SSE Event Data: {event.data}")
107+
173108
if isinstance(event, ThreadMessageCreated):
174109
step_id = event.data.id
110+
logger.debug(f"Message Created - Step ID: {step_id}")
175111

176112
yield sse_format(
177113
"messageCreated",
@@ -183,15 +119,16 @@ async def handle_assistant_stream(
183119
time.sleep(0.25) # Give the client time to render the message
184120

185121
if isinstance(event, ThreadMessageDelta) and event.data.delta.content:
186-
content = event.data.delta.content[0]
187-
if hasattr(content, 'text') and content.text and content.text.value:
122+
content: MessageContentDelta = event.data.delta.content[0]
123+
if isinstance(content, TextDeltaBlock) and content.text and content.text.value:
188124
yield sse_format(
189125
f"textDelta{step_id}",
190126
content.text.value
191127
)
192128

193129
if isinstance(event, ThreadRunStepCreated) and event.data.type == "tool_calls":
194130
step_id = event.data.id
131+
logger.debug(f"Tool Call Created - Step ID: {step_id}")
195132

196133
yield sse_format(
197134
f"toolCallCreated",
@@ -207,6 +144,7 @@ async def handle_assistant_stream(
207144
if tool_calls:
208145
# TODO: Support parallel function calling
209146
tool_call = tool_calls[0]
147+
logger.debug(f"Tool Call Delta - Type: {tool_call.type}")
210148

211149
# Handle function tool call
212150
if tool_call.type == "function":
@@ -224,27 +162,33 @@ async def handle_assistant_stream(
224162
# Handle code interpreter tool calls
225163
elif tool_call.type == "code_interpreter":
226164
if tool_call.code_interpreter and tool_call.code_interpreter.input:
165+
logger.debug(f"Code Interpreter Input: {tool_call.code_interpreter.input}")
227166
yield sse_format(
228167
f"toolDelta{step_id}",
229168
str(tool_call.code_interpreter.input)
230169
)
231170
if tool_call.code_interpreter and tool_call.code_interpreter.outputs:
232171
for output in tool_call.code_interpreter.outputs:
172+
logger.debug(f"Code Interpreter Output Type: {output.type}")
233173
if output.type == "logs" and output.logs:
234174
yield sse_format(
235175
f"toolDelta{step_id}",
236176
str(output.logs)
237177
)
238178
elif output.type == "image" and output.image and output.image.file_id:
179+
logger.debug(f"Image Output - File ID: {output.image.file_id}")
180+
# Create the image HTML on the backend
181+
image_html = f'<img src="/assistants/{assistant_id}/files/{output.image.file_id}/content" class="code-interpreter-image">'
239182
yield sse_format(
240-
f"toolDelta{step_id}",
241-
str(output.image.file_id)
183+
f"imageOutput",
184+
image_html
242185
)
243186

244187
# If the assistant run requires an action (a tool call), break and handle it
245188
if isinstance(event, ThreadRunRequiresAction):
246189
required_action = event.data.required_action
247190
run_requires_action_event = event
191+
logger.debug("Run Requires Action Event")
248192
if required_action and required_action.submit_tool_outputs:
249193
break
250194

@@ -284,8 +228,8 @@ async def event_generator() -> AsyncGenerator[str, None]:
284228
location = args.get("location", "Unknown")
285229
dates_raw = args.get("dates", [datetime.today().strftime("%Y-%m-%d")])
286230
dates = [
287-
datetime.strptime(d, "%Y-%m-%d") if isinstance(d, str) else d
288-
for d in dates_raw
231+
datetime.strptime(d, "%Y-%m-%d")
232+
for d in dates_raw if isinstance(d, str)
289233
]
290234
except Exception as err:
291235
logger.error(f"Failed to parse function arguments: {err}")

0 commit comments

Comments
 (0)