Skip to content

Commit a291805

Browse files
Allow tool selection during assistant setup or subsequent assistant update
1 parent 2d9a06b commit a291805

File tree

10 files changed

+314
-206
lines changed

10 files changed

+314
-206
lines changed

routers/chat.py

Lines changed: 4 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import logging
22
import time
33
from datetime import datetime
4-
from typing import Any, AsyncGenerator, Dict, List, Optional, Union, cast
5-
from dataclasses import dataclass
4+
from typing import AsyncGenerator, Optional, Union
65
from fastapi.templating import Jinja2Templates
76
from fastapi import APIRouter, Form, Depends, Request
87
from fastapi.responses import StreamingResponse, HTMLResponse
@@ -13,50 +12,18 @@
1312
ThreadRunRequiresAction, ThreadRunStepCreated, ThreadRunStepDelta
1413
)
1514
from openai.types.beta import AssistantStreamEvent
16-
from openai.types.beta.threads.run_submit_tool_outputs_params import ToolOutput
1715
from openai.types.beta.threads.run import RequiredAction
1816
from openai.types.beta.threads.message_content_delta import MessageContentDelta
1917
from openai.types.beta.threads.text_delta_block import TextDeltaBlock
2018
from fastapi.responses import StreamingResponse
21-
from fastapi import APIRouter, Depends, Form, HTTPException
22-
from pydantic import BaseModel
19+
from fastapi import APIRouter, Depends, Form
2320

2421
import json
2522

26-
from utils.custom_functions import get_weather
23+
from utils.custom_functions import get_weather, post_tool_outputs
2724
from utils.sse import sse_format
25+
from utils.streaming import AssistantStreamMetadata
2826

29-
@dataclass
30-
class AssistantStreamMetadata:
31-
"""Metadata for assistant stream events that require further processing."""
32-
type: str # Always "metadata"
33-
required_action: Optional[RequiredAction]
34-
step_id: str
35-
run_requires_action_event: Optional[ThreadRunRequiresAction]
36-
37-
@classmethod
38-
def create(cls,
39-
required_action: Optional[RequiredAction],
40-
step_id: str,
41-
run_requires_action_event: Optional[ThreadRunRequiresAction]
42-
) -> "AssistantStreamMetadata":
43-
"""Factory method to create a metadata instance with validation."""
44-
return cls(
45-
type="metadata",
46-
required_action=required_action,
47-
step_id=step_id,
48-
run_requires_action_event=run_requires_action_event
49-
)
50-
51-
def requires_tool_call(self) -> bool:
52-
"""Check if this metadata indicates a required tool call."""
53-
return (self.required_action is not None
54-
and self.required_action.submit_tool_outputs is not None
55-
and bool(self.required_action.submit_tool_outputs.tool_calls))
56-
57-
def get_run_id(self) -> str:
58-
"""Get the run ID from the requires action event, or empty string if none."""
59-
return self.run_requires_action_event.data.id if self.run_requires_action_event else ""
6027

6128
logger: logging.Logger = logging.getLogger("uvicorn.error")
6229
logger.setLevel(logging.DEBUG)
@@ -70,43 +37,6 @@ def get_run_id(self) -> str:
7037
# Jinja2 templates
7138
templates = Jinja2Templates(directory="templates")
7239

73-
# Utility function for submitting tool outputs to the assistant
74-
class ToolCallOutputs(BaseModel):
75-
tool_outputs: Dict[str, Any]
76-
runId: str
77-
78-
async def post_tool_outputs(client: AsyncOpenAI, data: Dict[str, Any], thread_id: str) -> AsyncAssistantStreamManager:
79-
"""
80-
data is expected to be something like
81-
{
82-
"tool_outputs": {
83-
"output": [{"location": "City", "temperature": 70, "conditions": "Sunny"}],
84-
"tool_call_id": "call_123"
85-
},
86-
"runId": "some-run-id",
87-
}
88-
"""
89-
try:
90-
outputs_list = [
91-
ToolOutput(
92-
output=str(data["tool_outputs"]["output"]),
93-
tool_call_id=data["tool_outputs"]["tool_call_id"]
94-
)
95-
]
96-
97-
98-
stream_manager = client.beta.threads.runs.submit_tool_outputs_stream(
99-
thread_id=thread_id,
100-
run_id=data["runId"],
101-
tool_outputs=outputs_list,
102-
)
103-
104-
return stream_manager
105-
106-
except Exception as e:
107-
logger.error(f"Error submitting tool outputs: {e}")
108-
raise HTTPException(status_code=500, detail=str(e))
109-
11040

11141
# Route to submit a new user message to a thread and mount a component that
11242
# will start an assistant run stream

routers/files.py

Lines changed: 30 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
import os
22
import logging
3-
from typing import List, Dict, Any, AsyncIterable
3+
from typing import List, Dict
44
from dotenv import load_dotenv
55
from fastapi import APIRouter, Request, UploadFile, File, HTTPException, Depends, Form, Path
66
from fastapi.responses import StreamingResponse
77
from openai import AsyncOpenAI
8+
from openai.types.file_purpose import FilePurpose
9+
from utils.files import get_or_create_vector_store
10+
from utils.streaming import stream_file_content
811

