Skip to content

Commit 60eed9e

Browse files
committed
Add StructuredDiscussion class to manage message groups and improve message handling in ToolUseAgent
1 parent dbc065c commit 60eed9e

File tree

1 file changed

+95
-36
lines changed

1 file changed

+95
-36
lines changed

src/agentlab/agents/tool_use_agent/multi_tool_agent.py

Lines changed: 95 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import logging
33
from abc import ABC, abstractmethod
44
from copy import copy
5-
from dataclasses import asdict, dataclass
5+
from dataclasses import asdict, dataclass, field
66
from pathlib import Path
77
from typing import Any
88

@@ -54,6 +54,55 @@ def apply(self, llm, messages: list[MessageBuilder], **kwargs):
5454
pass
5555

5656

57+
@dataclass
58+
class MsgGroup:
59+
name: str = None
60+
messages: list[MessageBuilder] = field(default_factory=list)
61+
summary: MessageBuilder = None
62+
63+
64+
class StructuredDiscussion:
65+
66+
def __init__(self, keep_last_n_obs=None):
67+
self.groups: list[MsgGroup] = []
68+
self.keep_last_n_obs = keep_last_n_obs
69+
70+
def append(self, message: MessageBuilder):
71+
"""Append a message to the last group."""
72+
self.groups[-1].messages.append(message)
73+
74+
def new_group(self, name: str = None):
75+
"""Start a new group of messages."""
76+
if name is None:
77+
name = f"group_{len(self.groups)}"
78+
self.groups.append(MsgGroup(name))
79+
80+
def flatten(self) -> list[MessageBuilder]:
81+
"""Flatten the groups into a single list of messages."""
82+
83+
keep_last_n_obs = self.keep_last_n_obs or len(self.groups)
84+
messages = []
85+
for i, group in enumerate(self.groups):
86+
is_tail = i >= len(self.groups) - keep_last_n_obs
87+
print(
88+
f"Processing group {i} ({group.name}), is_tail={is_tail}, len(greoup)={len(group.messages)}"
89+
)
90+
if not is_tail and group.summary is not None:
91+
messages.append(group.summary)
92+
else:
93+
messages.extend(group.messages)
94+
95+
return messages
96+
97+
def set_last_summary(self, summary: MessageBuilder):
98+
# append None to summaries until we reach the current group index
99+
self.groups[-1].summary = summary
100+
101+
def is_goal_set(self) -> bool:
102+
"""Check if the goal is set in the first group."""
103+
return len(self.groups) > 0
104+
105+
57106
# @dataclass
58107
# class BlockArgs(ABC):
59108

@@ -74,9 +123,9 @@ class Goal(Block):
74123

75124
goal_as_system_msg: bool = True
76125

77-
def apply(self, llm, messages: list[MessageBuilder], obs: dict) -> dict:
126+
def apply(self, llm, discussion: StructuredDiscussion, obs: dict) -> dict:
78127
system_message = llm.msg.system().add_text(SYS_MSG)
79-
messages.append(system_message)
128+
discussion.append(system_message)
80129

81130
if self.goal_as_system_msg:
82131
goal_message = llm.msg.system()
@@ -89,7 +138,7 @@ def apply(self, llm, messages: list[MessageBuilder], obs: dict) -> dict:
89138
goal_message.add_text(content["text"])
90139
elif content["type"] == "image_url":
91140
goal_message.add_image(content["image_url"])
92-
messages.append(goal_message)
141+
discussion.append(goal_message)
93142

94143

