Skip to content

Commit 593e104

Browse files
authored
Merge branch 'generic_agent_hinter' into scratch/refactor-hint-retrieval
2 parents 49ebc89 + 15c5639 commit 593e104

File tree

7 files changed

+45
-16
lines changed

7 files changed

+45
-16
lines changed

src/agentlab/agents/generic_agent_hinter/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
AGENT_CLAUDE_SONNET_35,
1414
AGENT_CLAUDE_SONNET_35_VISION,
1515
AGENT_CUSTOM,
16+
AGENT_GPT5_MINI,
17+
AGENT_GPT5_NANO,
1618
AGENT_LLAMA3_70B,
1719
AGENT_LLAMA4_17B_INSTRUCT,
1820
AGENT_LLAMA31_70B,
@@ -26,9 +28,7 @@
2628
AGENT_o3_MINI,
2729
FLAGS_GPT_4o,
2830
GenericAgentArgs,
29-
AGENT_GPT5_MINI,
3031
)
31-
3232
from .generic_agent import GenericAgent, GenericAgentArgs
3333

3434
__all__ = [
@@ -50,4 +50,5 @@
5050
"AGENT_4o_MINI_VISION",
5151
"AGENT_CLAUDE_SONNET_35_VISION",
5252
"AGENT_GPT5_MINI",
53+
"AGENT_GPT5_NANO",
5354
]

src/agentlab/agents/generic_agent_hinter/agent_configs.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,10 @@
365365
chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-5-mini-2025-08-07"],
366366
flags=GPT5_MINI_FLAGS,
367367
)
368+
AGENT_GPT5_NANO = GenericAgentArgs(
369+
chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-5-nano-2025-08-07"],
370+
flags=GPT5_MINI_FLAGS,
371+
)
368372

369373
AGENT_GPT5 = GenericAgentArgs(
370374
chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-5-2025-08-07"],

src/agentlab/agents/generic_agent_hinter/generic_agent.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
the agent, including model arguments and flags for various behaviors.
99
"""
1010

11+
import os
1112
from copy import deepcopy
1213
from dataclasses import asdict, dataclass
1314
from pathlib import Path
@@ -91,6 +92,8 @@ def __init__(
9192
self.max_retry = max_retry
9293

9394
self.flags = flags
95+
if self.flags.hint_db_path is not None:
96+
assert os.path.exists(self.flags.hint_db_path), f"Hint database path {self.flags.hint_db_path} does not exist."
9497
self.action_set = self.flags.action.action_set.make_action_set()
9598
self._obs_preprocessor = dp.make_obs_preprocessor(flags.obs)
9699

@@ -113,11 +116,9 @@ def get_action(self, obs):
113116

114117
system_prompt = SystemMessage(dp.SystemPrompt().prompt)
115118

116-
queries, think_queries = self._get_queries()
117-
118119
# use those queries to retrieve from the database and pass to prompt if step-level
119120
queries_for_hints = (
120-
queries if getattr(self.flags, "hint_level", "episode") == "step" else None
121+
self._get_queries()[0] if getattr(self.flags, "hint_level", "episode") == "step" else None
121122
)
122123

123124
# get hints
@@ -211,7 +212,7 @@ def _get_queries(self):
211212
)
212213

213214
queries = ans_dict.get("queries", [])
214-
assert len(queries) == self.flags.n_retrieval_queries
215+
assert len(queries) <= self.flags.n_retrieval_queries
215216

216217
# TODO: we should probably propagate these chat_messages to be able to see them in xray
217218
return queries, ans_dict.get("think", None)

src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from agentlab.llm.llm_utils import HumanMessage, parse_html_tags_raise
1818
from browsergym.core.action.base import AbstractActionSet
1919

20+
logger = logging.getLogger(__name__)
2021

2122
@dataclass
2223
class GenericPromptFlags(dp.Flags):
@@ -359,13 +360,14 @@ def _prompt(self) -> HumanMessage:
359360
# Querying memory
360361
361362
Before choosing an action, let's search our available documentation and memory for relevant context.
362-
Generate a brief, general summary of the current status to help identify useful hints. Return your answer as follow
363+
Generate a brief, general summary of the current status to help identify useful hints. Return your answer in the following format:
363364
<think>chain of thought</think>
364-
<queries>json list of strings</queries> for the queries. Return exactly {self.n_queries}
365-
queries in the list.
365+
<queries>json list of strings of queries</queries>
366366
367-
# Concrete Example
367+
Additional instructions: List of queries should contain up to {self.n_queries} queries. Both the think and the queries blocks are required!
368368
369+
# Concrete Example
370+
```
369371
<think>
370372
I have to sort by client and country. I could use the built-in sort on each column but I'm not sure if
371373
I will be able to sort by both at the same time.
@@ -374,6 +376,10 @@ def _prompt(self) -> HumanMessage:
374376
<queries>
375377
{example_queries_str}
376378
</queries>
379+
```
380+
Note: do not generate backticks.
381+
Now proceed to generate your own thoughts and queries.
382+
Always return non-empty answer, its very important!
377383
"""
378384
)
379385

