Skip to content

Commit 60ad8e4

Browse files
update hinter agent and prompt
1 parent d2166b3 commit 60ad8e4

File tree

2 files changed

+101
-142
lines changed

2 files changed

+101
-142
lines changed

src/agentlab/agents/generic_agent_hinter/generic_agent.py

Lines changed: 94 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,18 @@
1010

1111
from copy import deepcopy
1212
from dataclasses import asdict, dataclass
13-
from functools import partial
13+
from pathlib import Path
1414
from warnings import warn
1515

16-
import bgym
17-
from bgym import Benchmark
18-
from browsergym.experiments.agent import Agent, AgentInfo
1916
import pandas as pd
20-
from pathlib import Path
2117
from agentlab.agents import dynamic_prompting as dp
2218
from agentlab.agents.agent_args import AgentArgs
2319
from agentlab.llm.chat_api import BaseModelArgs
2420
from agentlab.llm.llm_utils import Discussion, ParseError, SystemMessage, retry
2521
from agentlab.llm.tracking import cost_tracker_decorator
26-
from agentlab.agents.tool_use_agent.tool_use_agent import HintsSource
22+
from agentlab.utils.hinting import HintsSource
23+
from bgym import Benchmark
24+
from browsergym.experiments.agent import Agent, AgentInfo
2725

