11from typing import Protocol , AsyncIterator , List
22from openai import OpenAI
3- from .models import Message , ChatResponse
3+ from .models import Message , ChatResponse , TextContent , ImageContent
4+ import logging
5+ from datetime import datetime
6+
7+ logger = logging .getLogger (__name__ )
48
59class AIProvider (Protocol ):
610 """Protocol for AI providers"""
7- async def generate_response (self , messages : List [Message ]) -> ChatResponse : ...
8- async def generate_stream (self , messages : List [Message ]) -> AsyncIterator [str ]: ...
11+ def generate_response (self , messages : List [Message ]) -> ChatResponse : ...
12+ def generate_stream (self , messages : List [Message ]) -> AsyncIterator [str ]: ...
913
1014class OpenAIProvider :
1115 """OpenAI implementation of AIProvider"""
1216 def __init__ (self , api_key : str , api_base : str | None = None ):
1317 self .client = OpenAI (api_key = api_key , base_url = api_base )
1418
15- async def generate_response (self , messages : List [Message ]) -> ChatResponse :
16- response = await self .client .chat .completions .create (
17- model = "gpt-4o" ,
18- messages = [{"role" : m .role , "content" : m .content } for m in messages ]
19+ def _format_message (self , message : Message ) -> dict :
20+ """Format message for OpenAI API"""
21+ if isinstance (message .content , str ):
22+ return {"role" : message .role , "content" : message .content }
23+
24+ # For messages with text and images
25+ formatted_content = []
26+ for item in message .content :
27+ if isinstance (item , TextContent ):
28+ formatted_content .append ({"type" : "text" , "text" : item .text })
29+ elif isinstance (item , ImageContent ):
30+ formatted_content .append ({
31+ "type" : "image_url" ,
32+ "image_url" : item .image_url
33+ })
34+ return {"role" : message .role , "content" : formatted_content }
35+
36+ def generate_response (self , messages : List [Message ]) -> ChatResponse :
37+ """Generate a response for messages"""
38+ # Use the model from the last message, or default to claude-3-5-sonnet
39+ model = next ((m .model for m in reversed (messages ) if m .model ), "claude-3-5-sonnet" )
40+
41+ formatted_messages = [self ._format_message (m ) for m in messages ]
42+ response = self .client .chat .completions .create (
43+ model = model ,
44+ messages = formatted_messages
1945 )
46+
2047 return ChatResponse (
2148 content = response .choices [0 ].message .content ,
22- model = response .model
49+ model = response .model ,
50+ timestamp = datetime .utcnow ()
2351 )
2452
2553 async def generate_stream (self , messages : List [Message ]) -> AsyncIterator [str ]:
26- stream = await self .client .chat .completions .create (
27- model = "gpt-4o" ,
28- messages = [{"role" : m .role , "content" : m .content } for m in messages ],
29- stream = True
30- )
31- async for chunk in stream :
32- if chunk .choices [0 ].delta .content :
33- yield chunk .choices [0 ].delta .content
54+ """Stream response for messages"""
55+ try :
56+ # Use the model from the last message, or default to claude-3-5-sonnet
57+ model = next ((m .model for m in reversed (messages ) if m .model ), "claude-3-5-sonnet" )
58+
59+ formatted_messages = [self ._format_message (m ) for m in messages ]
60+ stream = self .client .chat .completions .create (
61+ model = model ,
62+ messages = formatted_messages ,
63+ stream = True
64+ )
65+
66+ for chunk in stream :
67+ if chunk .choices [0 ].delta .content :
68+ content = chunk .choices [0 ].delta .content
69+ yield content
70+ except Exception as e :
71+ logger .error (f"Error in stream: { str (e )} " )
72+ raise
0 commit comments