912
logger = logging.getLogger("uvicorn.error")
1013

1114
# Get assistant ID from environment variables
12-
load_dotenv()
15+
load_dotenv(override=True)
1316
assistant_id_env = os.getenv("ASSISTANT_ID")
1417
if not assistant_id_env:
1518
raise ValueError("ASSISTANT_ID environment variable not set")
@@ -20,21 +23,8 @@
2023
tags=["assistants_files"]
2124
)
2225

23-
# Helper function to get or create a vector store
24-
async def get_or_create_vector_store(assistantId: str, client: AsyncOpenAI = Depends(lambda: AsyncOpenAI())) -> str:
25-
assistant = await client.beta.assistants.retrieve(assistantId)
26-
if assistant.tool_resources and assistant.tool_resources.file_search and assistant.tool_resources.file_search.vector_store_ids:
27-
return assistant.tool_resources.file_search.vector_store_ids[0]
28-
vector_store = await client.vector_stores.create(name="sample-assistant-vector-store")
29-
await client.beta.assistants.update(
30-
assistantId,
31-
tool_resources={
32-
"file_search": {
33-
"vector_store_ids": [vector_store.id],
34-
},
35-
}
36-
)
37-
return vector_store.id
26+
27+
#TODO: Correctly return HTML, not JSON, from the routes below
3828

3929
@router.get("/")
4030
async def list_files(client: AsyncOpenAI = Depends(lambda: AsyncOpenAI())) -> List[Dict[str, str]]:
@@ -57,14 +47,16 @@ async def list_files(client: AsyncOpenAI = Depends(lambda: AsyncOpenAI())) -> Li
5747
})
5848
return files_array
5949

