1515type Message = Dict [str , Union [str , List [ContentItem ]]]
1616
1717
18+ @dataclass
19+ class ResponseLLMOutput :
20+ """Serializable object for the output of a response LLM."""
21+
22+ raw_response : Any
23+ think : str
24+ action : str
25+ last_computer_call_id : str
26+ assistant_message : Any
27+
28+
1829class MessageBuilder :
1930 def __init__ (self , role : str ):
2031 self .role = role
@@ -63,13 +74,17 @@ def to_openai(self) -> List[Message]:
6374 # tool messages can only take text with openai
6475 # we need to split the first content element if it's text and use it
6576 # then open a new (user) message with the rest
66- res [0 ]["tool_call_id" ] = self .tool_call_id
77+ # a function_call_output dict has keys "call_id", "type" and "output"
78+ res [0 ]["call_id" ] = self .tool_call_id
79+ res [0 ]["type" ] = "function_call_output"
80+ res [0 ].pop ("role" , None ) # make sure to remove role
6781 text_content = (
6882 content .pop (0 )["text" ]
6983 if "text" in content [0 ]
7084 else "Tool call answer in next message"
7185 )
72- res [0 ]["content" ] = text_content
86+ res [0 ]["output" ] = text_content
87+ res [0 ].pop ("content" , None ) # make sure to remove content
7388 res .append ({"role" : "user" , "content" : content })
7489
7590 return res
@@ -116,6 +131,8 @@ def to_anthropic(self) -> List[Message]:
116131 ]
117132 return res
118133
134+ def to_chat_completion (self ) -> List [Message ]: ...
135+
119136 def to_markdown (self ) -> str :
120137 content = []
121138 for item in self .content :
@@ -159,12 +176,12 @@ def __call__(self, messages: list[dict | MessageBuilder]) -> dict:
159176 return self ._parse_response (response )
160177
161178 @abstractmethod
162- def _call_api (self , messages : list [dict | MessageBuilder ]) -> dict :
179+ def _call_api (self , messages : list [dict | MessageBuilder ]) -> Any :
163180 """Make a call to the model API and return the raw response."""
164181 pass
165182
166183 @abstractmethod
167- def _parse_response (self , response : dict ) -> dict :
184+ def _parse_response (self , response : Any ) -> ResponseLLMOutput :
168185 """Parse the raw response from the model API and return a structured response."""
169186 pass
170187
@@ -187,11 +204,17 @@ def __init__(
187204 )
188205 self .client = OpenAI (api_key = api_key )
189206
190- def _call_api (self , messages : list [dict | MessageBuilder ]) -> dict :
207+ def _call_api (self , messages : list [Any | MessageBuilder ]) -> dict :
208+ input = []
209+ for msg in messages :
210+ if isinstance (msg , MessageBuilder ):
211+ input += msg .to_openai ()
212+ else :
213+ input .append (msg )
191214 try :
192215 response = self .client .responses .create (
193216 model = self .model_name ,
194- input = messages ,
217+ input = input ,
195218 temperature = self .temperature ,
196219 # previous_response_id=content.get("previous_response_id", None),
197220 max_output_tokens = self .max_tokens ,
@@ -208,27 +231,25 @@ def _call_api(self, messages: list[dict | MessageBuilder]) -> dict:
208231 raise e
209232
210233 def _parse_response (self , response : dict ) -> dict :
211- result = {
212- "raw_response" : response ,
213- "think" : "" ,
214- "action" : "noop()" ,
215- "last_computer_call_id" : None ,
216- "assistant_message" : {
217- "role" : "assistant" ,
218- "content" : response .output ,
219- },
220- }
234+ result = ResponseLLMOutput (
235+ raw_response = response ,
236+ think = "" ,
237+ action = "noop()" ,
238+ last_computer_call_id = None ,
239+ assistant_message = None ,
240+ )
221241 for output in response .output :
222242 if output .type == "function_call" :
223243 arguments = json .loads (output .arguments )
224- result [ " action" ] = (
244+ result . action = (
225245 f"{ output .name } ({ ", " .join ([f"{ k } ={ v } " for k , v in arguments .items ()])} )"
226246 )
227- result ["last_computer_call_id" ] = output .call_id
247+ result .last_computer_call_id = output .call_id
248+ result .assistant_message = output
228249 break
229250 elif output .type == "reasoning" :
230251 if len (output .summary ) > 0 :
231- result [ " think" ] += output .summary [0 ].text + "\n "
252+ result . think += output .summary [0 ].text + "\n "
232253 return result
233254
234255
@@ -251,10 +272,16 @@ def __init__(
251272 self .client = Anthropic (api_key = api_key )
252273
253274 def _call_api (self , messages : list [dict | MessageBuilder ]) -> dict :
275+ input = []
276+ for msg in messages :
277+ if isinstance (msg , MessageBuilder ):
278+ input += msg .to_anthropic ()
279+ else :
280+ input .append (msg )
254281 try :
255282 response = self .client .messages .create (
256283 model = self .model_name ,
257- messages = messages ,
284+ messages = input ,
258285 temperature = self .temperature ,
259286 max_tokens = self .max_tokens ,
260287 ** self .extra_kwargs ,
@@ -265,24 +292,22 @@ def _call_api(self, messages: list[dict | MessageBuilder]) -> dict:
265292 raise e
266293
267294 def _parse_response (self , response : dict ) -> dict :
268- result = {
269- " raw_response" : response ,
270- " think" : "" ,
271- " action" : "noop()" ,
272- " last_computer_call_id" : None ,
273- " assistant_message" : {
295+ result = ResponseLLMOutput (
296+ raw_response = response ,
297+ think = "" ,
298+ action = "noop()" ,
299+ last_computer_call_id = None ,
300+ assistant_message = {
274301 "role" : "assistant" ,
275302 "content" : response .content ,
276303 },
277- }
304+ )
278305 for output in response .content :
279306 if output .type == "tool_use" :
280- result ["action" ] = (
281- f"{ output .name } ({ ', ' .join ([f'{ k } =\" { v } \" ' if isinstance (v , str ) else f'{ k } ={ v } ' for k , v in output .input .items ()])} )"
282- )
283- result ["last_computer_call_id" ] = output .id
307+ result .action = f"{ output .name } ({ ', ' .join ([f'{ k } =\" { v } \" ' if isinstance (v , str ) else f'{ k } ={ v } ' for k , v in output .input .items ()])} )"
308+ result .last_computer_call_id = output .id
284309 elif output .type == "text" :
285- result [ " think" ] += output .text
310+ result . think += output .text
286311 return result
287312
288313
@@ -358,6 +383,8 @@ class OpenAIResponseModelArgs(BaseModelArgs):
358383 """Serializable object for instantiating a generic chat model with an OpenAI
359384 model."""
360385
386+ api = "openai"
387+
361388 def make_model (self , extra_kwargs = None ):
362389 return OpenAIResponseModel (
363390 model_name = self .model_name ,
@@ -372,6 +399,8 @@ class ClaudeResponseModelArgs(BaseModelArgs):
372399 """Serializable object for instantiating a generic chat model with an OpenAI
373400 model."""
374401
402+ api = "anthropic"
403+
375404 def make_model (self , extra_kwargs = None ):
376405 return ClaudeResponseModel (
377406 model_name = self .model_name ,
0 commit comments