Skip to content

Commit 57b7b78

Browse files
committed
refactor: restructure the chat and conversations.
Fix regression issues along the way
1 parent 166e355 commit 57b7b78

File tree

21 files changed

+992
-314
lines changed

21 files changed

+992
-314
lines changed

chatbot-api/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
"""Package initialization file"""
1+
"""ChatBot API package"""

chatbot-api/src/app.py

Lines changed: 0 additions & 222 deletions
This file was deleted.

chatbot-api/src/chat/models.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,21 @@
11
from dataclasses import dataclass
2-
from typing import List, Optional
2+
from typing import List, Optional, Union, Dict
33
from datetime import datetime
44

5+
@dataclass
6+
class ImageContent:
7+
type: str = "image_url"
8+
image_url: Dict[str, str] = None
9+
10+
@dataclass
11+
class TextContent:
12+
type: str = "text"
13+
text: str = ""
14+
515
@dataclass
616
class Message:
717
role: str
8-
content: str
18+
content: Union[str, List[Union[TextContent, ImageContent]]]
919
model: Optional[str] = None
1020
timestamp: datetime = datetime.utcnow()
1121

chatbot-api/src/chat/provider.py

Lines changed: 55 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,72 @@
11
from typing import Protocol, AsyncIterator, List
22
from openai import OpenAI
3-
from .models import Message, ChatResponse
3+
from .models import Message, ChatResponse, TextContent, ImageContent
4+
import logging
5+
from datetime import datetime
6+
7+
logger = logging.getLogger(__name__)
48

59
class AIProvider(Protocol):
610
"""Protocol for AI providers"""
7-
async def generate_response(self, messages: List[Message]) -> ChatResponse: ...
8-
async def generate_stream(self, messages: List[Message]) -> AsyncIterator[str]: ...
11+
def generate_response(self, messages: List[Message]) -> ChatResponse: ...
12+
def generate_stream(self, messages: List[Message]) -> AsyncIterator[str]: ...
913

1014
class OpenAIProvider:
1115
"""OpenAI implementation of AIProvider"""
1216
def __init__(self, api_key: str, api_base: str | None = None):
1317
self.client = OpenAI(api_key=api_key, base_url=api_base)
1418

15-
async def generate_response(self, messages: List[Message]) -> ChatResponse:
16-
response = await self.client.chat.completions.create(
17-
model="gpt-4o",
18-
messages=[{"role": m.role, "content": m.content} for m in messages]
19+
def _format_message(self, message: Message) -> dict:
20+
"""Format message for OpenAI API"""
21+
if isinstance(message.content, str):
22+
return {"role": message.role, "content": message.content}
23+
24+
# For messages with text and images
25+
formatted_content = []
26+
for item in message.content:
27+
if isinstance(item, TextContent):
28+
formatted_content.append({"type": "text", "text": item.text})
29+
elif isinstance(item, ImageContent):
30+
formatted_content.append({
31+
"type": "image_url",
32+
"image_url": item.image_url
33+
})
34+
return {"role": message.role, "content": formatted_content}
35+
36+
def generate_response(self, messages: List[Message]) -> ChatResponse:
37+
"""Generate a response for messages"""
38+
# Use the model from the last message, or default to claude-3-5-sonnet
39+
model = next((m.model for m in reversed(messages) if m.model), "claude-3-5-sonnet")
40+
41+
formatted_messages = [self._format_message(m) for m in messages]
42+
response = self.client.chat.completions.create(
43+
model=model,
44+
messages=formatted_messages
1945
)
46+
2047
return ChatResponse(
2148
content=response.choices[0].message.content,
22-
model=response.model
49+
model=response.model,
50+
timestamp=datetime.utcnow()
2351
)
2452

2553
async def generate_stream(self, messages: List[Message]) -> AsyncIterator[str]:
26-
stream = await self.client.chat.completions.create(
27-
model="gpt-4o",
28-
messages=[{"role": m.role, "content": m.content} for m in messages],
29-
stream=True
30-
)
31-
async for chunk in stream:
32-
if chunk.choices[0].delta.content:
33-
yield chunk.choices[0].delta.content
54+
"""Stream response for messages"""
55+
try:
56+
# Use the model from the last message, or default to claude-3-5-sonnet
57+
model = next((m.model for m in reversed(messages) if m.model), "claude-3-5-sonnet")
58+
59+
formatted_messages = [self._format_message(m) for m in messages]
60+
stream = self.client.chat.completions.create(
61+
model=model,
62+
messages=formatted_messages,
63+
stream=True
64+
)
65+
66+
for chunk in stream:
67+
if chunk.choices[0].delta.content:
68+
content = chunk.choices[0].delta.content
69+
yield content
70+
except Exception as e:
71+
logger.error(f"Error in stream: {str(e)}")
72+
raise

0 commit comments

Comments
 (0)