Skip to content

Commit abb897f

Browse files
authored
Merge branch 'generic_agent_hinter' into step-wise-retrieval
2 parents ee2653a + 69048c4 commit abb897f

22 files changed

+2315
-71
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,3 +109,4 @@ hint = [
109109
[project.scripts]
110110
agentlab-assistant = "agentlab.ui_assistant:main"
111111
agentlab-xray = "agentlab.analyze.agent_xray:main"
112+
agentlab-mentor = "agentlab.agents.hitl_agent.launch_hint_ui:main"

src/agentlab/agents/agent_utils.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
1+
import copy
2+
13
from PIL import Image, ImageDraw
24
from playwright.sync_api import Page
35

6+
from agentlab.analyze import overlay_utils
7+
from agentlab.llm.llm_utils import img_to_base_64
8+
49

510
def draw_mouse_pointer(image: Image.Image, x: int, y: int) -> Image.Image:
611
"""
@@ -128,3 +133,24 @@ def zoom_webpage(page: Page, zoom_factor: float = 1.5):
128133

129134
page.evaluate(f"document.documentElement.style.zoom='{zoom_factor*100}%'")
130135
return page
136+
137+
138+
def overlay_action(obs, action):
139+
"""Overlays actions on screenshot in-place"""
140+
act_img = copy.deepcopy(obs["screenshot"])
141+
act_img = Image.fromarray(act_img)
142+
143+
new_obs_properties = copy.deepcopy(obs["extra_element_properties"])
144+
import os
145+
146+
if os.getenv("AGENTLAB_USE_RETINA"):
147+
# HACK: divide everything by 2 in the obs
148+
# TODO: make this more robust by changing login in annotate_action directly (or maybe in the obs section?)
149+
for key, value in new_obs_properties.items():
150+
try:
151+
new_obs_properties[key]["bbox"] = [elem / 2 for elem in value["bbox"]]
152+
except:
153+
pass
154+
155+
overlay_utils.annotate_action(act_img, action, properties=new_obs_properties)
156+
return img_to_base_64(act_img)

src/agentlab/agents/generic_agent/__init__.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,22 @@
2222
AGENT_4o_MINI,
2323
AGENT_4o_MINI_VISION,
2424
AGENT_4o_VISION,
25+
AGENT_AZURE_4o_MINI,
26+
AGENT_AZURE_4o,
27+
AGENT_AZURE_4o_VISION,
28+
AGENT_AZURE_4o_MINI_VISION,
29+
AGENT_AZURE_41,
30+
AGENT_AZURE_41_MINI,
31+
AGENT_AZURE_41_NANO,
32+
AGENT_AZURE_41_VISION,
33+
AGENT_AZURE_41_MINI_VISION,
34+
AGENT_AZURE_41_NANO_VISION,
35+
AGENT_AZURE_5,
36+
AGENT_AZURE_5_MINI,
37+
AGENT_AZURE_5_NANO,
38+
AGENT_AZURE_5_VISION,
39+
AGENT_AZURE_5_MINI_VISION,
40+
AGENT_AZURE_5_NANO_VISION,
2541
AGENT_o1_MINI,
2642
AGENT_o3_MINI,
2743
FLAGS_GPT_4o,
@@ -46,6 +62,22 @@
4662
"AGENT_37_SONNET",
4763
"AGENT_4o_VISION",
4864
"AGENT_4o_MINI_VISION",
65+
"AGENT_AZURE_4o_MINI",
66+
"AGENT_AZURE_4o",
67+
"AGENT_AZURE_4o_VISION",
68+
"AGENT_AZURE_4o_MINI_VISION",
69+
"AGENT_AZURE_41",
70+
"AGENT_AZURE_41_MINI",
71+
"AGENT_AZURE_41_NANO",
72+
"AGENT_AZURE_41_VISION",
73+
"AGENT_AZURE_41_MINI_VISION",
74+
"AGENT_AZURE_41_NANO_VISION",
75+
"AGENT_AZURE_5",
76+
"AGENT_AZURE_5_MINI",
77+
"AGENT_AZURE_5_NANO",
78+
"AGENT_AZURE_5_VISION",
79+
"AGENT_AZURE_5_MINI_VISION",
80+
"AGENT_AZURE_5_NANO_VISION",
4981
"AGENT_CLAUDE_SONNET_35_VISION",
5082
"AGENT_GPT5_MINI",
5183
]

src/agentlab/agents/generic_agent/agent_configs.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,43 @@
262262
chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-4o-mini-2024-07-18"],
263263
flags=FLAGS_GPT_4o,
264264
)
265+
266+
AGENT_AZURE_4o_MINI = GenericAgentArgs(
267+
chat_model_args=CHAT_MODEL_ARGS_DICT["azure/gpt-4o-mini-2024-07-18"],
268+
flags=FLAGS_GPT_4o,
269+
)
270+
AGENT_AZURE_4o = GenericAgentArgs(
271+
chat_model_args=CHAT_MODEL_ARGS_DICT["azure/gpt-4o-2024-08-06"],
272+
flags=FLAGS_GPT_4o,
273+
)
274+
AGENT_AZURE_41 = GenericAgentArgs(
275+
chat_model_args=CHAT_MODEL_ARGS_DICT["azure/gpt-4.1-2025-04-14"],
276+
flags=FLAGS_GPT_4o,
277+
)
278+
AGENT_AZURE_41_MINI = GenericAgentArgs(
279+
chat_model_args=CHAT_MODEL_ARGS_DICT["azure/gpt-4.1-mini-2025-04-14"],
280+
flags=FLAGS_GPT_4o,
281+
)
282+
AGENT_AZURE_41_NANO = GenericAgentArgs(
283+
chat_model_args=CHAT_MODEL_ARGS_DICT["azure/gpt-4.1-nano-2025-04-14"],
284+
flags=FLAGS_GPT_4o,
285+
)
286+
287+
AGENT_AZURE_5 = GenericAgentArgs(
288+
chat_model_args=CHAT_MODEL_ARGS_DICT["azure/gpt-5-2025-08-07"],
289+
flags=FLAGS_GPT_4o,
290+
)
291+
292+
AGENT_AZURE_5_MINI = GenericAgentArgs(
293+
chat_model_args=CHAT_MODEL_ARGS_DICT["azure/gpt-5-mini-2025-08-07"],
294+
flags=FLAGS_GPT_4o,
295+
)
296+
297+
AGENT_AZURE_5_NANO = GenericAgentArgs(
298+
chat_model_args=CHAT_MODEL_ARGS_DICT["azure/gpt-5-nano-2025-08-07"],
299+
flags=FLAGS_GPT_4o,
300+
)
301+
265302
AGENT_CLAUDE_SONNET_35 = GenericAgentArgs(
266303
chat_model_args=CHAT_MODEL_ARGS_DICT["openrouter/anthropic/claude-3.5-sonnet:beta"],
267304
flags=FLAGS_GPT_4o,
@@ -298,6 +335,45 @@
298335
flags=FLAGS_GPT_4o_VISION,
299336
)
300337

338+
AGENT_AZURE_4o_VISION = GenericAgentArgs(
339+
chat_model_args=CHAT_MODEL_ARGS_DICT["azure/gpt-4o-2024-08-06"],
340+
flags=FLAGS_GPT_4o_VISION,
341+
)
342+
343+
AGENT_AZURE_4o_MINI_VISION = GenericAgentArgs(
344+
chat_model_args=CHAT_MODEL_ARGS_DICT["azure/gpt-4o-mini-2024-07-18"],
345+
flags=FLAGS_GPT_4o_VISION,
346+
)
347+
348+
AGENT_AZURE_41_VISION = GenericAgentArgs(
349+
chat_model_args=CHAT_MODEL_ARGS_DICT["azure/gpt-4.1-2025-04-14"],
350+
flags=FLAGS_GPT_4o_VISION,
351+
)
352+
353+
AGENT_AZURE_41_MINI_VISION = GenericAgentArgs(
354+
chat_model_args=CHAT_MODEL_ARGS_DICT["azure/gpt-4.1-mini-2025-04-14"],
355+
flags=FLAGS_GPT_4o_VISION,
356+
)
357+
AGENT_AZURE_41_NANO_VISION = GenericAgentArgs(
358+
chat_model_args=CHAT_MODEL_ARGS_DICT["azure/gpt-4.1-nano-2025-04-14"],
359+
flags=FLAGS_GPT_4o_VISION,
360+
)
361+
362+
AGENT_AZURE_5_VISION = GenericAgentArgs(
363+
chat_model_args=CHAT_MODEL_ARGS_DICT["azure/gpt-5-2025-08-07"],
364+
flags=FLAGS_GPT_4o_VISION,
365+
)
366+
367+
AGENT_AZURE_5_MINI_VISION = GenericAgentArgs(
368+
chat_model_args=CHAT_MODEL_ARGS_DICT["azure/gpt-5-mini-2025-08-07"],
369+
flags=FLAGS_GPT_4o_VISION,
370+
)
371+
372+
AGENT_AZURE_5_NANO_VISION = GenericAgentArgs(
373+
chat_model_args=CHAT_MODEL_ARGS_DICT["azure/gpt-5-nano-2025-08-07"],
374+
flags=FLAGS_GPT_4o_VISION,
375+
)
376+
301377
AGENT_CLAUDE_SONNET_35_VISION = GenericAgentArgs(
302378
chat_model_args=CHAT_MODEL_ARGS_DICT["openrouter/anthropic/claude-3.5-sonnet:beta"],
303379
flags=FLAGS_GPT_4o_VISION,

src/agentlab/agents/generic_agent_hinter/generic_agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class GenericAgentArgs(AgentArgs):
3838

3939
def __post_init__(self):
4040
try: # some attributes might be temporarily args.CrossProd for hyperparameter generation
41-
self.agent_name = f"GenericAgent-{self.chat_model_args.model_name}".replace("/", "_")
41+
self.agent_name = f"GenericAgent-hinter-{self.chat_model_args.model_name}".replace("/", "_")
4242
except AttributeError:
4343
pass
4444

src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py

Lines changed: 93 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,13 @@ class GenericPromptFlags(dp.Flags):
6161
add_missparsed_messages: bool = True
6262
max_trunc_itr: int = 20
6363
flag_group: str = None
64+
# hint flags
65+
hint_type: Literal["human", "llm", "docs"] = "human"
66+
hint_index_type: Literal["sparse", "dense"] = "sparse"
67+
hint_query_type: Literal["direct", "llm", "emb"] = "direct"
68+
hint_index_path: str = None
69+
hint_retriever_path: str = None
70+
hint_num_results: int = 5
6471
n_retrieval_queries: int = 3
6572
hint_level: Literal["episode", "step"] = "episode"
6673

@@ -120,6 +127,13 @@ def time_for_caution():
120127
hint_retrieval_mode=flags.task_hint_retrieval_mode,
121128
llm=llm,
122129
skip_hints_for_current_task=flags.skip_hints_for_current_task,
130+
# hint related
131+
hint_type=flags.hint_type,
132+
hint_index_type=flags.hint_index_type,
133+
hint_query_type=flags.hint_query_type,
134+
hint_index_path=flags.hint_index_path,
135+
hint_retriever_path=flags.hint_retriever_path,
136+
hint_num_results=flags.hint_num_results,
123137
hint_level=flags.hint_level,
124138
queries=queries,
125139
)
@@ -307,14 +321,26 @@ def __init__(
307321
use_task_hint: bool,
308322
hint_db_path: str,
309323
goal: str,
310-
hint_retrieval_mode: Literal["direct", "llm", "emb"],
311-
skip_hints_for_current_task: bool,
312324
llm: ChatModel,
325+
hint_type: Literal["human", "llm", "docs"] = "human",
326+
hint_index_type: Literal["sparse", "dense"] = "sparse",
327+
hint_query_type: Literal["direct", "llm", "emb"] = "direct",
328+
hint_index_path: str = None,
329+
hint_retriever_path: str = None,
330+
hint_num_results: int = 5,
331+
skip_hints_for_current_task: bool = False,
332+
hint_retrieval_mode: Literal["direct", "llm", "emb"] = "direct",
313333
hint_level: Literal["episode", "step"] = "episode",
314334
queries: list[str] | None = None,
315335
) -> None:
316336
super().__init__(visible=use_task_hint)
317337
self.use_task_hint = use_task_hint
338+
self.hint_type = hint_type
339+
self.hint_index_type = hint_index_type
340+
self.hint_query_type = hint_query_type
341+
self.hint_index_path = hint_index_path
342+
self.hint_retriever_path = hint_retriever_path
343+
self.hint_num_results = hint_num_results
318344
self.hint_db_rel_path = "hint_db.csv"
319345
self.hint_db_path = hint_db_path # Allow external path override
320346
self.hint_retrieval_mode: Literal["direct", "llm", "emb"] = hint_retrieval_mode
@@ -343,29 +369,47 @@ def __init__(
343369
def _init(self):
344370
"""Initialize the block."""
345371
try:
346-
# Use external path if provided, otherwise fall back to relative path
347-
if self.hint_db_path and Path(self.hint_db_path).exists():
348-
hint_db_path = Path(self.hint_db_path)
372+
if self.hint_type == "docs":
373+
if self.hint_index_type == "sparse":
374+
print("Loading sparse hint index")
375+
import bm25s
376+
self.hint_index = bm25s.BM25.load(self.hint_index_path, load_corpus=True)
377+
print("Sparse hint index loaded successfully")
378+
elif self.hint_index_type == "dense":
379+
print("Loading dense hint index and retriever")
380+
from datasets import load_from_disk
381+
from sentence_transformers import SentenceTransformer
382+
self.hint_index = load_from_disk(self.hint_index_path)
383+
self.hint_index.load_faiss_index("embeddings", self.hint_index_path.removesuffix("/") + ".faiss")
384+
print("Dense hint index loaded successfully")
385+
self.hint_retriever = SentenceTransformer(self.hint_retriever_path)
386+
print("Hint retriever loaded successfully")
387+
else:
388+
raise ValueError(f"Unknown hint index type: {self.hint_index_type}")
349389
else:
350-
hint_db_path = Path(__file__).parent / self.hint_db_rel_path
351-
352-
if hint_db_path.exists():
353-
self.hint_db = pd.read_csv(hint_db_path, header=0, index_col=None, dtype=str)
354-
# Verify the expected columns exist
355-
if "task_name" not in self.hint_db.columns or "hint" not in self.hint_db.columns:
356-
print(
357-
f"Warning: Hint database missing expected columns. Found: {list(self.hint_db.columns)}"
358-
)
390+
# Use external path if provided, otherwise fall back to relative path
391+
if self.hint_db_path and Path(self.hint_db_path).exists():
392+
hint_db_path = Path(self.hint_db_path)
393+
else:
394+
hint_db_path = Path(__file__).parent / self.hint_db_rel_path
395+
396+
if hint_db_path.exists():
397+
self.hint_db = pd.read_csv(hint_db_path, header=0, index_col=None, dtype=str)
398+
# Verify the expected columns exist
399+
if "task_name" not in self.hint_db.columns or "hint" not in self.hint_db.columns:
400+
print(
401+
f"Warning: Hint database missing expected columns. Found: {list(self.hint_db.columns)}"
402+
)
403+
self.hint_db = pd.DataFrame(columns=["task_name", "hint"])
404+
else:
405+
print(f"Warning: Hint database not found at {hint_db_path}")
359406
self.hint_db = pd.DataFrame(columns=["task_name", "hint"])
360-
else:
361-
print(f"Warning: Hint database not found at {hint_db_path}")
362-
self.hint_db = pd.DataFrame(columns=["task_name", "hint"])
363-
364-
self.hints_source = HintsSource(
365-
hint_db_path=hint_db_path.as_posix(),
366-
hint_retrieval_mode=self.hint_retrieval_mode,
367-
skip_hints_for_current_task=self.skip_hints_for_current_task,
368-
)
407+
408+
self.hints_source = HintsSource(
409+
hint_db_path=hint_db_path.as_posix(),
410+
hint_retrieval_mode=self.hint_retrieval_mode,
411+
skip_hints_for_current_task=self.skip_hints_for_current_task,
412+
)
369413
except Exception as e:
370414
# Fallback to empty database on any error
371415
print(f"Warning: Could not load hint database: {e}")
@@ -376,6 +420,32 @@ def get_hints_for_task(self, task_name: str) -> str:
376420
if not self.use_task_hint:
377421
return ""
378422

423+
if self.hint_type == "docs":
424+
if not hasattr(self, "hint_index"):
425+
self._init()
426+
427+
if self.hint_query_type == "goal":
428+
query = self.goal
429+
elif self.hint_query_type == "llm":
430+
query = self.llm.generate(self._prompt + self._abstract_ex + self._concrete_ex)
431+
else:
432+
raise ValueError(f"Unknown hint query type: {self.hint_query_type}")
433+
434+
if self.hint_index_type == "sparse":
435+
query_tokens = bm25s.tokenize(query)
436+
docs = self.hint_index.search(query_tokens, k=self.hint_num_results)
437+
docs = docs["text"]
438+
elif self.hint_index_type == "dense":
439+
query_embedding = self.hint_retriever.encode(query)
440+
_, docs = self.hint_index.get_nearest_examples("embeddings", query_embedding, k=self.hint_num_results)
441+
docs = docs["text"]
442+
443+
hints_str = (
444+
"# Hints:\nHere are some hints for the task you are working on:\n"
445+
+ "\n".join(docs)
446+
)
447+
return hints_str
448+
379449
# Ensure hint_db is initialized
380450
if not hasattr(self, "hint_db"):
381451
self._init()
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from typing_extensions import Protocol
2+
3+
from agentlab.agents.agent_args import AgentArgs
4+
5+
6+
class MultiCandidateAgent(Protocol):
7+
"""
8+
Protocol for agents that generate multiple candidates for get_action.
9+
10+
This protocol defines the contract for agents that can generate
11+
multiple candidate actions and allow selection of one of them for execution.
12+
"""
13+
14+
def get_candidate_generations(
15+
self, obs: dict, hint: list[str] | None = None, n_candidates: int = 3
16+
) -> "list[dict]":
17+
"""
18+
Generate multiple candidate actions for the given observation.
19+
20+
You can pass extra info in agent_info to update internal state of the
21+
agent based on the selected candidate. Your internal state management
22+
should be robust to multiple calls to the get_candidate_generations method
23+
in a single step.
24+
25+
Args:
26+
obs: The current observation dictionary containing environment state
27+
hint: Optional list of hint strings to guide candidate generation
28+
n_candidates: Number of candidate actions to generate
29+
"""
30+
...
31+
32+
def update_agent_state_from_selected_candidate(self, output: dict):
33+
"""
34+
Update the agent's internal state based on the selected candidate.
35+
This can include any memory or planning updates.
36+
37+
Args:
38+
output: The selected candidate action dictionary
39+
"""
40+
pass
41+
42+
43+
class MultiCandidateAgentArgs(AgentArgs):
44+
def make_agent(self) -> MultiCandidateAgent: ...
45+
46+
def __post_init__(self):
47+
"""Prefix subagent name with 'MC-'."""
48+
super().__post_init__()
49+
if hasattr(self, "agent_name") and self.agent_name:
50+
self.agent_name = "MC-" + self.agent_name

0 commit comments

Comments
 (0)