Skip to content

Commit 69048c4

Browse files
update hinting agent retrieval
1 parent a43e54d commit 69048c4

File tree

2 files changed

+93
-23
lines changed

2 files changed

+93
-23
lines changed

src/agentlab/agents/generic_agent_hinter/generic_agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class GenericAgentArgs(AgentArgs):
3434

3535
def __post_init__(self):
3636
try: # some attributes might be temporarily args.CrossProd for hyperparameter generation
37-
self.agent_name = f"GenericAgent-{self.chat_model_args.model_name}".replace("/", "_")
37+
self.agent_name = f"GenericAgent-hinter-{self.chat_model_args.model_name}".replace("/", "_")
3838
except AttributeError:
3939
pass
4040

src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py

Lines changed: 92 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,13 @@ class GenericPromptFlags(dp.Flags):
6060
add_missparsed_messages: bool = True
6161
max_trunc_itr: int = 20
6262
flag_group: str = None
63+
# hint flags
64+
hint_type: Literal["human", "llm", "docs"] = "human"
65+
hint_index_type: Literal["sparse", "dense"] = "sparse"
66+
hint_query_type: Literal["direct", "llm", "emb"] = "direct"
67+
hint_index_path: str = None
68+
hint_retriever_path: str = None
69+
hint_num_results: int = 5
6370

6471

