Skip to content

Commit 55ce26a

Browse files
committed
same hints retrieval for both generic and tooluse agents
1 parent 24a14f2 commit 55ce26a

File tree

3 files changed

+109
-50
lines changed

3 files changed

+109
-50
lines changed

src/agentlab/agents/generic_agent_hinter/generic_agent.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,11 @@ def get_action(self, obs):
111111
previous_plan=self.plan,
112112
step=self.plan_step,
113113
flags=self.flags,
114+
llm=self.chat_llm,
114115
)
115116

116117
# Set task name for task hints if available
117-
if self.flags.use_task_hint and hasattr(self, 'task_name'):
118+
if self.flags.use_task_hint and hasattr(self, "task_name"):
118119
main_prompt.set_task_name(self.task_name)
119120

120121
max_prompt_tokens, max_trunc_itr = self._get_maxes()

src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py

Lines changed: 46 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,16 @@
66

77
import logging
88
from dataclasses import dataclass
9+
from pathlib import Path
10+
from typing import Literal
911

10-
from browsergym.core import action
12+
import pandas as pd
1113
from browsergym.core.action.base import AbstractActionSet
1214

1315
from agentlab.agents import dynamic_prompting as dp
16+
from agentlab.agents.tool_use_agent.tool_use_agent import HintsSource
17+
from agentlab.llm.chat_api import ChatModel
1418
from agentlab.llm.llm_utils import HumanMessage, parse_html_tags_raise
15-
import fnmatch
16-
import pandas as pd
17-
from pathlib import Path
1819

1920