2826
from .generic_agent_prompt import (
2927
GenericPromptFlags,
@@ -40,7 +38,9 @@ class GenericAgentArgs(AgentArgs):
4038

4139
def __post_init__(self):
4240
try: # some attributes might be temporarily args.CrossProd for hyperparameter generation
43-
self.agent_name = f"GenericAgent-hinter-{self.chat_model_args.model_name}".replace("/", "_")
41+
self.agent_name = f"GenericAgent-hinter-{self.chat_model_args.model_name}".replace(
42+
"/", "_"
43+
)
4444
except AttributeError:
4545
pass
4646

@@ -116,7 +116,9 @@ def get_action(self, obs):
116116
queries, think_queries = self._get_queries()
117117

118118
# use those queries to retrieve from the database and pass to prompt if step-level
119-
queries_for_hints = queries if getattr(self.flags, "hint_level", "episode") == "step" else None
119+
queries_for_hints = (
120+
queries if getattr(self.flags, "hint_level", "episode") == "step" else None
121+
)
120122

121123
main_prompt = MainPrompt(
122124
action_set=self.action_set,
@@ -257,12 +259,16 @@ def _init_hints_index(self):
257259
if self.flags.hint_type == "docs":
258260
if self.flags.hint_index_type == "sparse":
259261
import bm25s
262+
260263
self.hint_index = bm25s.BM25.load(self.flags.hint_index_path, load_corpus=True)
261264
elif self.flags.hint_index_type == "dense":
262265
from datasets import load_from_disk
263266
from sentence_transformers import SentenceTransformer
267+
264268
self.hint_index = load_from_disk(self.flags.hint_index_path)
265-
self.hint_index.load_faiss_index("embeddings", self.flags.hint_index_path.removesuffix("/") + ".faiss")
269+
self.hint_index.load_faiss_index(
270+
"embeddings", self.flags.hint_index_path.removesuffix("/") + ".faiss"
271+
)
266272
self.hint_retriever = SentenceTransformer(self.flags.hint_retriever_path)
267273
else:
268274
raise ValueError(f"Unknown hint index type: {self.flags.hint_index_type}")
@@ -276,7 +282,10 @@ def _init_hints_index(self):
276282
if hint_db_path.exists():
277283
self.hint_db = pd.read_csv(hint_db_path, header=0, index_col=None, dtype=str)
278284
# Verify the expected columns exist
279-
if "task_name" not in self.hint_db.columns or "hint" not in self.hint_db.columns:
285+
if (
286+
"task_name" not in self.hint_db.columns
287+
or "hint" not in self.hint_db.columns
288+
):
280289
print(
281290
f"Warning: Hint database missing expected columns. Found: {list(self.hint_db.columns)}"
282291
)
@@ -292,4 +301,78 @@ def _init_hints_index(self):
292301
except Exception as e:
293302
# Fallback to empty database on any error
294303
print(f"Warning: Could not load hint database: {e}")
295-
self.hint_db = pd.DataFrame(columns=["task_name", "hint"])
304+
self.hint_db = pd.DataFrame(columns=["task_name", "hint"])
305+
306+
def get_hints_for_task(self, task_name: str) -> str:
307+
"""Get hints for a specific task."""
308+
if not self.use_task_hint:
309+
return ""
310+
311+
if self.hint_type == "docs":
312+
if not hasattr(self, "hint_index"):
313+
print("Initializing hint index new time")
314+
self._init()
315+
if self.hint_query_type == "goal":
316+
query = self.goal
317+
elif self.hint_query_type == "llm":
318+
query = self.llm.generate(self._prompt + self._abstract_ex + self._concrete_ex)
319+
else:
320+
raise ValueError(f"Unknown hint query type: {self.hint_query_type}")
321+
322+
if self.hint_index_type == "sparse":
323+
import bm25s
324+
query_tokens = bm25s.tokenize(query)
325+
docs, _ = self.hint_index.retrieve(query_tokens, k=self.hint_num_results)
326+
docs = [elem["text"] for elem in docs[0]]
327+
# HACK: truncate to 20k characters (should cover >99% of the cases)
328+
for doc in docs:
329+
if len(doc) > 20000:
330+
doc = doc[:20000]
331+
doc += " ...[truncated]"
332+
elif self.hint_index_type == "dense":
333+
query_embedding = self.hint_retriever.encode(query)
334+
_, docs = self.hint_index.get_nearest_examples("embeddings", query_embedding, k=self.hint_num_results)
335+
docs = docs["text"]
336+
337+
hints_str = (
338+
"# Hints:\nHere are some hints for the task you are working on:\n"
339+
+ "\n".join(docs)
340+
)
341+
return hints_str
342+
343+
# Check if hint_db has the expected structure
344+
if (
345+
self.hint_db.empty
346+
or "task_name" not in self.hint_db.columns
347+
or "hint" not in self.hint_db.columns
348+
):
349+
return ""
350+
351+
try:
352+
# When step-level, pass queries as goal string to fit the llm_prompt
353+
goal_or_queries = self.goal
354+
if self.hint_level == "step" and self.queries:
355+
goal_or_queries = "\n".join(self.queries)
356+
357+
task_hints = self.hints_source.choose_hints(
358+
self.llm,
359+
task_name,
360+
goal_or_queries,
361+
)
362+
363+
hints = []
364+
for hint in task_hints:
365+
hint = hint.strip()
366+
if hint:
367+
hints.append(f"- {hint}")
368+
369+
if len(hints) > 0:
370+
hints_str = (
371+
"# Hints:\nHere are some hints for the task you are working on:\n"
372+
+ "\n".join(hints)
373+
)
374+
return hints_str
375+
except Exception as e:
376+
print(f"Warning: Error getting hints for task {task_name}: {e}")
377+
378+
return ""

src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py

Lines changed: 7 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def __init__(
8585
step: int,
8686
flags: GenericPromptFlags,
8787
llm: ChatModel,
88-
queries: list[str] | None = None,
88+
task_hints: list[str] = [],
8989
) -> None:
9090
super().__init__()
9191
self.flags = flags
@@ -120,25 +120,7 @@ def time_for_caution():
120120
self.be_cautious = dp.BeCautious(visible=time_for_caution)
121121
self.think = dp.Think(visible=lambda: flags.use_thinking)
122122
self.hints = dp.Hints(visible=lambda: flags.use_hints)
123-
goal_str: str = goal[0]["text"]
124-
# TODO: This design is not very good as we will instantiate the loop up at every step
125-
self.task_hint = TaskHint(
126-
use_task_hint=flags.use_task_hint,
127-
hint_db_path=flags.hint_db_path,
128-
goal=goal_str,
129-
hint_retrieval_mode=flags.task_hint_retrieval_mode,
130-
llm=llm,
131-
skip_hints_for_current_task=flags.skip_hints_for_current_task,
132-
# hint related
133-
hint_type=flags.hint_type,
134-
hint_index_type=flags.hint_index_type,
135-
hint_query_type=flags.hint_query_type,
136-
hint_index_path=flags.hint_index_path,
137-
hint_retriever_path=flags.hint_retriever_path,
138-
hint_num_results=flags.hint_num_results,
139-
hint_level=flags.hint_level,
140-
queries=queries,
141-
)
123+
self.task_hints = TaskHint(visible=lambda: flags.use_task_hint, task_hints=task_hints)
142124
self.plan = Plan(previous_plan, step, lambda: flags.use_plan) # TODO add previous plan
143125
self.criticise = Criticise(visible=lambda: flags.use_criticise)
144126
self.memory = Memory(visible=lambda: flags.use_memory)
@@ -147,19 +129,13 @@ def time_for_caution():
147129
def _prompt(self) -> HumanMessage:
148130
prompt = HumanMessage(self.instructions.prompt)
149131

150-
# Add task hints if enabled
151-
task_hints_text = ""
152-
# if self.flags.use_task_hint and hasattr(self, "task_name"):
153-
if self.flags.use_task_hint:
154-
task_hints_text = self.task_hint.get_hints_for_task(self.task_name)
155-
156132
prompt.add_text(
157133
f"""\
158134
{self.obs.prompt}\
159135
{self.history.prompt}\
160136
{self.action_prompt.prompt}\
161137
{self.hints.prompt}\
162-
{task_hints_text}\
138+
{self.task_hint.prompt}\
163139
{self.be_cautious.prompt}\
164140
{self.think.prompt}\
165141
{self.plan.prompt}\
@@ -321,37 +297,11 @@ def _parse_answer(self, text_answer):
321297
class TaskHint(dp.PromptElement):
322298
def __init__(
323299
self,
324-
use_task_hint: bool,
325-
hint_db_path: str,
326-
goal: str,
327-
llm: ChatModel,
328-
hint_type: Literal["human", "llm", "docs"] = "human",
329-
hint_index_type: Literal["sparse", "dense"] = "sparse",
330-
hint_query_type: Literal["direct", "llm", "emb"] = "direct",
331-
hint_index_path: str = None,
332-
hint_retriever_path: str = None,
333-
hint_num_results: int = 5,
334-
skip_hints_for_current_task: bool = False,
335-
hint_retrieval_mode: Literal["direct", "llm", "emb"] = "direct",
336-
hint_level: Literal["episode", "step"] = "episode",
337-
queries: list[str] | None = None,
300+
visible: bool,
301+
task_hints: list[str]
338302
) -> None:
339-
super().__init__(visible=use_task_hint)
340-
self.use_task_hint = use_task_hint
341-
self.hint_type = hint_type
342-
self.hint_index_type = hint_index_type
343-
self.hint_query_type = hint_query_type
344-
self.hint_index_path = hint_index_path
345-
self.hint_retriever_path = hint_retriever_path
346-
self.hint_num_results = hint_num_results
347-
self.hint_db_rel_path = "hint_db.csv"
348-
self.hint_db_path = hint_db_path # Allow external path override
349-
self.hint_retrieval_mode: Literal["direct", "llm", "emb"] = hint_retrieval_mode
350-
self.skip_hints_for_current_task = skip_hints_for_current_task
351-
self.goal = goal
352-
self.llm = llm
353-
self.hint_level: Literal["episode", "step"] = hint_level
354-
self.queries: list[str] | None = queries
303+
super().__init__(visible=visible)
304+
self.task_hints = task_hints
355305

356306
_prompt = "" # Task hints are added dynamically in MainPrompt
357307

@@ -368,80 +318,6 @@ def __init__(
368318
</task_hint>
369319
"""
370320

371-
def get_hints_for_task(self, task_name: str) -> str:
372-
"""Get hints for a specific task."""
373-
if not self.use_task_hint:
374-
return ""
375-
376-
if self.hint_type == "docs":
377-
if not hasattr(self, "hint_index"):
378-
print("Initializing hint index new time")
379-
self._init()
380-
if self.hint_query_type == "goal":
381-
query = self.goal
382-
elif self.hint_query_type == "llm":
383-
query = self.llm.generate(self._prompt + self._abstract_ex + self._concrete_ex)
384-
else:
385-
raise ValueError(f"Unknown hint query type: {self.hint_query_type}")
386-
387-
if self.hint_index_type == "sparse":
388-
import bm25s
389-
query_tokens = bm25s.tokenize(query)
390-
docs, _ = self.hint_index.retrieve(query_tokens, k=self.hint_num_results)
391-
docs = [elem["text"] for elem in docs[0]]
392-
# HACK: truncate to 20k characters (should cover >99% of the cases)
393-
for doc in docs:
394-
if len(doc) > 20000:
395-
doc = doc[:20000]
396-
doc += " ...[truncated]"
397-
elif self.hint_index_type == "dense":
398-
query_embedding = self.hint_retriever.encode(query)
399-
_, docs = self.hint_index.get_nearest_examples("embeddings", query_embedding, k=self.hint_num_results)
400-
docs = docs["text"]
401-
402-
hints_str = (
403-
"# Hints:\nHere are some hints for the task you are working on:\n"
404-
+ "\n".join(docs)
405-
)
406-
return hints_str
407-
408-
# Check if hint_db has the expected structure
409-
if (
410-
self.hint_db.empty
411-
or "task_name" not in self.hint_db.columns
412-
or "hint" not in self.hint_db.columns
413-
):
414-
return ""
415-
416-
try:
417-
# When step-level, pass queries as goal string to fit the llm_prompt
418-
goal_or_queries = self.goal
419-
if self.hint_level == "step" and self.queries:
420-
goal_or_queries = "\n".join(self.queries)
421-
422-
task_hints = self.hints_source.choose_hints(
423-
self.llm,
424-
task_name,
425-
goal_or_queries,
426-
)
427-
428-
hints = []
429-
for hint in task_hints:
430-
hint = hint.strip()
431-
if hint:
432-
hints.append(f"- {hint}")
433-
434-
if len(hints) > 0:
435-
hints_str = (
436-
"# Hints:\nHere are some hints for the task you are working on:\n"
437-
+ "\n".join(hints)
438-
)
439-
return hints_str
440-
except Exception as e:
441-
print(f"Warning: Error getting hints for task {task_name}: {e}")
442-
443-
return ""
444-
445321

446322
class StepWiseContextIdentificationPrompt(dp.Shrinkable):
447323
def __init__(

0 commit comments

Comments
 (0)