6572
class MainPrompt(dp.Shrinkable):
@@ -116,6 +123,13 @@ def time_for_caution():
116123
hint_retrieval_mode=flags.task_hint_retrieval_mode,
117124
llm=llm,
118125
skip_hints_for_current_task=flags.skip_hints_for_current_task,
126+
# hint related
127+
hint_type=flags.hint_type,
128+
hint_index_type=flags.hint_index_type,
129+
hint_query_type=flags.hint_query_type,
130+
hint_index_path=flags.hint_index_path,
131+
hint_retriever_path=flags.hint_retriever_path,
132+
hint_num_results=flags.hint_num_results,
119133
)
120134
self.plan = Plan(previous_plan, step, lambda: flags.use_plan) # TODO add previous plan
121135
self.criticise = Criticise(visible=lambda: flags.use_criticise)
@@ -301,12 +315,24 @@ def __init__(
301315
use_task_hint: bool,
302316
hint_db_path: str,
303317
goal: str,
304-
hint_retrieval_mode: Literal["direct", "llm", "emb"],
305-
skip_hints_for_current_task: bool,
306318
llm: ChatModel,
319+
hint_type: Literal["human", "llm", "docs"] = "human",
320+
hint_index_type: Literal["sparse", "dense"] = "sparse",
321+
hint_query_type: Literal["direct", "llm", "emb"] = "direct",
322+
hint_index_path: str = None,
323+
hint_retriever_path: str = None,
324+
hint_num_results: int = 5,
325+
skip_hints_for_current_task: bool = False,
326+
hint_retrieval_mode: Literal["direct", "llm", "emb"] = "direct",
307327
) -> None:
308328
super().__init__(visible=use_task_hint)
309329
self.use_task_hint = use_task_hint
330+
self.hint_type = hint_type
331+
self.hint_index_type = hint_index_type
332+
self.hint_query_type = hint_query_type
333+
self.hint_index_path = hint_index_path
334+
self.hint_retriever_path = hint_retriever_path
335+
self.hint_num_results = hint_num_results
310336
self.hint_db_rel_path = "hint_db.csv"
311337
self.hint_db_path = hint_db_path # Allow external path override
312338
self.hint_retrieval_mode: Literal["direct", "llm", "emb"] = hint_retrieval_mode
@@ -333,28 +359,46 @@ def __init__(
333359
def _init(self):
334360
"""Initialize the block."""
335361
try:
336-
# Use external path if provided, otherwise fall back to relative path
337-
if self.hint_db_path and Path(self.hint_db_path).exists():
338-
hint_db_path = Path(self.hint_db_path)
362+
if self.hint_type == "docs":
363+
if self.hint_index_type == "sparse":
364+
print("Loading sparse hint index")
365+
import bm25s
366+
self.hint_index = bm25s.BM25.load(self.hint_index_path, load_corpus=True)
367+
print("Sparse hint index loaded successfully")
368+
elif self.hint_index_type == "dense":
369+
print("Loading dense hint index and retriever")
370+
from datasets import load_from_disk
371+
from sentence_transformers import SentenceTransformer
372+
self.hint_index = load_from_disk(self.hint_index_path)
373+
self.hint_index.load_faiss_index("embeddings", self.hint_index_path.removesuffix("/") + ".faiss")
374+
print("Dense hint index loaded successfully")
375+
self.hint_retriever = SentenceTransformer(self.hint_retriever_path)
376+
print("Hint retriever loaded successfully")
377+
else:
378+
raise ValueError(f"Unknown hint index type: {self.hint_index_type}")
339379
else:
340-
hint_db_path = Path(__file__).parent / self.hint_db_rel_path
341-
342-
if hint_db_path.exists():
343-
self.hint_db = pd.read_csv(hint_db_path, header=0, index_col=None, dtype=str)
344-
# Verify the expected columns exist
345-
if "task_name" not in self.hint_db.columns or "hint" not in self.hint_db.columns:
346-
print(
347-
f"Warning: Hint database missing expected columns. Found: {list(self.hint_db.columns)}"
348-
)
380+
# Use external path if provided, otherwise fall back to relative path
381+
if self.hint_db_path and Path(self.hint_db_path).exists():
382+
hint_db_path = Path(self.hint_db_path)
383+
else:
384+
hint_db_path = Path(__file__).parent / self.hint_db_rel_path
385+
386+
if hint_db_path.exists():
387+
self.hint_db = pd.read_csv(hint_db_path, header=0, index_col=None, dtype=str)
388+
# Verify the expected columns exist
389+
if "task_name" not in self.hint_db.columns or "hint" not in self.hint_db.columns:
390+
print(
391+
f"Warning: Hint database missing expected columns. Found: {list(self.hint_db.columns)}"
392+
)
393+
self.hint_db = pd.DataFrame(columns=["task_name", "hint"])
394+
else:
395+
print(f"Warning: Hint database not found at {hint_db_path}")
349396
self.hint_db = pd.DataFrame(columns=["task_name", "hint"])
350-
else:
351-
print(f"Warning: Hint database not found at {hint_db_path}")
352-
self.hint_db = pd.DataFrame(columns=["task_name", "hint"])
353-
self.hints_source = HintsSource(
354-
hint_db_path=hint_db_path.as_posix(),
355-
hint_retrieval_mode=self.hint_retrieval_mode,
356-
skip_hints_for_current_task=self.skip_hints_for_current_task,
357-
)
397+
self.hints_source = HintsSource(
398+
hint_db_path=hint_db_path.as_posix(),
399+
hint_retrieval_mode=self.hint_retrieval_mode,
400+
skip_hints_for_current_task=self.skip_hints_for_current_task,
401+
)
358402
except Exception as e:
359403
# Fallback to empty database on any error
360404
print(f"Warning: Could not load hint database: {e}")
@@ -365,6 +409,32 @@ def get_hints_for_task(self, task_name: str) -> str:
365409
if not self.use_task_hint:
366410
return ""
367411

412+
if self.hint_type == "docs":
413+
if not hasattr(self, "hint_index"):
414+
self._init()
415+
416+
if self.hint_query_type == "goal":
417+
query = self.goal
418+
elif self.hint_query_type == "llm":
419+
query = self.llm.generate(self._prompt + self._abstract_ex + self._concrete_ex)
420+
else:
421+
raise ValueError(f"Unknown hint query type: {self.hint_query_type}")
422+
423+
if self.hint_index_type == "sparse":
424+
query_tokens = bm25s.tokenize(query)
425+
docs = self.hint_index.search(query_tokens, k=self.hint_num_results)
426+
docs = docs["text"]
427+
elif self.hint_index_type == "dense":
428+
query_embedding = self.hint_retriever.encode(query)
429+
_, docs = self.hint_index.get_nearest_examples("embeddings", query_embedding, k=self.hint_num_results)
430+
docs = docs["text"]
431+
432+
hints_str = (
433+
"# Hints:\nHere are some hints for the task you are working on:\n"
434+
+ "\n".join(docs)
435+
)
436+
return hints_str
437+
368438
# Ensure hint_db is initialized
369439
if not hasattr(self, "hint_db"):
370440
self._init()

0 commit comments

Comments
 (0)