Skip to content

Commit 54ec412

Browse files
committed
changed LLM structure to be more versatile
1 parent 4e973ac commit 54ec412

File tree

2 files changed

+121
-111
lines changed

2 files changed

+121
-111
lines changed

src/agentlab/agents/tool_use_agent/agent.py

Lines changed: 5 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -188,36 +188,16 @@ def get_action(self, obs: Any) -> float:
188188
temperature=self.temperature,
189189
)
190190

191-
action = "noop()"
192-
think = ""
193-
# openai
194-
# for output in response.output:
195-
# if output.type == "function_call":
196-
# arguments = json.loads(output.arguments)
197-
# action = f"{output.name}({", ".join([f"{k}={v}" for k, v in arguments.items()])})"
198-
# self.previous_call_id = output.call_id
199-
# self.messages.append(output)
200-
# break
201-
# elif output.type == "reasoning":
202-
# if len(output.summary) > 0:
203-
# think += output.summary[0].text + "\n"
204-
# self.messages.append(output)
205-
206-
# anthropic
207-
for output in response.content:
208-
if output.type == "text":
209-
think += output.text
210-
elif output.type == "tool_use":
211-
action = f"{output.name}({', '.join([f'{k}=\"{v}\"' if isinstance(v, str) else f'{k}={v}' for k, v in output.input.items()])})"
212-
self.previous_call_id = output.id
213-
214-
self.messages.append({"role": "assistant", "content": response.content})
191+
action = response["action"]
192+
think = response["think"]
193+
self.previous_call_id = response["last_computer_call_id"]
194+
self.messages.append(response["assistant_message"])
215195

216196
return (
217197
action,
218198
bgym.AgentInfo(
219199
think=think,
220-
chat_messages=[],
200+
chat_messages=self.messages,
221201
stats={},
222202
),
223203
)

src/agentlab/llm/response_api.py

Lines changed: 116 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
1+
import json
12
import logging
3+
from abc import ABC, abstractmethod
24
from dataclasses import dataclass
35
from typing import Any, Dict, List, Optional, Union
46

7+
import anthropic
58
import openai
9+
from anthropic import Anthropic
610
from openai import OpenAI
711

8-
from .base_api import AbstractChatModel, BaseModelArgs
12+
from .base_api import BaseModelArgs
913

1014
type ContentItem = Dict[str, Any]
1115
type Message = Dict[str, Union[str, List[ContentItem]]]
@@ -134,29 +138,61 @@ def to_markdown(self) -> str:
134138
return res
135139

136140

137-
class ResponseModel(AbstractChatModel):
141+
class BaseResponseModel(ABC):
138142
def __init__(
139143
self,
140-
model_name,
141-
api_key=None,
142-
temperature=0.5,
143-
max_tokens=100,
144-
extra_kwargs=None,
144+
model_name: str,
145+
api_key: Optional[str] = None,
146+
temperature: float = 0.5,
147+
max_tokens: int = 100,
148+
extra_kwargs: Optional[Dict[str, Any]] = None,
145149
):
146150
self.model_name = model_name
147151
self.api_key = api_key
148152
self.temperature = temperature
149153
self.max_tokens = max_tokens
150154
self.extra_kwargs = extra_kwargs or {}
155+
156+
def __call__(self, messages: list[dict | MessageBuilder]) -> dict:
157+
"""Make a call to the model and return the parsed response."""
158+
response = self._call_api(messages)
159+
return self._parse_response(response)
160+
161+
@abstractmethod
162+
def _call_api(self, messages: list[dict | MessageBuilder]) -> dict:
163+
"""Make a call to the model API and return the raw response."""
164+
pass
165+
166+
@abstractmethod
167+
def _parse_response(self, response: dict) -> dict:
168+
"""Parse the raw response from the model API and return a structured response."""
169+
pass
170+
171+
172+
class OpenAIResponseModel(BaseResponseModel):
173+
def __init__(
174+
self,
175+
model_name: str,
176+
api_key: Optional[str] = None,
177+
temperature: float = 0.5,
178+
max_tokens: int = 100,
179+
extra_kwargs: Optional[Dict[str, Any]] = None,
180+
):
181+
super().__init__(
182+
model_name=model_name,
183+
api_key=api_key,
184+
temperature=temperature,
185+
max_tokens=max_tokens,
186+
extra_kwargs=extra_kwargs,
187+
)
151188
self.client = OpenAI(api_key=api_key)
152189

