11import logging
22from dataclasses import dataclass
3+ from typing import Any , Dict , List , Optional , Union
34
45import openai
56from openai import OpenAI
67
78from .base_api import AbstractChatModel , BaseModelArgs
89
10+ type ContentItem = Dict [str , Any ]
11+ type Message = Dict [str , Union [str , List [ContentItem ]]]
12+
13+
14+ class MessageBuilder :
15+ def __init__ (self , role : str ):
16+ self .role = role
17+ self .content : List [ContentItem ] = []
18+ self .tool_call_id = None
19+
20+ @staticmethod
21+ def system () -> "MessageBuilder" :
22+ return MessageBuilder (role = "system" )
23+
24+ @staticmethod
25+ def user () -> "MessageBuilder" :
26+ return MessageBuilder (role = "user" )
27+
28+ @staticmethod
29+ def assistant () -> "MessageBuilder" :
30+ return MessageBuilder (role = "assistant" )
31+
32+ @staticmethod
33+ def tool () -> "MessageBuilder" :
34+ return MessageBuilder (role = "tool" )
35+
36+ def add_text (self , text : str ) -> "MessageBuilder" :
37+ self .content .append ({"text" : text })
38+ return self
39+
40+ def add_image (self , image : str ) -> "MessageBuilder" :
41+ self .content .append ({"image" : image })
42+ return self
43+
44+ def add_tool_id (self , tool_id : str ) -> "MessageBuilder" :
45+ self .tool_call_id = tool_id
46+ return self
47+
48+ def to_openai (self ) -> List [Message ]:
49+ content = []
50+ for item in self .content :
51+ if "text" in item :
52+ content .append ({"type" : "input_text" , "text" : item ["text" ]})
53+ elif "image" in item :
54+ content .append ({"type" : "input_image" , "image_url" : item ["image" ]})
55+ res = [{"role" : self .role , "content" : content }]
56+
57+ if self .role == "tool" :
58+ assert self .tool_call_id is not None , "Tool call ID is required for tool messages"
59+ # tool messages can only take text with openai
60+ # we need to split the first content element if it's text and use it
61+ # then open a new (user) message with the rest
62+ res [0 ]["tool_call_id" ] = self .tool_call_id
63+ text_content = (
64+ content .pop (0 )["text" ]
65+ if "text" in content [0 ]
66+ else "Tool call answer in next message"
67+ )
68+ res [0 ]["content" ] = text_content
69+ res .append ({"role" : "user" , "content" : content })
70+
71+ return res
72+
73+ def to_anthropic (self ) -> List [Message ]:
74+ content = []
75+ for item in self .content :
76+ if "text" in item :
77+ content .append ({"type" : "text" , "text" : item ["text" ]})
78+ elif "image" in item :
79+ content .append (
80+ {
81+ "type" : "image" ,
82+ "source" : {
83+ "type" : "base64" , # currently only base64 is supported
84+ "media_type" : "image/png" , # currently only png is supported
85+ "data" : item ["image" ],
86+ },
87+ }
88+ )
89+ res = [{"role" : self .role , "content" : content }]
90+
91+ if self .role == "tool" :
92+ assert self .tool_call_id is not None , "Tool call ID is required for tool messages"
93+ res [0 ]["role" ] = "user"
94+ res [0 ]["content" ] = {
95+ "type" : "tool_result" ,
96+ "tool_use_id" : self .tool_call_id ,
97+ "content" : res [0 ]["content" ],
98+ }
99+ return res
100+
9101
10102class ResponseModel (AbstractChatModel ):
11103 def __init__ (
@@ -29,15 +121,15 @@ def __call__(self, content: dict, temperature: float = None) -> dict:
29121 response = self .client .responses .create (
30122 model = self .model_name ,
31123 input = content ,
32- # temperature=temperature,
124+ temperature = temperature ,
33125 # previous_response_id=content.get("previous_response_id", None),
34126 max_output_tokens = self .max_tokens ,
35127 ** self .extra_kwargs ,
36128 tool_choice = "required" ,
37- reasoning = {
38- "effort" : "low" ,
39- "summary" : "detailed" ,
40- },
129+ # reasoning={
130+ # "effort": "low",
131+ # "summary": "detailed",
132+ # },
41133 )
42134 return response
43135 except openai .OpenAIError as e :
@@ -165,3 +257,60 @@ def make_model(self, extra_kwargs=None):
165257 max_tokens = self .max_new_tokens ,
166258 extra_kwargs = extra_kwargs ,
167259 )
260+
261+
262+ import anthropic
263+
264+
265+ class ClaudeResponseModel (ResponseModel ):
266+ def __init__ (
267+ self ,
268+ model_name ,
269+ api_key = None ,
270+ temperature = 0.5 ,
271+ max_tokens = 100 ,
272+ extra_kwargs = None ,
273+ ):
274+ super ().__init__ (
275+ model_name = model_name ,
276+ api_key = api_key ,
277+ temperature = temperature ,
278+ max_tokens = max_tokens ,
279+ extra_kwargs = extra_kwargs ,
280+ )
281+ self .client = anthropic .Client (api_key = api_key )
282+ self .model_name = model_name
283+ self .temperature = temperature
284+ self .max_tokens = max_tokens
285+ self .extra_kwargs = extra_kwargs or {}
286+ self .model_name = model_name
287+ self .api_key = api_key
288+
289+ def __call__ (self , messages : list [dict ], temperature : float = None ) -> dict :
290+ temperature = temperature if temperature is not None else self .temperature
291+ try :
292+ response = self .client .messages .create (
293+ model = self .model_name ,
294+ messages = messages ,
295+ temperature = temperature ,
296+ max_tokens = self .max_tokens ,
297+ ** self .extra_kwargs ,
298+ )
299+ return response
300+ except Exception as e :
301+ logging .error (f"Failed to get a response from the API: { e } " )
302+ raise e
303+
304+
305+ @dataclass
306+ class ClaudeResponseModelArgs (BaseModelArgs ):
307+ """Serializable object for instantiating a generic chat model with an OpenAI
308+ model."""
309+
310+ def make_model (self , extra_kwargs = None ):
311+ return ClaudeResponseModel (
312+ model_name = self .model_name ,
313+ temperature = self .temperature ,
314+ max_tokens = self .max_new_tokens ,
315+ extra_kwargs = extra_kwargs ,
316+ )
0 commit comments