Skip to content

Commit c86873b

Browse files
(wip) refactor hinting index
1 parent dabddcf commit c86873b

File tree

2 files changed

+61
-10
lines changed

2 files changed

+61
-10
lines changed

src/agentlab/agents/generic_agent_hinter/generic_agent.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,14 @@
1616
import bgym
1717
from bgym import Benchmark
1818
from browsergym.experiments.agent import Agent, AgentInfo
19-
19+
import pandas as pd
20+
from pathlib import Path
2021
from agentlab.agents import dynamic_prompting as dp
2122
from agentlab.agents.agent_args import AgentArgs
2223
from agentlab.llm.chat_api import BaseModelArgs
2324
from agentlab.llm.llm_utils import Discussion, ParseError, SystemMessage, retry
2425
from agentlab.llm.tracking import cost_tracker_decorator
26+
from agentlab.agents.tool_use_agent.tool_use_agent import HintsSource
2527

2628
from .generic_agent_prompt import (
2729
GenericPromptFlags,
@@ -92,6 +94,8 @@ def __init__(
9294
self.action_set = self.flags.action.action_set.make_action_set()
9395
self._obs_preprocessor = dp.make_obs_preprocessor(flags.obs)
9496

97+
self._init_hints_index()
98+
9599
self._check_flag_constancy()
96100
self.reset(seed=None)
97101

@@ -246,3 +250,46 @@ def _get_maxes(self):
246250
else 20 # dangerous to change the default value here?
247251
)
248252
return max_prompt_tokens, max_trunc_itr
253+
254+
def _init_hints_index(self):
255+
"""Initialize the block."""
256+
try:
257+
if self.flags.hint_type == "docs":
258+
if self.flags.hint_index_type == "sparse":
259+
import bm25s
260+
self.hint_index = bm25s.BM25.load(self.flags.hint_index_path, load_corpus=True)
261+
elif self.flags.hint_index_type == "dense":
262+
from datasets import load_from_disk
263+
from sentence_transformers import SentenceTransformer
264+
self.hint_index = load_from_disk(self.flags.hint_index_path)
265+
self.hint_index.load_faiss_index("embeddings", self.flags.hint_index_path.removesuffix("/") + ".faiss")
266+
self.hint_retriever = SentenceTransformer(self.flags.hint_retriever_path)
267+
else:
268+
raise ValueError(f"Unknown hint index type: {self.flags.hint_index_type}")
269+
else:
270+
# Use external path if provided, otherwise fall back to relative path
271+
if self.flags.hint_db_path and Path(self.flags.hint_db_path).exists():
272+
hint_db_path = Path(self.flags.hint_db_path)
273+
else:
274+
hint_db_path = Path(__file__).parent / self.flags.hint_db_rel_path
275+
276+
if hint_db_path.exists():
277+
self.hint_db = pd.read_csv(hint_db_path, header=0, index_col=None, dtype=str)
278+
# Verify the expected columns exist
279+
if "task_name" not in self.hint_db.columns or "hint" not in self.hint_db.columns:
280+
print(
281+
f"Warning: Hint database missing expected columns. Found: {list(self.hint_db.columns)}"
282+
)
283+
self.hint_db = pd.DataFrame(columns=["task_name", "hint"])
284+
else:
285+
print(f"Warning: Hint database not found at {hint_db_path}")
286+
self.hint_db = pd.DataFrame(columns=["task_name", "hint"])
287+
self.hints_source = HintsSource(
288+
hint_db_path=hint_db_path.as_posix(),
289+
hint_retrieval_mode=self.flags.hint_retrieval_mode,
290+
skip_hints_for_current_task=self.flags.skip_hints_for_current_task,
291+
)
292+
except Exception as e:
293+
# Fallback to empty database on any error
294+
print(f"Warning: Could not load hint database: {e}")
295+
self.hint_db = pd.DataFrame(columns=["task_name", "hint"])

src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def __init__(
8080
actions: list[str],
8181
memories: list[str],
8282
thoughts: list[str],
83+
hints: list[str],
8384
previous_plan: str,
8485
step: int,
8586
flags: GenericPromptFlags,
@@ -120,6 +121,7 @@ def time_for_caution():
120121
self.think = dp.Think(visible=lambda: flags.use_thinking)
121122
self.hints = dp.Hints(visible=lambda: flags.use_hints)
122123
goal_str: str = goal[0]["text"]
124+
# TODO: This design is not very good as we will instantiate the loop up at every step
123125
self.task_hint = TaskHint(
124126
use_task_hint=flags.use_task_hint,
125127
hint_db_path=flags.hint_db_path,
@@ -147,7 +149,8 @@ def _prompt(self) -> HumanMessage:
147149

148150
# Add task hints if enabled
149151
task_hints_text = ""
150-
if self.flags.use_task_hint and hasattr(self, "task_name"):
152+
# if self.flags.use_task_hint and hasattr(self, "task_name"):
153+
if self.flags.use_task_hint:
151154
task_hints_text = self.task_hint.get_hints_for_task(self.task_name)
152155

153156
prompt.add_text(
@@ -371,19 +374,14 @@ def _init(self):
371374
try:
372375
if self.hint_type == "docs":
373376
if self.hint_index_type == "sparse":
374-
print("Loading sparse hint index")
375377
import bm25s
376378
self.hint_index = bm25s.BM25.load(self.hint_index_path, load_corpus=True)
377-
print("Sparse hint index loaded successfully")
378379
elif self.hint_index_type == "dense":
379-
print("Loading dense hint index and retriever")
380380
from datasets import load_from_disk
381381
from sentence_transformers import SentenceTransformer
382382
self.hint_index = load_from_disk(self.hint_index_path)
383383
self.hint_index.load_faiss_index("embeddings", self.hint_index_path.removesuffix("/") + ".faiss")
384-
print("Dense hint index loaded successfully")
385384
self.hint_retriever = SentenceTransformer(self.hint_retriever_path)
386-
print("Hint retriever loaded successfully")
387385
else:
388386
raise ValueError(f"Unknown hint index type: {self.hint_index_type}")
389387
else:
@@ -422,8 +420,8 @@ def get_hints_for_task(self, task_name: str) -> str:
422420

423421
if self.hint_type == "docs":
424422
if not hasattr(self, "hint_index"):
423+
print("Initializing hint index new time")
425424
self._init()
426-
427425
if self.hint_query_type == "goal":
428426
query = self.goal
429427
elif self.hint_query_type == "llm":
@@ -432,9 +430,15 @@ def get_hints_for_task(self, task_name: str) -> str:
432430
raise ValueError(f"Unknown hint query type: {self.hint_query_type}")
433431

434432
if self.hint_index_type == "sparse":
433+
import bm25s
435434
query_tokens = bm25s.tokenize(query)
436-
docs = self.hint_index.search(query_tokens, k=self.hint_num_results)
437-
docs = docs["text"]
435+
docs, _ = self.hint_index.retrieve(query_tokens, k=self.hint_num_results)
436+
docs = [elem["text"] for elem in docs[0]]
437+
# HACK: truncate to 20k characters (should cover >99% of the cases)
438+
for doc in docs:
439+
if len(doc) > 20000:
440+
doc = doc[:20000]
441+
doc += " ...[truncated]"
438442
elif self.hint_index_type == "dense":
439443
query_embedding = self.hint_retriever.encode(query)
440444
_, docs = self.hint_index.get_nearest_examples("embeddings", query_embedding, k=self.hint_num_results)

0 commit comments

Comments
 (0)