Skip to content

Commit 9ed3376

Browse files
improve entry point args for agentlab-mentor to allow multiple seeds
1 parent 517aaf5 commit 9ed3376

File tree

1 file changed

+20
-13
lines changed

1 file changed

+20
-13
lines changed

src/agentlab/agents/hitl_agent/launch_hint_ui.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
Console launcher for the Human-in-the-Loop Generic Agent UI.
33
44
Usage (installed entry point):
5-
agentlab-mentor --benchmark miniwob --task-name miniwob.book-flight --seed 123 --no-headless
5+
agentlab-mentor --benchmark miniwob --task-name miniwob.book-flight --seed 123 --seed 456 --no-headless
66
77
This will run a Study with the MultipleProposalGenericAgent and the selected task.
88
"""
@@ -11,6 +11,7 @@
1111

1212
import argparse
1313
import logging
14+
import copy
1415
from typing import Optional
1516

1617
import bgym
@@ -20,9 +21,10 @@
2021
)
2122
from agentlab.experiments.study import Study
2223

24+
logger = logging.getLogger(__name__)
2325

2426
def build_benchmark(
25-
benchmark_name: str, task_name: Optional[str], seed: Optional[int], headless: bool
27+
benchmark_name: str, task_name: Optional[str], seeds: Optional[list[int]], headless: bool
2628
):
2729
# Instantiate benchmark by name using BrowserGym registry
2830
try:
@@ -32,16 +34,20 @@ def build_benchmark(
3234
raise SystemExit(f"Unknown benchmark '{benchmark_name}'. Choose one of: {choices}") from e
3335

3436
if task_name:
35-
# If a fully-qualified name is provided, filter by exact match; otherwise, allow glob
36-
if any(ch in task_name for ch in "*?[]"):
37-
benchmark = benchmark.subset_from_glob("task_name", task_name)
38-
else:
39-
benchmark = benchmark.subset_from_glob("task_name", task_name)
40-
41-
# If a specific seed is provided, set it on all env args
42-
if seed is not None:
43-
for env_args in benchmark.env_args_list:
44-
env_args.task_seed = seed
37+
benchmark = benchmark.subset_from_glob("task_name", task_name)
38+
tasks = list(set(e.task_name for e in benchmark.env_args_list))
39+
logger.warning(f'Found {len(tasks)} tasks matching "{task_name}:" \n {tasks}, using only the first one.')
40+
task = tasks[0]
41+
42+
# If specific seeds are provided, duplicate envs for each seed
43+
if seeds is not None:
44+
new_env_args_list = []
45+
task_env = next((x for x in benchmark.env_args_list if x.task_name == task))
46+
for seed in seeds:
47+
ea = copy.deepcopy(task_env)
48+
ea.task_seed = seed
49+
new_env_args_list.append(ea)
50+
benchmark.env_args_list = new_env_args_list
4551

4652
# Reasonable defaults for interactive UI
4753
for env_args in benchmark.env_args_list:
@@ -66,9 +72,10 @@ def parse_args():
6672
)
6773
p.add_argument(
6874
"--seed",
75+
action="append",
6976
type=int,
7077
default=None,
71-
help="Task seed to use for all selected tasks. If omitted, tasks keep their configured/random seed.",
78+
help="Task seed. Repeat flag for multiple seeds (e.g., --seed 1 --seed 2). If omitted, tasks keep their configured/random seed.",
7279
)
7380
p.add_argument(
7481
"--jobs",

0 commit comments

Comments
 (0)