50+
51+
# Take a purpose parameter, defaulting to "assistants"
6052
@router.post("/")
61-
async def upload_file(file: UploadFile = File(...)) -> Dict[str, str]:
53+
async def upload_file(file: UploadFile = File(...), purpose: FilePurpose = Form(default="assistants")) -> Dict[str, str]:
6254
try:
6355
client = AsyncOpenAI()
6456
vector_store_id = await get_or_create_vector_store(assistant_id)
6557
openai_file = await client.files.create(
6658
file=file.file,
67-
purpose="assistants"
59+
purpose=purpose
6860
)
6961
await client.vector_stores.files.create(
7062
vector_store_id=vector_store_id,
@@ -74,11 +66,25 @@ async def upload_file(file: UploadFile = File(...)) -> Dict[str, str]:
7466
except Exception as e:
7567
raise HTTPException(status_code=500, detail=str(e))
7668

77-
async def stream_file_content(content: bytes) -> AsyncIterable[bytes]:
78-
yield content
69+
70+
@router.delete("/delete")
71+
async def delete_file(
72+
request: Request,
73+
fileId: str = Form(...),
74+
client: AsyncOpenAI = Depends(lambda: AsyncOpenAI())
75+
) -> Dict[str, str]:
76+
vector_store_id = await get_or_create_vector_store(assistant_id, client)
77+
await client.vector_stores.files.delete(vector_store_id=vector_store_id, file_id=fileId)
78+
return {"message": "File deleted successfully"}
79+
80+
81+
# --- Streaming file content ---
82+
83+
84+
7985

8086
@router.get("/{file_id}")
81-
async def get_file(
87+
async def download_assistant_file(
8288
file_id: str = Path(..., description="The ID of the file to retrieve"),
8389
client: AsyncOpenAI = Depends(lambda: AsyncOpenAI())
8490
) -> StreamingResponse:
@@ -96,18 +102,9 @@ async def get_file(
96102
except Exception as e:
97103
raise HTTPException(status_code=500, detail=str(e))
98104

99-
@router.delete("/delete")
100-
async def delete_file(
101-
request: Request,
102-
fileId: str = Form(...),
103-
client: AsyncOpenAI = Depends(lambda: AsyncOpenAI())
104-
) -> Dict[str, str]:
105-
vector_store_id = await get_or_create_vector_store(assistant_id, client)
106-
await client.vector_stores.files.delete(vector_store_id=vector_store_id, file_id=fileId)
107-
return {"message": "File deleted successfully"}
108105

109106
@router.get("/{file_id}/content")
110-
async def get_file_content(
107+
async def get_assistant_image_content(
111108
file_id: str,
112109
client: AsyncOpenAI = Depends(lambda: AsyncOpenAI())
113110
) -> StreamingResponse:
@@ -119,7 +116,7 @@ async def get_file_content(
119116
# Get the file content from OpenAI
120117
file_content = await client.files.content(file_id)
121118
file_bytes = file_content.read() # Remove await since read() returns bytes directly
122-
119+
123120
# Return the file content as a streaming response
124121
# Note: In a production environment, you might want to add caching
125122
return StreamingResponse(

routers/setup.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import logging
22
import os
3-
from typing import Optional
3+
from typing import Optional, List
44
from dotenv import load_dotenv
55
from fastapi import APIRouter, Depends, HTTPException, Form, Request
66
from fastapi.responses import RedirectResponse, Response
77
from fastapi.templating import Jinja2Templates
88
from openai import AsyncOpenAI
9+
from openai.types.beta import Assistant
910

10-
from utils.create_assistant import create_or_update_assistant, request as assistant_create_request
11+
from utils.create_assistant import create_or_update_assistant, ToolTypes
1112
from utils.create_assistant import update_env_file
1213

1314
# Configure logger
@@ -45,27 +46,44 @@ async def set_openai_api_key(api_key: str = Form(...)) -> RedirectResponse:
4546

4647
# Add new setup route
4748
@router.get("/")
48-
async def read_setup(request: Request, message: Optional[str] = "") -> Response:
49+
async def read_setup(
50+
request: Request,
51+
client: AsyncOpenAI = Depends(lambda: AsyncOpenAI()),
52+
message: Optional[str] = ""
53+
) -> Response:
4954
# Check if assistant ID is missing
55+
current_tools = []
5056
load_dotenv(override=True)
5157
openai_api_key = os.getenv("OPENAI_API_KEY")
5258
assistant_id = os.getenv("ASSISTANT_ID")
5359

5460
if not openai_api_key:
5561
message = "OpenAI API key is missing."
56-
elif not assistant_id:
57-
message = "Assistant ID is missing."
5862
else:
59-
message = "All set up!"
63+
if assistant_id:
64+
try:
65+
assistant = await client.beta.assistants.retrieve(assistant_id)
66+
current_tools = [tool.type for tool in assistant.tools]
67+
except Exception as e:
68+
logger.error(f"Failed to retrieve assistant {assistant_id}: {e}")
69+
# If we can't retrieve the assistant, proceed as if it doesn't exist
70+
assistant_id = None
71+
message = "Error retrieving existing assistant. Please create a new one."
6072

6173
return templates.TemplateResponse(
6274
"setup.html",
63-
{"request": request, "message": message}
75+
{
76+
"request": request,
77+
"message": message,
78+
"assistant_id": assistant_id,
79+
"current_tools": current_tools
80+
}
6481
)
6582

6683

6784
@router.post("/assistant")
6885
async def create_update_assistant(
86+
tool_types: List[ToolTypes] = Form(...),
6987
client: AsyncOpenAI = Depends(lambda: AsyncOpenAI())
7088
) -> RedirectResponse:
7189
"""
@@ -76,7 +94,7 @@ async def create_update_assistant(
7694
new_assistant_id = await create_or_update_assistant(
7795
client=client,
7896
assistant_id=current_assistant_id,
79-
request=assistant_create_request,
97+
tool_types=tool_types,
8098
logger=logger
8199
)
82100

static/styles.css

Lines changed: 52 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ body {
1919
}
2020

2121
body {
22+
font-family: system-ui, -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Oxygen, Ubuntu, Cantarell, "Open Sans", "Helvetica Neue", sans-serif;
2223
color: rgb(var(--foreground-rgb));
2324
}
2425

@@ -57,16 +58,8 @@ pre {
5758
.logo {
5859
width: 32px;
5960
height: 32px;
60-
position: absolute;
61-
margin: 16px;
62-
top: 0;
63-
right: 0;
64-
}
65-
66-
@media (max-width: 1100px) {
67-
.logo {
68-
display: none;
69-
}
61+
filter: invert(100%) sepia(0%) saturate(0%) hue-rotate(0deg) brightness(100%) contrast(100%);
62+
/* Removed absolute positioning, handled by nav flexbox */
7063
}
7164

7265
.main {
@@ -576,3 +569,52 @@ pre {
576569
border-radius: 0.5rem;
577570
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
578571
}
572+
573+
.checkboxGroup {
574+
display: flex;
575+
flex-direction: column;
576+
gap: 0.5rem;
577+
margin-bottom: 1rem;
578+
}
579+
580+
.checkboxGroup label {
581+
display: flex;
582+
align-items: center;
583+
gap: 0.5rem;
584+
}
585+
586+
.errorMessage {
587+
color: #e53e3e;
588+
font-size: 0.875rem;
589+
}
590+
591+
/* --- Nav --- */
592+
.nav {
593+
position: relative; /* Context for absolute logo if needed later */
594+
display: flex;
595+
align-items: center; /* Vertically center items */
596+
justify-content: flex-end; /* Pushes items to the right */
597+
gap: 16px; /* Space between nav-links div and logo */
598+
padding: 16px; /* Match logo margin */
599+
height: 64px; /* Logo height (32) + top/bottom padding (16*2) */
600+
box-sizing: border-box;
601+
width: 100%;
602+
background-color: #000; /* Black background */
603+
border-bottom: 1px solid #eee; /* Optional: separator */
604+
}
605+
606+
.nav-links {
607+
display: flex;
608+
gap: 16px; /* Space between links */
609+
}
610+
611+
.nav a {
612+
color: white;
613+
font-weight: bold;
614+
text-decoration: none;
615+
}
616+
617+
.nav a:hover {
618+
text-decoration: underline;
619+
}
620+
/* --- End Nav --- */

0 commit comments

Comments
 (0)