Skip to content

Commit b813df7

Browse files
committed
Enhance StructuredDiscussion to group messages with summaries and adjust ToolUseAgent configuration for multi-action support
1 parent 60eed9e commit b813df7

File tree

1 file changed

+25
-16
lines changed

1 file changed

+25
-16
lines changed

src/agentlab/agents/tool_use_agent/multi_tool_agent.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import fnmatch
2+
import json
23
import logging
34
from abc import ABC, abstractmethod
45
from copy import copy
@@ -62,6 +63,12 @@ class MsgGroup:
6263

6364

6465
class StructuredDiscussion:
66+
"""
67+
A structured discussion that groups messages into named groups with a potential summary for each group.
68+
69+
When the discussion is flattened, only the last `keep_last_n_obs` groups are kept in the final list,
70+
the other groups are replaced by their summaries if they have one.
71+
"""
6572

6673
def __init__(self, keep_last_n_obs=None):
6774
self.groups: list[MsgGroup] = []
@@ -84,9 +91,7 @@ def flatten(self) -> list[MessageBuilder]:
8491
messages = []
8592
for i, group in enumerate(self.groups):
8693
is_tail = i >= len(self.groups) - keep_last_n_obs
87-
print(
88-
f"Processing group {i} ({group.name}), is_tail={is_tail}, len(greoup)={len(group.messages)}"
89-
)
94+
9095
if not is_tail and group.summary is not None:
9196
messages.append(group.summary)
9297
else:
@@ -103,15 +108,6 @@ def is_goal_set(self) -> bool:
103108
return len(self.groups) > 0
104109

105110

106-
# @dataclass
107-
# class BlockArgs(ABC):
108-
109-
# @abstractmethod
110-
# def make(self) -> Block:
111-
# """Make a block from the arguments."""
112-
# return self.__class__(**asdict(self))
113-
114-
115111
SYS_MSG = """You are a web agent. Based on the observation, you will decide which action to take to accomplish your goal.
116112
You strive for excellence and need to be as meticulous as possible. Make sure to explore when not sure.
117113
"""
@@ -344,7 +340,9 @@ class PromptConfig:
344340
summarizer: Summarizer = None
345341
general_hints: GeneralHints = None
346342
task_hint: TaskHint = None
347-
keep_last_n_obs: int = 2
343+
keep_last_n_obs: int = 1
344+
multiaction: bool = False
345+
action_subsets: tuple[str] = field(default_factory=lambda: ("coord",))
348346

349347

350348
@dataclass
@@ -382,7 +380,9 @@ def __init__(
382380
):
383381
self.model_args = model_args
384382
self.config = config
385-
self.action_set = bgym.HighLevelActionSet(["coord"], multiaction=False)
383+
self.action_set = bgym.HighLevelActionSet(
384+
self.config.action_subsets, multiaction=self.config.multiaction
385+
)
386386
self.tools = self.action_set.to_tool_description(api=model_args.api)
387387

388388
self.call_ids = []
@@ -444,7 +444,6 @@ def get_action(self, obs: Any) -> float:
444444
self.discussion.new_group()
445445

446446
self.obs_block.apply(self.llm, self.discussion, obs, last_llm_output=self.last_response)
447-
print("flatten for summary")
448447

449448
self.config.summarizer.apply(self.llm, self.discussion)
450449

@@ -466,6 +465,13 @@ def get_action(self, obs: Any) -> float:
466465
self._responses.append(response) # may be useful for debugging
467466
# self.messages.append(response.assistant_message) # this is tool call
468467

468+
tools_str = json.dumps(self.tools, indent=2)
469+
tools_msg = MessageBuilder("tool_description").add_text(tools_str)
470+
471+
# Adding these extra messages to visualize in gradio
472+
messages.insert(0, tools_msg) # insert at the beginning of the messages
473+
messages.append(response.tool_calls)
474+
469475
agent_info = bgym.AgentInfo(
470476
think=think,
471477
chat_messages=messages,
@@ -509,13 +515,16 @@ def get_action(self, obs: Any) -> float:
509515
use_last_error=True,
510516
use_screenshot=True,
511517
use_axtree=False,
512-
use_dom=False,
518+
use_dom=True,
513519
use_som=False,
514520
use_tabs=False,
515521
),
516522
summarizer=Summarizer(do_summary=True),
517523
general_hints=GeneralHints(use_hints=False),
518524
task_hint=TaskHint(use_task_hint=True),
525+
keep_last_n_obs=1, # keep only the last observation in the discussion
526+
multiaction=False, # whether to use multi-action or not
527+
action_subsets=("bid",),
519528
)
520529

521530
AGENT_CONFIG = ToolUseAgentArgs(

0 commit comments

Comments
 (0)