22import logging
33from abc import ABC , abstractmethod
44from copy import copy
5- from dataclasses import asdict , dataclass
5+ from dataclasses import asdict , dataclass , field
66from pathlib import Path
77from typing import Any
88
@@ -54,6 +54,55 @@ def apply(self, llm, messages: list[MessageBuilder], **kwargs):
5454 pass
5555
5656
57+ @dataclass
58+ class MsgGroup :
59+ name : str = None
60+ messages : list [MessageBuilder ] = field (default_factory = list )
61+ summary : MessageBuilder = None
62+
63+
64+ class StructuredDiscussion :
65+
66+ def __init__ (self , keep_last_n_obs = None ):
67+ self .groups : list [MsgGroup ] = []
68+ self .keep_last_n_obs = keep_last_n_obs
69+
70+ def append (self , message : MessageBuilder ):
71+ """Append a message to the last group."""
72+ self .groups [- 1 ].messages .append (message )
73+
74+ def new_group (self , name : str = None ):
75+ """Start a new group of messages."""
76+ if name is None :
77+ name = f"group_{ len (self .groups )} "
78+ self .groups .append (MsgGroup (name ))
79+
80+ def flatten (self ) -> list [MessageBuilder ]:
81+ """Flatten the groups into a single list of messages."""
82+
83+ keep_last_n_obs = self .keep_last_n_obs or len (self .groups )
84+ messages = []
85+ for i , group in enumerate (self .groups ):
86+ is_tail = i >= len (self .groups ) - keep_last_n_obs
87+ print (
88+ f"Processing group { i } ({ group .name } ), is_tail={ is_tail } , len(greoup)={ len (group .messages )} "
89+ )
90+ if not is_tail and group .summary is not None :
91+ messages .append (group .summary )
92+ else :
93+ messages .extend (group .messages )
94+
95+ return messages
96+
97+ def set_last_summary (self , summary : MessageBuilder ):
98+ # append None to summaries until we reach the current group index
99+ self .groups [- 1 ].summary = summary
100+
101+ def is_goal_set (self ) -> bool :
102+ """Check if the goal is set in the first group."""
103+ return len (self .groups ) > 0
104+
105+
57106# @dataclass
58107# class BlockArgs(ABC):
59108
@@ -74,9 +123,9 @@ class Goal(Block):
74123
75124 goal_as_system_msg : bool = True
76125
77- def apply (self , llm , messages : list [ MessageBuilder ] , obs : dict ) -> dict :
126+ def apply (self , llm , discussion : StructuredDiscussion , obs : dict ) -> dict :
78127 system_message = llm .msg .system ().add_text (SYS_MSG )
79- messages .append (system_message )
128+ discussion .append (system_message )
80129
81130 if self .goal_as_system_msg :
82131 goal_message = llm .msg .system ()
@@ -89,7 +138,7 @@ def apply(self, llm, messages: list[MessageBuilder], obs: dict) -> dict:
89138 goal_message .add_text (content ["text" ])
90139 elif content ["type" ] == "image_url" :
91140 goal_message .add_image (content ["image_url" ])
92- messages .append (goal_message )
141+ discussion .append (goal_message )
93142
94143
95144AXTREE_NOTE = """
@@ -112,7 +161,7 @@ class Obs(Block):
112161 use_zoomed_webpage : bool = False
113162
114163 def apply (
115- self , llm , messages : list [ MessageBuilder ] , obs : dict , last_llm_output : LLMOutput
164+ self , llm , discussion : StructuredDiscussion , obs : dict , last_llm_output : LLMOutput
116165 ) -> dict :
117166
118167 if last_llm_output .tool_calls is None :
@@ -147,7 +196,7 @@ def apply(
147196 if self .use_tabs :
148197 obs_msg .add_text (_format_tabs (obs ))
149198
150- messages .append (obs_msg )
199+ discussion .append (obs_msg )
151200 return obs_msg
152201
153202
@@ -172,7 +221,7 @@ class GeneralHints(Block):
172221
173222 use_hints : bool = True
174223
175- def apply (self , llm , messages : list [ MessageBuilder ] ) -> dict :
224+ def apply (self , llm , discussion : StructuredDiscussion ) -> dict :
176225 if not self .use_hints :
177226 return
178227
@@ -182,7 +231,7 @@ def apply(self, llm, messages: list[MessageBuilder]) -> dict:
182231 """Use ControlOrMeta instead of Control and Meta for keyboard shortcuts, to be cross-platform compatible. E.g. use ControlOrMeta for mutliple selection in lists.\n """
183232 )
184233
185- messages .append (llm .msg .user ().add_text ("\n " .join (hints )))
234+ discussion .append (llm .msg .user ().add_text ("\n " .join (hints )))
186235
187236
188237@dataclass
@@ -192,20 +241,22 @@ class Summarizer(Block):
192241 do_summary : bool = False
193242 high_details : bool = True
194243
195- def apply (self , llm , messages : list [ MessageBuilder ] ) -> dict :
244+ def apply (self , llm , discussion : StructuredDiscussion ) -> dict :
196245 if not self .do_summary :
197246 return
198247
199248 msg = llm .msg .user ().add_text ("""Summarize\n """ )
200249
201- messages .append (msg )
250+ discussion .append (msg )
202251 # TODO need to make sure we don't force tool use here
203- summary_response = llm (messages = messages , tool_choice = "none" )
252+ summary_response = llm (messages = discussion . flatten () , tool_choice = "none" )
204253
205254 summary_msg = llm .msg .assistant ().add_text (summary_response .think )
206- messages .append (summary_msg )
255+ discussion .append (summary_msg )
256+ discussion .set_last_summary (summary_msg )
257+ return summary_msg
207258
208- def apply_init (self , llm , messages : list [ MessageBuilder ] ) -> dict :
259+ def apply_init (self , llm , discussion : StructuredDiscussion ) -> dict :
209260 """Initialize the summarizer block."""
210261 if not self .do_summary :
211262 return
@@ -226,7 +277,7 @@ def apply_init(self, llm, messages: list[MessageBuilder]) -> dict:
226277 system_msg .add_text (
227278 """When asked to summarize, give a semantic description of the current state of the environment."""
228279 )
229- messages .append (system_msg )
280+ discussion .append (system_msg )
230281
231282
232283@dataclass
@@ -242,7 +293,7 @@ def _init(self):
242293 # index the task_name for fast lookup
243294 # self.hint_db.set_index("task_name", inplace=True, drop=False)
244295
245- def apply (self , llm , messages : list [ MessageBuilder ] , task_name : str ) -> dict :
296+ def apply (self , llm , discussion : StructuredDiscussion , task_name : str ) -> dict :
246297 if not self .use_task_hint :
247298 return
248299
@@ -256,12 +307,14 @@ def apply(self, llm, messages: list[MessageBuilder], task_name: str) -> dict:
256307 if hint :
257308 hints .append (f"- { hint } " )
258309
259- hints_str = "# Hints:\n Here are some hints for the task you are working on:\n " + "\n " .join (
260- hints
261- )
262- msg = llm .msg .user ().add_text (hints_str )
310+ if len (hints ) > 0 :
311+ hints_str = (
312+ "# Hints:\n Here are some hints for the task you are working on:\n "
313+ + "\n " .join (hints )
314+ )
315+ msg = llm .msg .user ().add_text (hints_str )
263316
264- messages .append (msg )
317+ discussion .append (msg )
265318
266319
267320class ToolCall (Block ):
@@ -291,6 +344,7 @@ class PromptConfig:
291344 summarizer : Summarizer = None
292345 general_hints : GeneralHints = None
293346 task_hint : TaskHint = None
347+ keep_last_n_obs : int = 2
294348
295349
296350@dataclass
@@ -338,11 +392,11 @@ def __init__(
338392 self .llm .msg = self .msg_builder
339393
340394 self .task_hint = self .config .task_hint .make ()
395+ self .obs_block = self .config .obs .make ()
341396
342- self .messages : list [ MessageBuilder ] = []
397+ self .discussion = StructuredDiscussion ( self . config . keep_last_n_obs )
343398 self .last_response : LLMOutput = LLMOutput ()
344399 self ._responses : list [LLMOutput ] = []
345- self .obs_msg_set = list ()
346400
347401 def obs_preprocessor (self , obs ):
348402 obs = copy (obs )
@@ -380,19 +434,23 @@ def set_task_name(self, task_name: str):
380434 @cost_tracker_decorator
381435 def get_action (self , obs : Any ) -> float :
382436 self .llm .reset_stats ()
383- if len (self .messages ) == 0 :
384- self .config .goal .apply (self .llm , self .messages , obs )
385- self .config .summarizer .apply_init (self .llm , self .messages )
386- self .config .general_hints .apply (self .llm , self .messages )
387- self .task_hint .apply (self .llm , self .messages , self .task_name )
388-
389- obs_msg = self .config .obs .apply (
390- self .llm , self .messages , obs , last_llm_output = self .last_response
391- )
392- self .obs_msg_set
393- self .config .summarizer .apply (self .llm , self .messages )
437+ if not self .discussion .is_goal_set ():
438+ self .discussion .new_group ("goal" )
439+ self .config .goal .apply (self .llm , self .discussion , obs )
440+ self .config .summarizer .apply_init (self .llm , self .discussion )
441+ self .config .general_hints .apply (self .llm , self .discussion )
442+ self .task_hint .apply (self .llm , self .discussion , self .task_name )
443+
444+ self .discussion .new_group ()
445+
446+ self .obs_block .apply (self .llm , self .discussion , obs , last_llm_output = self .last_response )
447+ print ("flatten for summary" )
448+
449+ self .config .summarizer .apply (self .llm , self .discussion )
450+
451+ messages = self .discussion .flatten ()
394452 response : LLMOutput = self .llm (
395- messages = self . messages ,
453+ messages = messages ,
396454 tool_choice = "any" ,
397455 cache_tool_definition = True ,
398456 cache_complete_prompt = True ,
@@ -401,15 +459,16 @@ def get_action(self, obs: Any) -> float:
401459 action = response .action
402460 think = response .think
403461
404- self .messages .append (response .tool_calls )
462+ self .discussion .new_group ()
463+ self .discussion .append (response .tool_calls )
405464
406465 self .last_response = response
407466 self ._responses .append (response ) # may be useful for debugging
408467 # self.messages.append(response.assistant_message) # this is tool call
409468
410469 agent_info = bgym .AgentInfo (
411470 think = think ,
412- chat_messages = self . messages ,
471+ chat_messages = messages ,
413472 stats = self .llm .stats .stats_dict ,
414473 )
415474 return action , agent_info
0 commit comments