2021
@dataclass
@@ -49,6 +50,7 @@ class GenericPromptFlags(dp.Flags):
4950
use_abstract_example: bool = False
5051
use_hints: bool = False
5152
use_task_hint: bool = False
53+
task_hint_retrieval_mode: Literal["direct", "llm", "emb"] = "direct"
5254
hint_db_path: str = None
5355
enable_chat: bool = False
5456
max_prompt_tokens: int = None
@@ -70,10 +72,12 @@ def __init__(
7072
previous_plan: str,
7173
step: int,
7274
flags: GenericPromptFlags,
75+
llm: ChatModel,
7376
) -> None:
7477
super().__init__()
7578
self.flags = flags
7679
self.history = dp.History(obs_history, actions, memories, thoughts, flags.obs)
80+
goal = obs_history[-1]["goal_object"]
7781
if self.flags.enable_chat:
7882
self.instructions = dp.ChatInstructions(
7983
obs_history[-1]["chat_messages"], extra_instructions=flags.extra_instructions
@@ -84,7 +88,7 @@ def __init__(
8488
"Agent is in goal mode, but multiple user messages are present in the chat. Consider switching to `enable_chat=True`."
8589
)
8690
self.instructions = dp.GoalInstructions(
87-
obs_history[-1]["goal_object"], extra_instructions=flags.extra_instructions
91+
goal, extra_instructions=flags.extra_instructions
8892
)
8993

9094
self.obs = dp.Observation(
@@ -105,7 +109,10 @@ def time_for_caution():
105109
self.hints = dp.Hints(visible=lambda: flags.use_hints)
106110
self.task_hint = TaskHint(
107111
use_task_hint=flags.use_task_hint,
108-
hint_db_path=flags.hint_db_path
112+
hint_db_path=flags.hint_db_path,
113+
goal=goal,
114+
hint_retrieval_mode=flags.task_hint_retrieval_mode,
115+
llm=llm,
109116
)
110117
self.plan = Plan(previous_plan, step, lambda: flags.use_plan) # TODO add previous plan
111118
self.criticise = Criticise(visible=lambda: flags.use_criticise)
@@ -114,12 +121,12 @@ def time_for_caution():
114121
@property
115122
def _prompt(self) -> HumanMessage:
116123
prompt = HumanMessage(self.instructions.prompt)
117-
124+
118125
# Add task hints if enabled
119126
task_hints_text = ""
120-
if self.flags.use_task_hint and hasattr(self, 'task_name'):
127+
if self.flags.use_task_hint and hasattr(self, "task_name"):
121128
task_hints_text = self.task_hint.get_hints_for_task(self.task_name)
122-
129+
123130
prompt.add_text(
124131
f"""\
125132
{self.obs.prompt}\
@@ -286,11 +293,21 @@ def _parse_answer(self, text_answer):
286293

287294

288295
class TaskHint(dp.PromptElement):
289-
def __init__(self, use_task_hint: bool = True, hint_db_path: str = None) -> None:
296+
def __init__(
297+
self,
298+
use_task_hint: bool,
299+
hint_db_path: str,
300+
goal: str,
301+
hint_retrieval_mode: Literal["direct", "llm", "emb"],
302+
llm: ChatModel,
303+
) -> None:
290304
super().__init__(visible=use_task_hint)
291305
self.use_task_hint = use_task_hint
292306
self.hint_db_rel_path = "hint_db.csv"
293307
self.hint_db_path = hint_db_path # Allow external path override
308+
self.hint_retrieval_mode: Literal["direct", "llm", "emb"] = hint_retrieval_mode
309+
self.goal = goal
310+
self.llm = llm
294311
self._init()
295312

296313
_prompt = "" # Task hints are added dynamically in MainPrompt
@@ -316,39 +333,49 @@ def _init(self):
316333
hint_db_path = Path(self.hint_db_path)
317334
else:
318335
hint_db_path = Path(__file__).parent / self.hint_db_rel_path
319-
336+
320337
if hint_db_path.exists():
321338
self.hint_db = pd.read_csv(hint_db_path, header=0, index_col=None, dtype=str)
322339
# Verify the expected columns exist
323340
if "task_name" not in self.hint_db.columns or "hint" not in self.hint_db.columns:
324-
print(f"Warning: Hint database missing expected columns. Found: {list(self.hint_db.columns)}")
341+
print(
342+
f"Warning: Hint database missing expected columns. Found: {list(self.hint_db.columns)}"
343+
)
325344
self.hint_db = pd.DataFrame(columns=["task_name", "hint"])
326345
else:
327346
print(f"Warning: Hint database not found at {hint_db_path}")
328347
self.hint_db = pd.DataFrame(columns=["task_name", "hint"])
348+
self.hints_source = HintsSource(
349+
hint_db_path=self.hint_db_rel_path,
350+
hint_retrieval_mode=self.hint_retrieval_mode,
351+
)
329352
except Exception as e:
330353
# Fallback to empty database on any error
331354
print(f"Warning: Could not load hint database: {e}")
332355
self.hint_db = pd.DataFrame(columns=["task_name", "hint"])
333356

334-
335357
def get_hints_for_task(self, task_name: str) -> str:
336358
"""Get hints for a specific task."""
337359
if not self.use_task_hint:
338360
return ""
339361

340362
# Ensure hint_db is initialized
341-
if not hasattr(self, 'hint_db'):
363+
if not hasattr(self, "hint_db"):
342364
self._init()
343365

344366
# Check if hint_db has the expected structure
345-
if self.hint_db.empty or "task_name" not in self.hint_db.columns or "hint" not in self.hint_db.columns:
367+
if (
368+
self.hint_db.empty
369+
or "task_name" not in self.hint_db.columns
370+
or "hint" not in self.hint_db.columns
371+
):
346372
return ""
347373

348374
try:
349-
task_hints = self.hint_db[
350-
self.hint_db["task_name"].apply(lambda x: fnmatch.fnmatch(x, task_name))
351-
]
375+
# task_hints = self.hint_db[
376+
# self.hint_db["task_name"].apply(lambda x: fnmatch.fnmatch(x, task_name))
377+
# ]
378+
task_hints = self.hints_source.choose_hints(self.llm, task_name, self.goal)
352379

353380
hints = []
354381
for hint in task_hints["hint"]:
@@ -364,5 +391,5 @@ def get_hints_for_task(self, task_name: str) -> str:
364391
return hints_str
365392
except Exception as e:
366393
print(f"Warning: Error getting hints for task {task_name}: {e}")
367-
394+
368395
return ""

src/agentlab/agents/tool_use_agent/tool_use_agent.py

Lines changed: 61 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from agentlab.benchmarks.abstract_env import AbstractBenchmark as AgentLabBenchmark
2929
from agentlab.benchmarks.osworld import OSWorldActionSet
3030
from agentlab.llm.base_api import BaseModelArgs
31+
from agentlab.llm.chat_api import ChatModel
3132
from agentlab.llm.llm_utils import image_to_png_base64_url
3233
from agentlab.llm.response_api import (
3334
APIPayload,
@@ -316,39 +317,21 @@ class TaskHint(Block):
316317

317318
def _init(self):
318319
"""Initialize the block."""
319-
if Path(self.hint_db_rel_path).is_absolute():
320-
hint_db_path = Path(self.hint_db_rel_path)
321-
else:
322-
hint_db_path = Path(__file__).parent / self.hint_db_rel_path
323-
self.hint_db = pd.read_csv(hint_db_path, header=0, index_col=None, dtype=str)
324-
if self.hint_retrieval_mode == "emb":
325-
self.encode_hints()
326-
327-
def oai_embed(self, text: str):
328-
response = self._oai_emb.create(input=text, model="text-embedding-3-small")
329-
return response.data[0].embedding
330-
331-
def encode_hints(self):
332-
self.uniq_hints = self.hint_db.drop_duplicates(subset=["hint"], keep="first")
333-
logger.info(
334-
f"Encoding {len(self.uniq_hints)} unique hints with semantic keys using {self.embedder_model} model."
320+
self.hints_source = HintsSource(
321+
hint_db_path=self.hint_db_rel_path,
322+
hint_retrieval_mode=self.hint_retrieval_mode,
323+
top_n=self.top_n,
324+
embedder_model=self.embedder_model,
325+
embedder_server=self.embedder_server,
326+
llm_prompt=self.llm_prompt,
335327
)
336-
hints = self.uniq_hints["hint"].tolist()
337-
semantic_keys = self.uniq_hints["semantic_keys"].tolist()
338-
lines = [f"{k}: {h}" for h, k in zip(hints, semantic_keys)]
339-
emb_path = f"{self.hint_db_rel_path}.embs.npy"
340-
assert os.path.exists(emb_path), f"Embedding file not found: {emb_path}"
341-
logger.info(f"Loading hint embeddings from: {emb_path}")
342-
emb_dict = np.load(emb_path, allow_pickle=True).item()
343-
self.hint_embeddings = np.array([emb_dict[k] for k in lines])
344-
logger.info(f"Loaded hint embeddings shape: {self.hint_embeddings.shape}")
345328

346329
def apply(self, llm, discussion: StructuredDiscussion, task_name: str) -> dict:
347330
if not self.use_task_hint:
348331
return {}
349332

350333
goal = "\n".join([c.get("text", "") for c in discussion.groups[0].messages[1].content])
351-
task_hints = self.choose_hints(llm, task_name, goal)
334+
task_hints = self.hints_source.choose_hints(llm, task_name, goal)
352335

353336
hints = []
354337
for hint in task_hints:
@@ -365,6 +348,49 @@ def apply(self, llm, discussion: StructuredDiscussion, task_name: str) -> dict:
365348

366349
discussion.append(msg)
367350

351+
352+
class HintsSource:
353+
def __init__(
354+
self,
355+
hint_db_path: str,
356+
hint_retrieval_mode: Literal["direct", "llm", "emb"] = "direct",
357+
top_n: int = 4,
358+
embedder_model: str = "Qwen/Qwen3-Embedding-0.6B",
359+
embedder_server: str = "http://localhost:5000",
360+
llm_prompt: str = """We're choosing hints to help solve the following task:\n{goal}.\n
361+
You need to choose the most relevant hints topic from the following list:\n\nHint topics:\n{topics}\n
362+
Choose hint topic for the task and return only its number, e.g. 1. If you don't know the answer, return -1.""",
363+
) -> None:
364+
self.hint_db_path = hint_db_path
365+
self.hint_retrieval_mode = hint_retrieval_mode
366+
self.top_n = top_n
367+
self.embedder_model = embedder_model
368+
self.embedder_server = embedder_server
369+
self.llm_prompt = llm_prompt
370+
371+
if Path(hint_db_path).is_absolute():
372+
self.hint_db_path = Path(hint_db_path).as_posix()
373+
else:
374+
self.hint_db_path = (Path(__file__).parent / self.hint_db_path).as_posix()
375+
self.hint_db = pd.read_csv(self.hint_db_path, header=0, index_col=None, dtype=str)
376+
if self.hint_retrieval_mode == "emb":
377+
self.load_hint_vectors()
378+
379+
def load_hint_vectors(self):
380+
self.uniq_hints = self.hint_db.drop_duplicates(subset=["hint"], keep="first")
381+
logger.info(
382+
f"Encoding {len(self.uniq_hints)} unique hints with semantic keys using {self.embedder_model} model."
383+
)
384+
hints = self.uniq_hints["hint"].tolist()
385+
semantic_keys = self.uniq_hints["semantic_keys"].tolist()
386+
lines = [f"{k}: {h}" for h, k in zip(hints, semantic_keys)]
387+
emb_path = f"{self.hint_db_path}.embs.npy"
388+
assert os.path.exists(emb_path), f"Embedding file not found: {emb_path}"
389+
logger.info(f"Loading hint embeddings from: {emb_path}")
390+
emb_dict = np.load(emb_path, allow_pickle=True).item()
391+
self.hint_embeddings = np.array([emb_dict[k] for k in lines])
392+
logger.info(f"Loaded hint embeddings shape: {self.hint_embeddings.shape}")
393+
368394
def choose_hints(self, llm, task_name: str, goal: str) -> list[str]:
369395
"""Choose hints based on the task name."""
370396
if self.hint_retrieval_mode == "llm":
@@ -384,11 +410,14 @@ def choose_hints_llm(self, llm, goal: str) -> list[str]:
384410
hint_topics = list(topic_to_hints.keys())
385411
topics = "\n".join([f"{i}. {h}" for i, h in enumerate(hint_topics)])
386412
prompt = self.llm_prompt.format(goal=goal, topics=topics)
387-
response = llm(APIPayload(messages=[llm.msg.user().add_text(prompt)]))
413+
if isinstance(llm, ChatModel):
414+
response: str = llm(messages=[dict(role="user", content=prompt)])["content"]
415+
else:
416+
response: str = llm(APIPayload(messages=[llm.msg.user().add_text(prompt)])).think
388417
try:
389-
hint_topic_idx = json.loads(response.think)
418+
hint_topic_idx = json.loads(response)
390419
if hint_topic_idx < 0 or hint_topic_idx >= len(hint_topics):
391-
logger.error(f"Wrong LLM hint id response: {response.think}, no hints")
420+
logger.error(f"Wrong LLM hint id response: {response}, no hints")
392421
return []
393422
hint_topic = hint_topics[hint_topic_idx]
394423
hint_indices = topic_to_hints[hint_topic]
@@ -397,7 +426,7 @@ def choose_hints_llm(self, llm, goal: str) -> list[str]:
397426
hints = df["hint"].tolist()
398427
logger.debug(f"LLM hint topic {hint_topic_idx}, chosen hints: {df['hint'].tolist()}")
399428
except json.JSONDecodeError:
400-
logger.error(f"Failed to parse LLM hint id response: {response.think}, no hints")
429+
logger.error(f"Failed to parse LLM hint id response: {response}, no hints")
401430
hints = []
402431
return hints
403432

@@ -427,6 +456,7 @@ def _encode(self, texts: list[str], prompt: str = "", timeout: int = 10, max_ret
427456
raise e
428457
time.sleep(random.uniform(1, timeout))
429458
continue
459+
raise ValueError("Failed to encode hints")
430460

431461
def _similarity(
432462
self, texts1: list[str], texts2: list[str], timeout: int = 2, max_retries: int = 5
@@ -446,6 +476,7 @@ def _similarity(
446476
raise e
447477
time.sleep(random.uniform(1, timeout))
448478
continue
479+
raise ValueError("Failed to compute similarity")
449480

450481
def choose_hints_direct(self, task_name: str) -> list[str]:
451482
hints = self.hint_db[

0 commit comments

Comments
 (0)