|
1 | 1 | import fnmatch |
| 2 | +import logging |
2 | 3 | from abc import ABC, abstractmethod |
3 | 4 | from copy import copy |
4 | 5 | from dataclasses import asdict, dataclass |
@@ -64,7 +65,6 @@ def apply(self, llm, messages: list[MessageBuilder], **kwargs): |
64 | 65 |
|
65 | 66 | SYS_MSG = """You are a web agent. Based on the observation, you will decide which action to take to accomplish your goal. |
66 | 67 | You strive for excellence and need to be as meticulous as possible. Make sure to explore when not sure. |
67 | | -Your chain of thought should have 3 sections: 1) Analyze the effect of the action, 2) Summarize the current state of the environment, 3) Reflect on the next action to take. |
68 | 68 | """ |
69 | 69 |
|
70 | 70 |
|
@@ -190,22 +190,45 @@ class Summarizer(Block): |
190 | 190 | """Block to summarize the last action and the current state of the environment.""" |
191 | 191 |
|
192 | 192 | do_summary: bool = False |
| 193 | + high_details: bool = True |
193 | 194 |
|
194 | 195 | def apply(self, llm, messages: list[MessageBuilder]) -> dict: |
195 | 196 | if not self.do_summary: |
196 | 197 | return |
197 | 198 |
|
198 | | - msg = llm.msg.user().add_text( |
199 | | - "Summarize the effect of the last action and the current state of the environment." |
200 | | - ) |
| 199 | + msg = llm.msg.user().add_text("""Summarize\n""") |
201 | 200 |
|
202 | 201 | messages.append(msg) |
203 | 202 | # TODO need to make sure we don't force tool use here |
204 | | - summary_response = llm(messages=messages) |
| 203 | + summary_response = llm(messages=messages, tool_choice="none") |
205 | 204 |
|
206 | 205 | summary_msg = llm.msg.assistant().add_text(summary_response.think) |
207 | 206 | messages.append(summary_msg) |
208 | 207 |
|
| 208 | + def apply_init(self, llm, messages: list[MessageBuilder]) -> dict: |
| 209 | + """Initialize the summarizer block.""" |
| 210 | + if not self.do_summary: |
| 211 | + return |
| 212 | + |
| 213 | + system_msg = llm.msg.system() |
| 214 | + if self.high_details: |
| 215 | + # Add a system message to the LLM to indicate that it should summarize |
| 216 | + system_msg.add_text( |
| 217 | + """# Summarizer instructions:\nWhen asked to summarize, do the following: |
| 218 | + 1) Summarize the effect of the last action, with attention to details. |
| 219 | + 2) Give a semantic description of the current state of the environment, with attention to details. If there was a repeating mistake, mention the cause of it. |
| 220 | + 3) Reason about the overall task at a high level. |
| 221 | + 4) What hint can be relevant for the next action? Only chose from the hints provided in the task description. Or select none. |
| 222 | + 5) What is the currently activated item if any. |
| 223 | + 6) Reason about the next action to take, based on the current state and the goal. |
| 224 | + """ |
| 225 | + ) |
| 226 | + else: |
| 227 | + system_msg.add_text( |
| 228 | + """When asked to summarize, give a semantic description of the current state of the environment.""" |
| 229 | + ) |
| 230 | + messages.append(system_msg) |
| 231 | + |
209 | 232 |
|
210 | 233 | @dataclass |
211 | 234 | class TaskHint(Block): |
@@ -234,7 +257,9 @@ def apply(self, llm, messages: list[MessageBuilder], task_name: str) -> dict: |
234 | 257 | if hint: |
235 | 258 | hints.append(f"- {hint}") |
236 | 259 |
|
237 | | - hints_str = "Here are some hints for the task you are working on:\n" + "\n".join(hints) |
| 260 | + hints_str = "# Hints:\nHere are some hints for the task you are working on:\n" + "\n".join( |
| 261 | + hints |
| 262 | + ) |
238 | 263 | msg = llm.msg.user().add_text(hints_str) |
239 | 264 |
|
240 | 265 | messages.append(msg) |
@@ -357,13 +382,19 @@ def get_action(self, obs: Any) -> float: |
357 | 382 | self.llm.reset_stats() |
358 | 383 | if len(self.messages) == 0: |
359 | 384 | self.config.goal.apply(self.llm, self.messages, obs) |
| 385 | + self.config.summarizer.apply_init(self.llm, self.messages) |
360 | 386 | self.config.general_hints.apply(self.llm, self.messages) |
361 | 387 | self.task_hint.apply(self.llm, self.messages, self.task_name) |
362 | 388 |
|
| 389 | + logging.info("Appending observation to messages") |
363 | 390 | self.config.obs.apply(self.llm, self.messages, obs, last_llm_output=self.last_response) |
| 391 | + logging.info("Calling summarizer") |
364 | 392 | self.config.summarizer.apply(self.llm, self.messages) |
365 | | - response: LLMOutput = self.llm(messages=self.messages, |
366 | | - cache_tool_definition=True) |
| 393 | + logging.info("Main tool calling") |
| 394 | + response: LLMOutput = self.llm( |
| 395 | + messages=self.messages, tool_choice="any", cache_tool_definition=True |
| 396 | + ) |
| 397 | + logging.info(f"Obtained response {response}") |
367 | 398 |
|
368 | 399 | action = response.action |
369 | 400 | think = response.think |
@@ -421,7 +452,7 @@ def get_action(self, obs: Any) -> float: |
421 | 452 | use_som=False, |
422 | 453 | use_tabs=False, |
423 | 454 | ), |
424 | | - summarizer=Summarizer(), |
| 455 | + summarizer=Summarizer(do_summary=True), |
425 | 456 | general_hints=GeneralHints(use_hints=False), |
426 | 457 | task_hint=TaskHint(use_task_hint=True), |
427 | 458 | ) |
|
0 commit comments