95144
AXTREE_NOTE = """
@@ -112,7 +161,7 @@ class Obs(Block):
112161
use_zoomed_webpage: bool = False
113162

114163
def apply(
115-
self, llm, messages: list[MessageBuilder], obs: dict, last_llm_output: LLMOutput
164+
self, llm, discussion: StructuredDiscussion, obs: dict, last_llm_output: LLMOutput
116165
) -> dict:
117166

118167
if last_llm_output.tool_calls is None:
@@ -147,7 +196,7 @@ def apply(
147196
if self.use_tabs:
148197
obs_msg.add_text(_format_tabs(obs))
149198

150-
messages.append(obs_msg)
199+
discussion.append(obs_msg)
151200
return obs_msg
152201

153202

@@ -172,7 +221,7 @@ class GeneralHints(Block):
172221

173222
use_hints: bool = True
174223

175-
def apply(self, llm, messages: list[MessageBuilder]) -> dict:
224+
def apply(self, llm, discussion: StructuredDiscussion) -> dict:
176225
if not self.use_hints:
177226
return
178227

@@ -182,7 +231,7 @@ def apply(self, llm, messages: list[MessageBuilder]) -> dict:
182231
"""Use ControlOrMeta instead of Control and Meta for keyboard shortcuts, to be cross-platform compatible. E.g. use ControlOrMeta for mutliple selection in lists.\n"""
183232
)
184233

185-
messages.append(llm.msg.user().add_text("\n".join(hints)))
234+
discussion.append(llm.msg.user().add_text("\n".join(hints)))
186235

187236

188237
@dataclass
@@ -192,20 +241,22 @@ class Summarizer(Block):
192241
do_summary: bool = False
193242
high_details: bool = True
194243

195-
def apply(self, llm, messages: list[MessageBuilder]) -> dict:
244+
def apply(self, llm, discussion: StructuredDiscussion) -> dict:
196245
if not self.do_summary:
197246
return
198247

199248
msg = llm.msg.user().add_text("""Summarize\n""")
200249

201-
messages.append(msg)
250+
discussion.append(msg)
202251
# TODO need to make sure we don't force tool use here
203-
summary_response = llm(messages=messages, tool_choice="none")
252+
summary_response = llm(messages=discussion.flatten(), tool_choice="none")
204253

205254
summary_msg = llm.msg.assistant().add_text(summary_response.think)
206-
messages.append(summary_msg)
255+
discussion.append(summary_msg)
256+
discussion.set_last_summary(summary_msg)
257+
return summary_msg
207258

208-
def apply_init(self, llm, messages: list[MessageBuilder]) -> dict:
259+
def apply_init(self, llm, discussion: StructuredDiscussion) -> dict:
209260
"""Initialize the summarizer block."""
210261
if not self.do_summary:
211262
return
@@ -226,7 +277,7 @@ def apply_init(self, llm, messages: list[MessageBuilder]) -> dict:
226277
system_msg.add_text(
227278
"""When asked to summarize, give a semantic description of the current state of the environment."""
228279
)
229-
messages.append(system_msg)
280+
discussion.append(system_msg)
230281

231282

232283
@dataclass
@@ -242,7 +293,7 @@ def _init(self):
242293
# index the task_name for fast lookup
243294
# self.hint_db.set_index("task_name", inplace=True, drop=False)
244295

245-
def apply(self, llm, messages: list[MessageBuilder], task_name: str) -> dict:
296+
def apply(self, llm, discussion: StructuredDiscussion, task_name: str) -> dict:
246297
if not self.use_task_hint:
247298
return
248299

@@ -256,12 +307,14 @@ def apply(self, llm, messages: list[MessageBuilder], task_name: str) -> dict:
256307
if hint:
257308
hints.append(f"- {hint}")
258309

259-
hints_str = "# Hints:\nHere are some hints for the task you are working on:\n" + "\n".join(
260-
hints
261-
)
262-
msg = llm.msg.user().add_text(hints_str)
310+
if len(hints) > 0:
311+
hints_str = (
312+
"# Hints:\nHere are some hints for the task you are working on:\n"
313+
+ "\n".join(hints)
314+
)
315+
msg = llm.msg.user().add_text(hints_str)
263316

264-
messages.append(msg)
317+
discussion.append(msg)
265318

266319

267320
class ToolCall(Block):
@@ -291,6 +344,7 @@ class PromptConfig:
291344
summarizer: Summarizer = None
292345
general_hints: GeneralHints = None
293346
task_hint: TaskHint = None
347+
keep_last_n_obs: int = 2
294348

295349

296350
@dataclass
@@ -338,11 +392,11 @@ def __init__(
338392
self.llm.msg = self.msg_builder
339393

340394
self.task_hint = self.config.task_hint.make()
395+
self.obs_block = self.config.obs.make()
341396

342-
self.messages: list[MessageBuilder] = []
397+
self.discussion = StructuredDiscussion(self.config.keep_last_n_obs)
343398
self.last_response: LLMOutput = LLMOutput()
344399
self._responses: list[LLMOutput] = []
345-
self.obs_msg_set = list()
346400

347401
def obs_preprocessor(self, obs):
348402
obs = copy(obs)
@@ -380,19 +434,23 @@ def set_task_name(self, task_name: str):
380434
@cost_tracker_decorator
381435
def get_action(self, obs: Any) -> float:
382436
self.llm.reset_stats()
383-
if len(self.messages) == 0:
384-
self.config.goal.apply(self.llm, self.messages, obs)
385-
self.config.summarizer.apply_init(self.llm, self.messages)
386-
self.config.general_hints.apply(self.llm, self.messages)
387-
self.task_hint.apply(self.llm, self.messages, self.task_name)
388-
389-
obs_msg = self.config.obs.apply(
390-
self.llm, self.messages, obs, last_llm_output=self.last_response
391-
)
392-
self.obs_msg_set
393-
self.config.summarizer.apply(self.llm, self.messages)
437+
if not self.discussion.is_goal_set():
438+
self.discussion.new_group("goal")
439+
self.config.goal.apply(self.llm, self.discussion, obs)
440+
self.config.summarizer.apply_init(self.llm, self.discussion)
441+
self.config.general_hints.apply(self.llm, self.discussion)
442+
self.task_hint.apply(self.llm, self.discussion, self.task_name)
443+
444+
self.discussion.new_group()
445+
446+
self.obs_block.apply(self.llm, self.discussion, obs, last_llm_output=self.last_response)
447+
print("flatten for summary")
448+
449+
self.config.summarizer.apply(self.llm, self.discussion)
450+
451+
messages = self.discussion.flatten()
394452
response: LLMOutput = self.llm(
395-
messages=self.messages,
453+
messages=messages,
396454
tool_choice="any",
397455
cache_tool_definition=True,
398456
cache_complete_prompt=True,
@@ -401,15 +459,16 @@ def get_action(self, obs: Any) -> float:
401459
action = response.action
402460
think = response.think
403461

404-
self.messages.append(response.tool_calls)
462+
self.discussion.new_group()
463+
self.discussion.append(response.tool_calls)
405464

406465
self.last_response = response
407466
self._responses.append(response) # may be useful for debugging
408467
# self.messages.append(response.assistant_message) # this is tool call
409468

410469
agent_info = bgym.AgentInfo(
411470
think=think,
412-
chat_messages=self.messages,
471+
chat_messages=messages,
413472
stats=self.llm.stats.stats_dict,
414473
)
415474
return action, agent_info

0 commit comments

Comments
 (0)