Skip to content

Commit d9c9216

Browse files
committed
simplify history format
1 parent 29ba1c4 commit d9c9216

File tree

2 files changed

+90
-124
lines changed

2 files changed

+90
-124
lines changed
Lines changed: 89 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import json
22
import logging
33
import pprint
4+
import time
45
from dataclasses import dataclass
56
from functools import partial
6-
from typing import Callable
7+
from typing import Callable, Literal
78

8-
from litellm import completion_with_retries
9+
from litellm import completion
910
from litellm.types.utils import ChatCompletionMessageToolCall, Message, ModelResponse
1011
from PIL import Image
1112
from termcolor import colored
@@ -17,81 +18,24 @@
1718

1819
logger = logging.getLogger(__name__)
1920

20-
@dataclass
21-
class Observation:
22-
data: dict # expected keys: goal_object, pruned_html, axtree_txt, screenshot, last_action_error, action_result
23-
24-
def to_messages(self) -> list[dict]:
25-
"""
26-
Convert the observation dictionary into a list of chat messages for Lite LLM
27-
"""
28-
messages = []
29-
tool_call_id = self.data.get("tool_call_id")
30-
if self.data.get("goal_object") and not tool_call_id: # its a first observation when there are no tool_call_id, so include goal
31-
goal=self.data["goal_object"][0]["text"]
32-
messages.append({
33-
"role": "user",
34-
"content": f"## Goal:\n{goal}"
35-
})
36-
text_obs = []
37-
if self.data.get("action_result"):
38-
result=self.data["action_result"]
39-
text_obs.append(f"Action Result:\n{result}")
40-
if self.data.get("pruned_html"):
41-
html=self.data["pruned_html"]
42-
text_obs.append(f"Pruned HTML:\n{html}")
43-
if self.data.get("axtree_txt"):
44-
axtree=self.data["axtree_txt"]
45-
text_obs.append(f"Accessibility Tree:\n{axtree}")
46-
if self.data.get("last_action_error"):
47-
error = self.data['last_action_error']
48-
text_obs.append(f"Action Error:\n{error}")
49-
if text_obs:
50-
if tool_call_id:
51-
message = {
52-
"role": "tool",
53-
"tool_call_id": tool_call_id,
54-
"content": "\n\n".join(text_obs),
55-
}
56-
else:
57-
message = {
58-
"role": "user",
59-
"content": "\n\n".join(text_obs),
60-
}
61-
messages.append(message)
62-
if self.data.get("screenshot"):
63-
if isinstance(self.data["screenshot"], Image.Image):
64-
image_content_url = image_to_png_base64_url(self.data["screenshot"])
65-
messages.append({
66-
"role": "user",
67-
"content": [{"type": "image_url", "image_url": {"url": image_content_url}}],
68-
})
69-
else:
70-
raise ValueError(f"Expected Image.Image, got {type(self.data['screenshot'])}")
71-
return messages
7221

73-
@dataclass
74-
class LLMOutput:
75-
"""
76-
LiteLLM output message containing all the llm response details, suitable for putting back into prompt to reuse KV cache
77-
"""
78-
message: Message
79-
def to_messages(self) -> list[Message]:
80-
return [self.message]
81-
82-
@dataclass
83-
class SystemMessage:
84-
message: str
85-
def to_messages(self) -> list[dict]:
86-
return [{"role": "system", "content": self.message}]
22+
class LLMArgs(BaseModelArgs):
23+
reasoning_effort: Literal["minimal", "low", "medium", "high"] = "low"
24+
num_retries: int = 3
8725

88-
@dataclass
89-
class UserMessage:
90-
message: str
91-
def to_messages(self) -> list[dict]:
92-
return [{"role": "user", "content": self.message}]
26+
def make_model(self) -> Callable:
27+
return partial(
28+
completion,
29+
model=self.model_name,
30+
temperature=self.temperature,
31+
max_tokens=self.max_total_tokens,
32+
max_completion_tokens=self.max_new_tokens,
33+
reasoning_effort=self.reasoning_effort,
34+
num_retries=self.num_retries,
35+
tool_choice="auto",
36+
parallel_tool_calls=False,
37+
)
9338

94-
Step = LLMOutput | Observation | SystemMessage | UserMessage
9539

