Skip to content

Commit a5c4849

Browse files
Draft HTMX implementation
1 parent b957892 commit a5c4849

File tree

4 files changed

+43
-140
lines changed

4 files changed

+43
-140
lines changed

main.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from fastapi.templating import Jinja2Templates
88
from fastapi.responses import RedirectResponse
99
from routers import files, messages, tools, api_keys, assistants
10+
from utils.threads import create_thread
1011

1112

1213
logger = logging.getLogger("uvicorn.error")
@@ -55,17 +56,23 @@ async def read_home(request: Request):
5556
}
5657
)
5758

59+
# TODO: Implement some kind of thread id storage or management logic to allow
60+
# user to load an old thread, delete an old thread, etc. instead of start new
5861
@app.get("/basic-chat")
5962
async def read_basic_chat(request: Request, messages: list = [], thread_id: str = None):
6063
# Get assistant ID from environment variables
6164
load_dotenv()
6265
assistant_id = os.getenv("ASSISTANT_ID")
66+
67+
# Create a new assistant chat thread if no thread ID is provided
68+
if not thread_id or thread_id == "None" or thread_id == "null":
69+
thread_id: str = await create_thread()
6370

6471
return templates.TemplateResponse(
6572
"examples/basic-chat.html",
6673
{
6774
"request": request,
68-
"assistant_id": assistant_id, # Add assistant_id to template context
75+
"assistant_id": assistant_id,
6976
"messages": messages,
7077
"thread_id": thread_id
7178
}
@@ -76,6 +83,10 @@ async def read_file_search(request: Request, messages: list = [], thread_id: str
7683
# Get assistant ID from environment variables
7784
load_dotenv()
7885
assistant_id = os.getenv("ASSISTANT_ID")
86+
87+
# Create a new assistant chat thread if no thread ID is provided
88+
if not thread_id or thread_id == "None" or thread_id == "null":
89+
thread_id: str = await create_thread()
7990

8091
return templates.TemplateResponse(
8192
"examples/file-search.html",
@@ -92,6 +103,10 @@ async def read_function_calling(request: Request, messages: list = [], thread_id
92103
# Get assistant ID from environment variables
93104
load_dotenv()
94105
assistant_id = os.getenv("ASSISTANT_ID")
106+
107+
# Create a new assistant chat thread if no thread ID is provided
108+
if not thread_id or thread_id == "None" or thread_id == "null":
109+
thread_id: str = await create_thread()
95110

96111
# Define the condition class map
97112
conditionClassMap = {
@@ -122,6 +137,10 @@ async def read_all(request: Request, messages: list = [], thread_id: str = None)
122137
# Get assistant ID from environment variables
123138
load_dotenv()
124139
assistant_id = os.getenv("ASSISTANT_ID")
140+
141+
# Create a new assistant chat thread if no thread ID is provided
142+
if not thread_id or thread_id == "None" or thread_id == "null":
143+
thread_id: str = await create_thread()
125144

126145
return templates.TemplateResponse(
127146
"examples/all.html",

routers/messages.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import os
22
import logging
33
from dotenv import load_dotenv
4+
from fastapi.templating import Jinja2Templates
45
from fastapi import APIRouter, Form, HTTPException, Depends
56
from fastapi.responses import StreamingResponse
67
from openai import AsyncOpenAI
78
from openai.resources.beta.threads.runs.runs import AsyncAssistantStreamManager
89
import json
9-
from utils.threads import create_thread
1010

1111
logger: logging.Logger = logging.getLogger("uvicorn.error")
1212
logger.setLevel(logging.DEBUG)
@@ -17,28 +17,26 @@
1717
tags=["assistants_messages"]
1818
)
1919

20+
# Load Jinja2 templates
21+
templates = Jinja2Templates(directory="templates")
2022

2123
# Send a new message to a thread
22-
@router.post("/send_message")
24+
@router.post("/send")
2325
async def post_message(
2426
userInput: str = Form(...),
25-
thread_id: str | None = Form(None),
27+
thread_id: str = Form(),
2628
client: AsyncOpenAI = Depends(lambda: AsyncOpenAI())
2729
) -> dict:
28-
# Create a new assistant chat thread if no thread ID is provided
29-
if not thread_id or thread_id == "None" or thread_id == "null":
30-
thread_id: str = await create_thread()
31-
3230
# Create a new message in the thread
3331
await client.beta.threads.messages.create(
3432
thread_id=thread_id,
3533
role="user",
3634
content=userInput
3735
)
38-
39-
return {"thread_id": thread_id}
4036

41-
@router.get("/stream_response")
37+
return templates.TemplateResponse("components/chat-turn.html")
38+
39+
@router.get("/receive")
4240
async def stream_response(
4341
thread_id: str | None = None,
4442
client: AsyncOpenAI = Depends(lambda: AsyncOpenAI())
@@ -55,10 +53,10 @@ async def event_generator():
5553
)
5654
async with stream as stream_manager:
5755
async for text in stream_manager.text_deltas:
58-
yield f"data: {json.dumps({'text': text, 'thread_id': thread_id})}\n\n"
56+
yield f"data: {text}"
5957

6058
# Send a done event when the stream is complete
61-
yield f"data: {json.dumps({'complete': True})}\n\n"
59+
yield f"event: EndMessage"
6260

6361
return StreamingResponse(
6462
event_generator(),

templates/components/chat-turn.html

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
<div class="userMessage">{{ userInput }}</div>
2+
<div
3+
class="assistantMessage"
4+
hx-ext="sse"
5+
sse-connect="/assistants/{{ assistant_id }}/messages/receive"
6+
sse-swap="message"
7+
hx-swap="beforeend"
8+
></div>

templates/components/chat.html

Lines changed: 5 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
<div class="chatContainer">
2-
<div class="messages">
2+
<div id="messages" class="messages">
33
{% for msg in messages %}
44
{% if msg.role == "user" %}
55
<div class="userMessage">{{ msg.text }}</div>
@@ -17,10 +17,9 @@
1717
</div>
1818
{% endif %}
1919
{% endfor %}
20-
<div id="messagesEndRef"></div>
2120
</div>
2221
<form id="chatForm" class="inputForm clearfix">
23-
{% if thread_id is not none %}value="{{ thread_id }}"{% endif %}
22+
<input type="hidden" value="{{ thread_id }}">
2423
<input
2524
type="text"
2625
class="input"
@@ -31,133 +30,12 @@
3130
<button
3231
type="submit"
3332
class="button"
33+
hx-get="/assistants/{assistant_id}/messages/send"
34+
hx-target="#messages"
35+
hx-swap="beforeEnd"
3436
{% if inputDisabled %}disabled{% endif %}
3537
>
3638
Send
3739
</button>
3840
</form>
3941
</div>
40-
41-
<script>
42-
document.getElementById('chatForm').addEventListener('submit', function(e) {
43-
e.preventDefault();
44-
45-
const form = e.target;
46-
const input = form.querySelector('#userInput');
47-
let threadId = form.querySelector('input[name="thread_id"]')?.value;
48-
const messagesDiv = document.querySelector('.messages');
49-
50-
// Don't send empty messages
51-
if (!input.value.trim()) return;
52-
53-
// Append user message immediately
54-
const userMessageDiv = document.createElement('div');
55-
userMessageDiv.className = 'userMessage';
56-
userMessageDiv.textContent = input.value;
57-
messagesDiv.insertBefore(userMessageDiv, document.getElementById('messagesEndRef'));
58-
59-
// Store message and clear input before sending
60-
const messageText = input.value;
61-
input.value = '';
62-
63-
// Scroll to bottom after user message
64-
messagesEndRef.scrollIntoView({ behavior: 'smooth' });
65-
66-
// Create form data
67-
const formData = new FormData();
68-
formData.append('userInput', messageText);
69-
if (threadId) {
70-
formData.append('thread_id', threadId);
71-
}
72-
73-
// First send the message via POST
74-
fetch('/send_message', {
75-
method: 'POST',
76-
body: formData
77-
}).then(response => {
78-
if (!response.ok) {
79-
throw new Error('Network response was not ok');
80-
}
81-
return response.json();
82-
}).then(data => {
83-
// Update the thread_id if we got a new one
84-
if (data.thread_id) {
85-
const threadInput = form.querySelector('input[name="thread_id"]');
86-
if (!threadInput) {
87-
const newThreadInput = document.createElement('input');
88-
newThreadInput.type = 'hidden';
89-
newThreadInput.name = 'thread_id';
90-
newThreadInput.value = data.thread_id;
91-
form.appendChild(newThreadInput);
92-
} else {
93-
threadInput.value = data.thread_id;
94-
}
95-
threadId = data.thread_id;
96-
}
97-
98-
// Create URL params
99-
const urlParams = new URLSearchParams();
100-
if (threadId && threadId !== "None") {
101-
urlParams.append('thread_id', threadId);
102-
}
103-
104-
// Create and store EventSource reference
105-
const eventSource = new EventSource('/stream_response?' + urlParams.toString());
106-
107-
// Add cleanup when page is unloaded
108-
window.addEventListener('beforeunload', () => {
109-
if (eventSource) {
110-
eventSource.close();
111-
}
112-
});
113-
114-
let currentMessageDiv = null;
115-
116-
eventSource.onmessage = (event) => {
117-
const data = JSON.parse(event.data);
118-
119-
// Check if this is a completion message
120-
if (data.complete) {
121-
clearTimeout(streamTimeout);
122-
eventSource.close();
123-
return;
124-
}
125-
126-
// Create message div if it doesn't exist
127-
if (!currentMessageDiv) {
128-
currentMessageDiv = document.createElement('div');
129-
currentMessageDiv.className = 'assistantMessage';
130-
messagesDiv.insertBefore(currentMessageDiv, document.getElementById('messagesEndRef'));
131-
}
132-
133-
// Append new text
134-
currentMessageDiv.innerHTML += data.text;
135-
136-
// Scroll to bottom
137-
messagesEndRef.scrollIntoView({ behavior: 'smooth' });
138-
};
139-
140-
// Enhanced error handling
141-
eventSource.onerror = (error) => {
142-
console.error('EventSource failed:', error);
143-
eventSource.close();
144-
// Optionally add an error message to the chat
145-
const errorDiv = document.createElement('div');
146-
errorDiv.className = 'errorMessage';
147-
errorDiv.textContent = 'Message failed to send. Please try again.';
148-
messagesDiv.insertBefore(errorDiv, document.getElementById('messagesEndRef'));
149-
};
150-
151-
// Add timeout
152-
const streamTimeout = setTimeout(() => {
153-
if (eventSource.readyState !== EventSource.CLOSED) {
154-
eventSource.close();
155-
console.warn('Stream timed out');
156-
}
157-
}, 30000); // 30 second timeout
158-
}).catch(error => {
159-
console.error('Error:', error);
160-
// Maybe add some user-facing error message here
161-
});
162-
});
163-
</script>

0 commit comments

Comments
 (0)