@@ -384,8 +390,19 @@ def shrink(self):
384390
self.obs.shrink()
385391

386392
def _parse_answer(self, text_answer):
387-
ans_dict = parse_html_tags_raise(
388-
text_answer, keys=["think", "queries"], merge_multiple=True
389-
)
390-
ans_dict["queries"] = json.loads(ans_dict.get("queries", "[]"))
393+
try:
394+
ans_dict = parse_html_tags_raise(
395+
text_answer, keys=["think", "queries"], merge_multiple=True
396+
)
397+
except Exception as e:
398+
t = text_answer.replace("\n", "\\n")
399+
logger.warning(f"Failed to parse llm answer: {e}. RAW answer: '{t}'. Will retry")
400+
raise e
401+
raw_queries = ans_dict.get("queries", "[]")
402+
try:
403+
ans_dict["queries"] = json.loads(raw_queries)
404+
except Exception as e:
405+
t = text_answer.replace("\n", "\\n")
406+
logger.warning(f"Failed to parse queries: {e}. Queries block content: '{ans_dict['queries']}'. RAW llm answer: '{t}'. Will retry")
407+
raise e
391408
return ans_dict

src/agentlab/agents/tool_use_agent/tool_use_agent.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from agentlab.utils.hinting import HintsSource
4545

4646
logger = logging.getLogger(__name__)
47+
logger.setLevel(logging.INFO)
4748

4849

4950
@dataclass

src/agentlab/experiments/launch_exp.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
import os
23
from importlib import import_module
34
from pathlib import Path
45

@@ -7,6 +8,8 @@
78
from agentlab.experiments.exp_utils import run_exp
89
from agentlab.experiments.loop import ExpArgs, yield_all_exp_results
910

11+
RAY_PUBLIC_DASHBOARD = os.environ.get("RAY_PUBLIC_DASHBOARD", "false") == "true"
12+
1013

1114
def run_experiments(
1215
n_jobs,
@@ -82,7 +85,9 @@ def run_experiments(
8285
elif parallel_backend == "ray":
8386
from agentlab.experiments.graph_execution_ray import execute_task_graph, ray
8487

85-
ray.init(num_cpus=n_jobs)
88+
ray.init(
89+
num_cpus=n_jobs, dashboard_host="0.0.0.0" if RAY_PUBLIC_DASHBOARD else "127.0.0.1"
90+
)
8691
try:
8792
execute_task_graph(exp_args_list, avg_step_timeout=avg_step_timeout)
8893
finally:

src/agentlab/experiments/loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -907,14 +907,14 @@ def _move_old_exp(exp_dir):
907907

908908
def _get_env_name(task_name: str):
909909
"""Register tasks if needed (lazy import) and return environment name."""
910-
911910
# lazy import
912911
if task_name.startswith("miniwob"):
913912
import browsergym.miniwob
914913
elif task_name.startswith("workarena"):
915914
import browsergym.workarena
916915
elif task_name.startswith("webarena"):
917916
import browsergym.webarena
917+
import browsergym.webarenalite
918918
elif task_name.startswith("visualwebarena"):
919919
import browsergym.visualwebarena
920920
elif task_name.startswith("assistantbench"):

0 commit comments

Comments
 (0)