11import fnmatch
2+ import json
23import logging
34from abc import ABC , abstractmethod
45from copy import copy
@@ -62,6 +63,12 @@ class MsgGroup:
6263
6364
6465class StructuredDiscussion :
66+ """
67+ A structured discussion that groups messages into named groups with a potential summary for each group.
68+
69+ When the discussion is flattened, only the last `keep_last_n_obs` groups are kept in the final list,
70+ the other groups are replaced by their summaries if they have one.
71+ """
6572
6673 def __init__ (self , keep_last_n_obs = None ):
6774 self .groups : list [MsgGroup ] = []
@@ -84,9 +91,7 @@ def flatten(self) -> list[MessageBuilder]:
8491 messages = []
8592 for i , group in enumerate (self .groups ):
8693 is_tail = i >= len (self .groups ) - keep_last_n_obs
87- print (
88- f"Processing group { i } ({ group .name } ), is_tail={ is_tail } , len(greoup)={ len (group .messages )} "
89- )
94+
9095 if not is_tail and group .summary is not None :
9196 messages .append (group .summary )
9297 else :
@@ -103,15 +108,6 @@ def is_goal_set(self) -> bool:
103108 return len (self .groups ) > 0
104109
105110
106- # @dataclass
107- # class BlockArgs(ABC):
108-
109- # @abstractmethod
110- # def make(self) -> Block:
111- # """Make a block from the arguments."""
112- # return self.__class__(**asdict(self))
113-
114-
115111SYS_MSG = """You are a web agent. Based on the observation, you will decide which action to take to accomplish your goal.
116112You strive for excellence and need to be as meticulous as possible. Make sure to explore when not sure.
117113"""
@@ -344,7 +340,9 @@ class PromptConfig:
344340 summarizer : Summarizer = None
345341 general_hints : GeneralHints = None
346342 task_hint : TaskHint = None
347- keep_last_n_obs : int = 2
343+ keep_last_n_obs : int = 1
344+ multiaction : bool = False
345+ action_subsets : tuple [str ] = field (default_factory = lambda : ("coord" ,))
348346
349347
350348@dataclass
@@ -382,7 +380,9 @@ def __init__(
382380 ):
383381 self .model_args = model_args
384382 self .config = config
385- self .action_set = bgym .HighLevelActionSet (["coord" ], multiaction = False )
383+ self .action_set = bgym .HighLevelActionSet (
384+ self .config .action_subsets , multiaction = self .config .multiaction
385+ )
386386 self .tools = self .action_set .to_tool_description (api = model_args .api )
387387
388388 self .call_ids = []
@@ -444,7 +444,6 @@ def get_action(self, obs: Any) -> float:
444444 self .discussion .new_group ()
445445
446446 self .obs_block .apply (self .llm , self .discussion , obs , last_llm_output = self .last_response )
447- print ("flatten for summary" )
448447
449448 self .config .summarizer .apply (self .llm , self .discussion )
450449
@@ -466,6 +465,13 @@ def get_action(self, obs: Any) -> float:
466465 self ._responses .append (response ) # may be useful for debugging
467466 # self.messages.append(response.assistant_message) # this is tool call
468467
468+ tools_str = json .dumps (self .tools , indent = 2 )
469+ tools_msg = MessageBuilder ("tool_description" ).add_text (tools_str )
470+
471+ # Adding these extra messages to visualize in gradio
472+ messages .insert (0 , tools_msg ) # insert at the beginning of the messages
473+ messages .append (response .tool_calls )
474+
469475 agent_info = bgym .AgentInfo (
470476 think = think ,
471477 chat_messages = messages ,
@@ -509,13 +515,16 @@ def get_action(self, obs: Any) -> float:
509515 use_last_error = True ,
510516 use_screenshot = True ,
511517 use_axtree = False ,
512- use_dom = False ,
518+ use_dom = True ,
513519 use_som = False ,
514520 use_tabs = False ,
515521 ),
516522 summarizer = Summarizer (do_summary = True ),
517523 general_hints = GeneralHints (use_hints = False ),
518524 task_hint = TaskHint (use_task_hint = True ),
525+ keep_last_n_obs = 1 , # keep only the last observation in the discussion
526+ multiaction = False , # whether to use multi-action or not
527+ action_subsets = ("bid" ,),
519528)
520529
521530AGENT_CONFIG = ToolUseAgentArgs (
0 commit comments