9640
@dataclass
9741
class AgentConfig:
@@ -112,68 +56,90 @@ class AgentConfig:
11256
2. Evaluate action success, explain impact on task and next steps.
11357
3. If you see any errors in the last observation, think about it. If there is no error, just move on.
11458
4. List next steps to move towards the goal and propose next immediate action.
115-
Then produce the function call that performs the proposed action. If the task is complete, produce the final step.
59+
Then produce the single function call that performs the proposed action. If the task is complete, produce the final step.
11660
"""
11761

118-
class LLMArgs(BaseModelArgs):
119-
reasoning_effort: str = "low"
120-
121-
def make_model(self) -> Callable:
122-
return partial(
123-
completion_with_retries,
124-
model=self.model_name,
125-
temperature=self.temperature,
126-
max_tokens=self.max_total_tokens,
127-
max_completion_tokens=self.max_new_tokens,
128-
reasoning_effort=self.reasoning_effort,
129-
)
13062

13163
class ReactToolCallAgent:
132-
def __init__(self, action_set: ToolsActionSet, llm: Callable, config: AgentConfig):
64+
def __init__(
65+
self, action_set: ToolsActionSet, llm: Callable[..., ModelResponse], config: AgentConfig
66+
):
13367
self.action_set = action_set
134-
self.history: list[Step] = [SystemMessage(message=config.system_prompt)]
68+
self.history: list[dict | Message] = [{"role": "system", "content": config.system_prompt}]
13569
self.llm = llm
13670
self.config = config
13771
self.last_tool_call_id: str = ""
13872

13973
def obs_preprocessor(self, obs: dict) -> dict:
140-
if not self.config.use_html:
141-
obs.pop("pruned_html", None)
142-
if not self.config.use_axtree:
143-
obs.pop("axtree_txt", None)
144-
if not self.config.use_screenshot:
145-
obs.pop("screenshot", None)
146-
if self.last_tool_call_id:
147-
# add tool_call_id to obs for linking observation to the last executed action
148-
obs["tool_call_id"] = self.last_tool_call_id
14974
return obs
15075

76+
def obs_to_messages(self, obs: dict) -> list[dict]:
77+
"""
78+
Convert the observation dictionary into a list of chat messages for Lite LLM
79+
"""
80+
messages = []
81+
if obs.get("goal_object") and not self.last_tool_call_id:
82+
# its a first observation when there are no tool_call_id, so include goal
83+
goal = obs["goal_object"][0]["text"]
84+
messages.append({"role": "user", "content": f"## Goal:\n{goal}"})
85+
text_obs = []
86+
if result := obs.get("action_result"):
87+
text_obs.append(f"## Action Result:\n{result}")
88+
if error := obs.get("last_action_error"):
89+
text_obs.append(f"## Action Error:\n{error}")
90+
if self.config.use_html and (html := obs.get("pruned_html")):
91+
text_obs.append(f"## HTML:\n{html}")
92+
if self.config.use_axtree and (axtree := obs.get("axtree_txt")):
93+
text_obs.append(f"## Accessibility Tree:\n{axtree}")
94+
content = "\n\n".join(text_obs)
95+
if content:
96+
if self.last_tool_call_id:
97+
message = {
98+
"role": "tool",
99+
"tool_call_id": self.last_tool_call_id,
100+
"content": content,
101+
}
102+
else:
103+
message = {"role": "user", "content": content}
104+
messages.append(message)
105+
if self.config.use_screenshot and (scr := obs.get("screenshot")):
106+
if isinstance(scr, Image.Image):
107+
image_content = [
108+
{"type": "image_url", "image_url": {"url": image_to_png_base64_url(scr)}}
109+
]
110+
messages.append({"role": "user", "content": image_content})
111+
else:
112+
raise ValueError(
113+
f"Expected Image.Image in screenshot obs, got {type(obs['screenshot'])}"
114+
)
115+
return messages
116+
151117
def get_action(self, obs: dict) -> tuple[ToolCallAction, dict]:
152-
prev_actions = [step for step in self.history if isinstance(step, LLMOutput)]
153-
if len(prev_actions) >= self.config.max_actions:
118+
actions_count = len(
119+
[msg for msg in self.history if isinstance(msg, Message) and msg.tool_calls]
120+
)
121+
if actions_count >= self.config.max_actions:
154122
logger.warning("Max actions reached, stopping agent.")
155-
stop_action = ToolCallAction(id="stop", function=FunctionCall(name="final_step", arguments={}))
123+
stop_action = ToolCallAction(
124+
id="stop", function=FunctionCall(name="final_step", arguments={})
125+
)
156126
return stop_action, {}
157-
self.history.append(Observation(data=obs))
158-
steps = self.history + [UserMessage(message=self.config.guidance)]
159-
messages = [m for step in steps for m in step.to_messages()]
127+
self.history += self.obs_to_messages(self.obs_preprocessor(obs))
160128
tools = [tool.model_dump() for tool in self.action_set.actions]
129+
messages = self.history + [{"role": "user", "content": self.config.guidance}]
130+
161131
try:
162132
logger.info(colored(f"Prompt:\n{pprint.pformat(messages, width=120)}", "blue"))
163-
response: ModelResponse = self.llm(
164-
tools=tools,
165-
messages=messages,
166-
num_retries=self.config.max_retry,
167-
)
168-
message = response.choices[0].message # type: ignore
133+
response = self.llm(tools=tools, messages=messages)
134+
message = response.choices[0].message # type: ignore
169135
except Exception as e:
170136
logger.exception(f"Error getting LLM response: {e}. Prompt: {messages}")
171137
raise e
172138
logger.info(colored(f"LLM response:\n{pprint.pformat(message, width=120)}", "green"))
173-
self.history.append(LLMOutput(message=message))
139+
140+
self.history.append(message)
174141
thoughts = self.thoughts_from_message(message)
175142
action = self.action_from_message(message)
176-
177143
return action, {"think": thoughts}
178144

179145
def thoughts_from_message(self, message) -> str:
@@ -187,7 +153,7 @@ def thoughts_from_message(self, message) -> str:
187153
logger.info(colored(f"LLM thinking block:\n{thinking}", "yellow"))
188154
thoughts.append(thinking)
189155
if message.content:
190-
logger.info(colored(f"LLM output:\n{message.content}", "cyan"))
156+
logger.info(colored(f"LLM text output:\n{message.content}", "cyan"))
191157
thoughts.append(message.content)
192158
return "\n\n".join(thoughts)
193159

@@ -199,27 +165,27 @@ def action_from_message(self, message) -> ToolCallAction:
199165
assert isinstance(tool_call.function.name, str)
200166
try:
201167
args = json.loads(tool_call.function.arguments)
202-
action = ToolCallAction(
203-
id=tool_call.id,
204-
function=FunctionCall(name=tool_call.function.name, arguments=args)
205-
)
206168
except json.JSONDecodeError as e:
207-
logger.exception(f"Error in json parsing of tool call arguments, {e}: {tool_call.function.arguments}")
169+
logger.exception(
170+
f"Error in json parsing of tool call arguments, {e}: {tool_call.function.arguments}"
171+
)
208172
raise e
209-
173+
action = ToolCallAction(
174+
id=tool_call.id, function=FunctionCall(name=tool_call.function.name, arguments=args)
175+
)
210176
self.last_tool_call_id = action.id
177+
logger.info(f"Parsed tool call action: {action}")
211178
else:
212179
raise ValueError(f"No tool call found in LLM response: {message}")
213180
return action
214-
181+
215182

216183
@dataclass
217184
class ReactToolCallAgentArgs(AgentArgs):
218-
llm_args: LLMArgs = None # type: ignore
219-
config: AgentConfig = None # type: ignore
185+
llm_args: LLMArgs | None = None
186+
config: AgentConfig | None = None
220187

221188
def make_agent(self, actions: list[ToolSpec]) -> ReactToolCallAgent:
222189
llm = self.llm_args.make_model()
223190
action_set = ToolsActionSet(actions=actions)
224191
return ReactToolCallAgent(action_set=action_set, llm=llm, config=self.config)
225-

src/agentlab/agents/tapeagent/conf/miniwob.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@ defaults:
66
name: miniwob
77
comment: MiniWob Agent
88
parallel_backend: ray
9-
n_jobs: 16
9+
n_jobs: 8

0 commit comments

Comments
 (0)