Skip to content

Commit c5ea716

Browse files
committed
Enhance summarization functionality in ToolUseAgent: add detailed initialization for summarizer, improve logging, and update hints formatting.
1 parent 39084c0 commit c5ea716

File tree

1 file changed

+40
-9
lines changed

1 file changed

+40
-9
lines changed

src/agentlab/agents/tool_use_agent/multi_tool_agent.py

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import fnmatch
2+
import logging
23
from abc import ABC, abstractmethod
34
from copy import copy
45
from dataclasses import asdict, dataclass
@@ -64,7 +65,6 @@ def apply(self, llm, messages: list[MessageBuilder], **kwargs):
6465

6566
SYS_MSG = """You are a web agent. Based on the observation, you will decide which action to take to accomplish your goal.
6667
You strive for excellence and need to be as meticulous as possible. Make sure to explore when not sure.
67-
Your chain of thought should have 3 sections: 1) Analyze the effect of the action, 2) Summarize the current state of the environment, 3) Reflect on the next action to take.
6868
"""
6969

7070

@@ -190,22 +190,45 @@ class Summarizer(Block):
190190
"""Block to summarize the last action and the current state of the environment."""
191191

192192
do_summary: bool = False
193+
high_details: bool = True
193194

194195
def apply(self, llm, messages: list[MessageBuilder]) -> dict:
195196
if not self.do_summary:
196197
return
197198

198-
msg = llm.msg.user().add_text(
199-
"Summarize the effect of the last action and the current state of the environment."
200-
)
199+
msg = llm.msg.user().add_text("""Summarize\n""")
201200

202201
messages.append(msg)
203202
# TODO need to make sure we don't force tool use here
204-
summary_response = llm(messages=messages)
203+
summary_response = llm(messages=messages, tool_choice="none")
205204

206205
summary_msg = llm.msg.assistant().add_text(summary_response.think)
207206
messages.append(summary_msg)
208207

208+
def apply_init(self, llm, messages: list[MessageBuilder]) -> dict:
209+
"""Initialize the summarizer block."""
210+
if not self.do_summary:
211+
return
212+
213+
system_msg = llm.msg.system()
214+
if self.high_details:
215+
# Add a system message to the LLM to indicate that it should summarize
216+
system_msg.add_text(
217+
"""# Summarizer instructions:\nWhen asked to summarize, do the following:
218+
1) Summarize the effect of the last action, with attention to details.
219+
2) Give a semantic description of the current state of the environment, with attention to details. If there was a repeating mistake, mention the cause of it.
220+
3) Reason about the overall task at a high level.
221+
4) What hint can be relevant for the next action? Only chose from the hints provided in the task description. Or select none.
222+
5) What is the currently activated item if any.
223+
6) Reason about the next action to take, based on the current state and the goal.
224+
"""
225+
)
226+
else:
227+
system_msg.add_text(
228+
"""When asked to summarize, give a semantic description of the current state of the environment."""
229+
)
230+
messages.append(system_msg)
231+
209232

210233
@dataclass
211234
class TaskHint(Block):
@@ -234,7 +257,9 @@ def apply(self, llm, messages: list[MessageBuilder], task_name: str) -> dict:
234257
if hint:
235258
hints.append(f"- {hint}")
236259

237-
hints_str = "Here are some hints for the task you are working on:\n" + "\n".join(hints)
260+
hints_str = "# Hints:\nHere are some hints for the task you are working on:\n" + "\n".join(
261+
hints
262+
)
238263
msg = llm.msg.user().add_text(hints_str)
239264

240265
messages.append(msg)
@@ -357,13 +382,19 @@ def get_action(self, obs: Any) -> float:
357382
self.llm.reset_stats()
358383
if len(self.messages) == 0:
359384
self.config.goal.apply(self.llm, self.messages, obs)
385+
self.config.summarizer.apply_init(self.llm, self.messages)
360386
self.config.general_hints.apply(self.llm, self.messages)
361387
self.task_hint.apply(self.llm, self.messages, self.task_name)
362388

389+
logging.info("Appending observation to messages")
363390
self.config.obs.apply(self.llm, self.messages, obs, last_llm_output=self.last_response)
391+
logging.info("Calling summarizer")
364392
self.config.summarizer.apply(self.llm, self.messages)
365-
response: LLMOutput = self.llm(messages=self.messages,
366-
cache_tool_definition=True)
393+
logging.info("Main tool calling")
394+
response: LLMOutput = self.llm(
395+
messages=self.messages, tool_choice="any", cache_tool_definition=True
396+
)
397+
logging.info(f"Obtained response {response}")
367398

368399
action = response.action
369400
think = response.think
@@ -421,7 +452,7 @@ def get_action(self, obs: Any) -> float:
421452
use_som=False,
422453
use_tabs=False,
423454
),
424-
summarizer=Summarizer(),
455+
summarizer=Summarizer(do_summary=True),
425456
general_hints=GeneralHints(use_hints=False),
426457
task_hint=TaskHint(use_task_hint=True),
427458
)

0 commit comments

Comments
 (0)