153-
def __call__(self, content: dict, temperature: float = None) -> dict:
154-
temperature = temperature if temperature is not None else self.temperature
190+
def _call_api(self, messages: list[dict | MessageBuilder]) -> dict:
155191
try:
156192
response = self.client.responses.create(
157193
model=self.model_name,
158-
input=content,
159-
temperature=temperature,
194+
input=messages,
195+
temperature=self.temperature,
160196
# previous_response_id=content.get("previous_response_id", None),
161197
max_output_tokens=self.max_tokens,
162198
**self.extra_kwargs,
@@ -171,10 +207,39 @@ def __call__(self, content: dict, temperature: float = None) -> dict:
171207
logging.error(f"Failed to get a response from the API: {e}")
172208
raise e
173209

210+
def _parse_response(self, response: dict) -> dict:
211+
result = {
212+
"raw_response": response,
213+
"think": "",
214+
"action": "noop()",
215+
"last_computer_call_id": None,
216+
"assistant_message": {
217+
"role": "assistant",
218+
"content": response.output,
219+
},
220+
}
221+
for output in response.output:
222+
if output.type == "function_call":
223+
arguments = json.loads(output.arguments)
224+
result["action"] = (
225+
f"{output.name}({", ".join([f"{k}={v}" for k, v in arguments.items()])})"
226+
)
227+
result["last_computer_call_id"] = output.call_id
228+
break
229+
elif output.type == "reasoning":
230+
if len(output.summary) > 0:
231+
result["think"] += output.summary[0].text + "\n"
232+
return result
233+
174234

175-
class OpenAIResponseModel(ResponseModel):
235+
class ClaudeResponseModel(BaseResponseModel):
176236
def __init__(
177-
self, model_name, api_key=None, temperature=0.5, max_tokens=100, extra_kwargs=None
237+
self,
238+
model_name: str,
239+
api_key: Optional[str] = None,
240+
temperature: float = 0.5,
241+
max_tokens: int = 100,
242+
extra_kwargs: Optional[Dict[str, Any]] = None,
178243
):
179244
super().__init__(
180245
model_name=model_name,
@@ -183,37 +248,45 @@ def __init__(
183248
max_tokens=max_tokens,
184249
extra_kwargs=extra_kwargs,
185250
)
251+
self.client = Anthropic(api_key=api_key)
252+
253+
def _call_api(self, messages: list[dict | MessageBuilder]) -> dict:
254+
try:
255+
response = self.client.messages.create(
256+
model=self.model_name,
257+
messages=messages,
258+
temperature=self.temperature,
259+
max_tokens=self.max_tokens,
260+
**self.extra_kwargs,
261+
)
262+
return response
263+
except Exception as e:
264+
logging.error(f"Failed to get a response from the API: {e}")
265+
raise e
266+
267+
def _parse_response(self, response: dict) -> dict:
268+
result = {
269+
"raw_response": response,
270+
"think": "",
271+
"action": "noop()",
272+
"last_computer_call_id": None,
273+
"assistant_message": {
274+
"role": "assistant",
275+
"content": response.content,
276+
},
277+
}
278+
for output in response.content:
279+
if output.type == "tool_use":
280+
result["action"] = (
281+
f"{output.name}({', '.join([f'{k}=\"{v}\"' if isinstance(v, str) else f'{k}={v}' for k, v in output.input.items()])})"
282+
)
283+
result["last_computer_call_id"] = output.id
284+
elif output.type == "text":
285+
result["think"] += output.text
286+
return result
287+
186288

187-
def __call__(self, messages: list[dict], temperature: float = None) -> dict:
188-
return super().__call__(messages, temperature)
189-
# outputs = response.output
190-
# last_computer_call_id = None
191-
# answer_type = "call"
192-
# reasoning = "No reasoning"
193-
# for output in outputs:
194-
# if output.type == "reasoning":
195-
# reasoning = output.summary[0].text
196-
# elif output.type == "computer_call":
197-
# action = output.action
198-
# last_computer_call_id = output.call_id
199-
# res = response_to_text(action)
200-
# elif output.type == "message":
201-
# res = "noop()"
202-
# answer_type = "message"
203-
# else:
204-
# logging.warning(f"Unrecognized output type: {output.type}")
205-
# continue
206-
# return {
207-
# "think": reasoning,
208-
# "action": res,
209-
# "last_computer_call_id": last_computer_call_id,
210-
# "last_response_id": response.id,
211-
# "outputs": outputs,
212-
# "answer_type": answer_type,
213-
# }
214-
215-
216-
def response_to_text(action):
289+
def cua_response_to_text(action):
217290
"""
218291
Given a computer action (e.g., click, double_click, scroll, etc.),
219292
convert it to a text description.
@@ -294,49 +367,6 @@ def make_model(self, extra_kwargs=None):
294367
)
295368

296369

297-
import anthropic
298-
299-
300-
class ClaudeResponseModel(ResponseModel):
301-
def __init__(
302-
self,
303-
model_name,
304-
api_key=None,
305-
temperature=0.5,
306-
max_tokens=100,
307-
extra_kwargs=None,
308-
):
309-
super().__init__(
310-
model_name=model_name,
311-
api_key=api_key,
312-
temperature=temperature,
313-
max_tokens=max_tokens,
314-
extra_kwargs=extra_kwargs,
315-
)
316-
self.client = anthropic.Client(api_key=api_key)
317-
self.model_name = model_name
318-
self.temperature = temperature
319-
self.max_tokens = max_tokens
320-
self.extra_kwargs = extra_kwargs or {}
321-
self.model_name = model_name
322-
self.api_key = api_key
323-
324-
def __call__(self, messages: list[dict], temperature: float = None) -> dict:
325-
temperature = temperature if temperature is not None else self.temperature
326-
try:
327-
response = self.client.messages.create(
328-
model=self.model_name,
329-
messages=messages,
330-
temperature=temperature,
331-
max_tokens=self.max_tokens,
332-
**self.extra_kwargs,
333-
)
334-
return response
335-
except Exception as e:
336-
logging.error(f"Failed to get a response from the API: {e}")
337-
raise e
338-
339-
340370
@dataclass
341371
class ClaudeResponseModelArgs(BaseModelArgs):
342372
"""Serializable object for instantiating a generic chat model with an OpenAI

0 commit comments

Comments
 (0)