Skip to content

Commit b18e1e0

Browse files
Merge remote-tracking branch 'origin/allac/next-agent' into aj/tool_use_agent_chat_completion_support
2 parents 8ebbd7f + 7f4c018 commit b18e1e0

File tree

1 file changed

+27
-19
lines changed

1 file changed

+27
-19
lines changed

src/agentlab/agents/tool_use_agent/multi_tool_agent.py

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import fnmatch
2+
import json
23
import logging
34
from abc import ABC, abstractmethod
45
from copy import copy
@@ -62,10 +63,16 @@ class MsgGroup:
6263

6364

6465
class StructuredDiscussion:
66+
"""
67+
A structured discussion that groups messages into named groups with a potential summary for each group.
68+
69+
When the discussion is flattened, only the last `keep_last_n_obs` groups are kept in the final list,
70+
the other groups are replaced by their summaries if they have one.
71+
"""
6572

6673
def __init__(self, keep_last_n_obs=None):
6774
self.groups: list[MsgGroup] = []
68-
self.keep_last_n_obs: int| None = keep_last_n_obs
75+
self.keep_last_n_obs: int | None = keep_last_n_obs
6976

7077
def append(self, message: MessageBuilder):
7178
"""Append a message to the last group."""
@@ -84,15 +91,12 @@ def flatten(self) -> list[MessageBuilder]:
8491
messages = []
8592
for i, group in enumerate(self.groups):
8693
is_tail = i >= len(self.groups) - keep_last_n_obs
87-
print(
88-
f"Processing group {i} ({group.name}), is_tail={is_tail}, len(greoup)={len(group.messages)}"
89-
)
90-
# Include only summary if group not in last n groups.
94+
9195
if not is_tail and group.summary is not None:
9296
messages.append(group.summary)
9397
else:
9498
messages.extend(group.messages)
95-
# Mark all summarized messages for caching
99+
# Mark all summarized messages for caching
96100
if i == len(self.groups) - keep_last_n_obs:
97101
messages[i].mark_all_previous_msg_for_caching()
98102
return messages
@@ -106,15 +110,6 @@ def is_goal_set(self) -> bool:
106110
return len(self.groups) > 0
107111

108112

109-
# @dataclass
110-
# class BlockArgs(ABC):
111-
112-
# @abstractmethod
113-
# def make(self) -> Block:
114-
# """Make a block from the arguments."""
115-
# return self.__class__(**asdict(self))
116-
117-
118113
SYS_MSG = """You are a web agent. Based on the observation, you will decide which action to take to accomplish your goal.
119114
You strive for excellence and need to be as meticulous as possible. Make sure to explore when not sure.
120115
"""
@@ -347,7 +342,9 @@ class PromptConfig:
347342
summarizer: Summarizer = None
348343
general_hints: GeneralHints = None
349344
task_hint: TaskHint = None
350-
keep_last_n_obs: int = 2
345+
keep_last_n_obs: int = 1
346+
multiaction: bool = False
347+
action_subsets: tuple[str] = field(default_factory=lambda: ("coord",))
351348

352349

353350
@dataclass
@@ -385,7 +382,9 @@ def __init__(
385382
):
386383
self.model_args = model_args
387384
self.config = config
388-
self.action_set = bgym.HighLevelActionSet(["coord"], multiaction=False)
385+
self.action_set = bgym.HighLevelActionSet(
386+
self.config.action_subsets, multiaction=self.config.multiaction
387+
)
389388
self.tools = self.action_set.to_tool_description(api=model_args.api)
390389

391390
self.call_ids = []
@@ -447,7 +446,6 @@ def get_action(self, obs: Any) -> float:
447446
self.discussion.new_group()
448447

449448
self.obs_block.apply(self.llm, self.discussion, obs, last_llm_output=self.last_response)
450-
print("flatten for summary")
451449

452450
self.config.summarizer.apply(self.llm, self.discussion)
453451

@@ -470,6 +468,13 @@ def get_action(self, obs: Any) -> float:
470468
self._responses.append(response) # may be useful for debugging
471469
# self.messages.append(response.assistant_message) # this is tool call
472470

471+
tools_str = json.dumps(self.tools, indent=2)
472+
tools_msg = MessageBuilder("tool_description").add_text(tools_str)
473+
474+
# Adding these extra messages to visualize in gradio
475+
messages.insert(0, tools_msg) # insert at the beginning of the messages
476+
messages.append(response.tool_calls)
477+
473478
agent_info = bgym.AgentInfo(
474479
think=think,
475480
chat_messages=messages,
@@ -513,13 +518,16 @@ def get_action(self, obs: Any) -> float:
513518
use_last_error=True,
514519
use_screenshot=True,
515520
use_axtree=False,
516-
use_dom=False,
521+
use_dom=True,
517522
use_som=False,
518523
use_tabs=False,
519524
),
520525
summarizer=Summarizer(do_summary=True),
521526
general_hints=GeneralHints(use_hints=False),
522527
task_hint=TaskHint(use_task_hint=True),
528+
keep_last_n_obs=1, # keep only the last observation in the discussion
529+
multiaction=False, # whether to use multi-action or not
530+
action_subsets=("bid",),
523531
)
524532

525533
AGENT_CONFIG = ToolUseAgentArgs(

0 commit comments

Comments
 (0)