1+ import json
12import logging
3+ from abc import ABC , abstractmethod
24from dataclasses import dataclass
35from typing import Any , Dict , List , Optional , Union
46
7+ import anthropic
58import openai
9+ from anthropic import Anthropic
610from openai import OpenAI
711
8- from .base_api import AbstractChatModel , BaseModelArgs
12+ from .base_api import BaseModelArgs
913
1014type ContentItem = Dict [str , Any ]
1115type Message = Dict [str , Union [str , List [ContentItem ]]]
@@ -134,29 +138,61 @@ def to_markdown(self) -> str:
134138 return res
135139
136140
137- class ResponseModel ( AbstractChatModel ):
141+ class BaseResponseModel ( ABC ):
138142 def __init__ (
139143 self ,
140- model_name ,
141- api_key = None ,
142- temperature = 0.5 ,
143- max_tokens = 100 ,
144- extra_kwargs = None ,
144+ model_name : str ,
145+ api_key : Optional [ str ] = None ,
146+ temperature : float = 0.5 ,
147+ max_tokens : int = 100 ,
148+ extra_kwargs : Optional [ Dict [ str , Any ]] = None ,
145149 ):
146150 self .model_name = model_name
147151 self .api_key = api_key
148152 self .temperature = temperature
149153 self .max_tokens = max_tokens
150154 self .extra_kwargs = extra_kwargs or {}
155+
156+ def __call__ (self , messages : list [dict | MessageBuilder ]) -> dict :
157+ """Make a call to the model and return the parsed response."""
158+ response = self ._call_api (messages )
159+ return self ._parse_response (response )
160+
161+ @abstractmethod
162+ def _call_api (self , messages : list [dict | MessageBuilder ]) -> dict :
163+ """Make a call to the model API and return the raw response."""
164+ pass
165+
166+ @abstractmethod
167+ def _parse_response (self , response : dict ) -> dict :
168+ """Parse the raw response from the model API and return a structured response."""
169+ pass
170+
171+
172+ class OpenAIResponseModel (BaseResponseModel ):
173+ def __init__ (
174+ self ,
175+ model_name : str ,
176+ api_key : Optional [str ] = None ,
177+ temperature : float = 0.5 ,
178+ max_tokens : int = 100 ,
179+ extra_kwargs : Optional [Dict [str , Any ]] = None ,
180+ ):
181+ super ().__init__ (
182+ model_name = model_name ,
183+ api_key = api_key ,
184+ temperature = temperature ,
185+ max_tokens = max_tokens ,
186+ extra_kwargs = extra_kwargs ,
187+ )
151188 self .client = OpenAI (api_key = api_key )
152189
153- def __call__ (self , content : dict , temperature : float = None ) -> dict :
154- temperature = temperature if temperature is not None else self .temperature
190+ def _call_api (self , messages : list [dict | MessageBuilder ]) -> dict :
155191 try :
156192 response = self .client .responses .create (
157193 model = self .model_name ,
158- input = content ,
159- temperature = temperature ,
194+ input = messages ,
195+ temperature = self . temperature ,
160196 # previous_response_id=content.get("previous_response_id", None),
161197 max_output_tokens = self .max_tokens ,
162198 ** self .extra_kwargs ,
@@ -171,10 +207,39 @@ def __call__(self, content: dict, temperature: float = None) -> dict:
171207 logging .error (f"Failed to get a response from the API: { e } " )
172208 raise e
173209
210+ def _parse_response (self , response : dict ) -> dict :
211+ result = {
212+ "raw_response" : response ,
213+ "think" : "" ,
214+ "action" : "noop()" ,
215+ "last_computer_call_id" : None ,
216+ "assistant_message" : {
217+ "role" : "assistant" ,
218+ "content" : response .output ,
219+ },
220+ }
221+ for output in response .output :
222+ if output .type == "function_call" :
223+ arguments = json .loads (output .arguments )
224+ result ["action" ] = (
225+ f"{ output .name } ({ ", " .join ([f"{ k } ={ v } " for k , v in arguments .items ()])} )"
226+ )
227+ result ["last_computer_call_id" ] = output .call_id
228+ break
229+ elif output .type == "reasoning" :
230+ if len (output .summary ) > 0 :
231+ result ["think" ] += output .summary [0 ].text + "\n "
232+ return result
233+
174234
175- class OpenAIResponseModel ( ResponseModel ):
235+ class ClaudeResponseModel ( BaseResponseModel ):
176236 def __init__ (
177- self , model_name , api_key = None , temperature = 0.5 , max_tokens = 100 , extra_kwargs = None
237+ self ,
238+ model_name : str ,
239+ api_key : Optional [str ] = None ,
240+ temperature : float = 0.5 ,
241+ max_tokens : int = 100 ,
242+ extra_kwargs : Optional [Dict [str , Any ]] = None ,
178243 ):
179244 super ().__init__ (
180245 model_name = model_name ,
@@ -183,37 +248,45 @@ def __init__(
183248 max_tokens = max_tokens ,
184249 extra_kwargs = extra_kwargs ,
185250 )
251+ self .client = Anthropic (api_key = api_key )
252+
253+ def _call_api (self , messages : list [dict | MessageBuilder ]) -> dict :
254+ try :
255+ response = self .client .messages .create (
256+ model = self .model_name ,
257+ messages = messages ,
258+ temperature = self .temperature ,
259+ max_tokens = self .max_tokens ,
260+ ** self .extra_kwargs ,
261+ )
262+ return response
263+ except Exception as e :
264+ logging .error (f"Failed to get a response from the API: { e } " )
265+ raise e
266+
267+ def _parse_response (self , response : dict ) -> dict :
268+ result = {
269+ "raw_response" : response ,
270+ "think" : "" ,
271+ "action" : "noop()" ,
272+ "last_computer_call_id" : None ,
273+ "assistant_message" : {
274+ "role" : "assistant" ,
275+ "content" : response .content ,
276+ },
277+ }
278+ for output in response .content :
279+ if output .type == "tool_use" :
280+ result ["action" ] = (
281+ f"{ output .name } ({ ', ' .join ([f'{ k } =\" { v } \" ' if isinstance (v , str ) else f'{ k } ={ v } ' for k , v in output .input .items ()])} )"
282+ )
283+ result ["last_computer_call_id" ] = output .id
284+ elif output .type == "text" :
285+ result ["think" ] += output .text
286+ return result
287+
186288
187- def __call__ (self , messages : list [dict ], temperature : float = None ) -> dict :
188- return super ().__call__ (messages , temperature )
189- # outputs = response.output
190- # last_computer_call_id = None
191- # answer_type = "call"
192- # reasoning = "No reasoning"
193- # for output in outputs:
194- # if output.type == "reasoning":
195- # reasoning = output.summary[0].text
196- # elif output.type == "computer_call":
197- # action = output.action
198- # last_computer_call_id = output.call_id
199- # res = response_to_text(action)
200- # elif output.type == "message":
201- # res = "noop()"
202- # answer_type = "message"
203- # else:
204- # logging.warning(f"Unrecognized output type: {output.type}")
205- # continue
206- # return {
207- # "think": reasoning,
208- # "action": res,
209- # "last_computer_call_id": last_computer_call_id,
210- # "last_response_id": response.id,
211- # "outputs": outputs,
212- # "answer_type": answer_type,
213- # }
214-
215-
216- def response_to_text (action ):
289+ def cua_response_to_text (action ):
217290 """
218291 Given a computer action (e.g., click, double_click, scroll, etc.),
219292 convert it to a text description.
@@ -294,49 +367,6 @@ def make_model(self, extra_kwargs=None):
294367 )
295368
296369
297- import anthropic
298-
299-
300- class ClaudeResponseModel (ResponseModel ):
301- def __init__ (
302- self ,
303- model_name ,
304- api_key = None ,
305- temperature = 0.5 ,
306- max_tokens = 100 ,
307- extra_kwargs = None ,
308- ):
309- super ().__init__ (
310- model_name = model_name ,
311- api_key = api_key ,
312- temperature = temperature ,
313- max_tokens = max_tokens ,
314- extra_kwargs = extra_kwargs ,
315- )
316- self .client = anthropic .Client (api_key = api_key )
317- self .model_name = model_name
318- self .temperature = temperature
319- self .max_tokens = max_tokens
320- self .extra_kwargs = extra_kwargs or {}
321- self .model_name = model_name
322- self .api_key = api_key
323-
324- def __call__ (self , messages : list [dict ], temperature : float = None ) -> dict :
325- temperature = temperature if temperature is not None else self .temperature
326- try :
327- response = self .client .messages .create (
328- model = self .model_name ,
329- messages = messages ,
330- temperature = temperature ,
331- max_tokens = self .max_tokens ,
332- ** self .extra_kwargs ,
333- )
334- return response
335- except Exception as e :
336- logging .error (f"Failed to get a response from the API: { e } " )
337- raise e
338-
339-
340370@dataclass
341371class ClaudeResponseModelArgs (BaseModelArgs ):
342372 """Serializable object for instantiating a generic chat model with an OpenAI
0 commit comments