Skip to content

Commit 26f0abb

Browse files
committed
pass new flag and fix db path passing issue
1 parent 5315f14 commit 26f0abb

File tree

2 files changed

+14
-5
lines changed

2 files changed

+14
-5
lines changed

src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class GenericPromptFlags(dp.Flags):
5151
use_hints: bool = False
5252
use_task_hint: bool = False
5353
task_hint_retrieval_mode: Literal["direct", "llm", "emb"] = "direct"
54+
skip_hints_for_current_task: bool = False
5455
hint_db_path: str = None
5556
enable_chat: bool = False
5657
max_prompt_tokens: int = None
@@ -113,6 +114,7 @@ def time_for_caution():
113114
goal=goal,
114115
hint_retrieval_mode=flags.task_hint_retrieval_mode,
115116
llm=llm,
117+
skip_hints_for_current_task=flags.skip_hints_for_current_task,
116118
)
117119
self.plan = Plan(previous_plan, step, lambda: flags.use_plan) # TODO add previous plan
118120
self.criticise = Criticise(visible=lambda: flags.use_criticise)
@@ -299,13 +301,15 @@ def __init__(
299301
hint_db_path: str,
300302
goal: str,
301303
hint_retrieval_mode: Literal["direct", "llm", "emb"],
304+
skip_hints_for_current_task: bool,
302305
llm: ChatModel,
303306
) -> None:
304307
super().__init__(visible=use_task_hint)
305308
self.use_task_hint = use_task_hint
306309
self.hint_db_rel_path = "hint_db.csv"
307310
self.hint_db_path = hint_db_path # Allow external path override
308311
self.hint_retrieval_mode: Literal["direct", "llm", "emb"] = hint_retrieval_mode
312+
self.skip_hints_for_current_task = skip_hints_for_current_task
309313
self.goal = goal
310314
self.llm = llm
311315
self._init()
@@ -346,8 +350,9 @@ def _init(self):
346350
print(f"Warning: Hint database not found at {hint_db_path}")
347351
self.hint_db = pd.DataFrame(columns=["task_name", "hint"])
348352
self.hints_source = HintsSource(
349-
hint_db_path=self.hint_db_rel_path,
353+
hint_db_path=hint_db_path.as_posix(),
350354
hint_retrieval_mode=self.hint_retrieval_mode,
355+
skip_hints_for_current_task=self.skip_hints_for_current_task,
351356
)
352357
except Exception as e:
353358
# Fallback to empty database on any error

src/agentlab/agents/tool_use_agent/tool_use_agent.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,7 @@ def __init__(
375375
else:
376376
self.hint_db_path = (Path(__file__).parent / self.hint_db_path).as_posix()
377377
self.hint_db = pd.read_csv(self.hint_db_path, header=0, index_col=None, dtype=str)
378+
logger.info(f"Loaded {len(self.hint_db)} hints from database {self.hint_db_path}")
378379
if self.hint_retrieval_mode == "emb":
379380
self.load_hint_vectors()
380381

@@ -395,16 +396,19 @@ def load_hint_vectors(self):
395396

396397
def choose_hints(self, llm, task_name: str, goal: str) -> list[str]:
397398
"""Choose hints based on the task name."""
399+
logger.info(
400+
f"Choosing hints for task: {task_name}, goal: {goal} from db: {self.hint_db_path} using mode: {self.hint_retrieval_mode}"
401+
)
398402
if self.hint_retrieval_mode == "llm":
399-
return self.choose_hints_llm(llm, goal)
403+
return self.choose_hints_llm(llm, goal, task_name)
400404
elif self.hint_retrieval_mode == "direct":
401405
return self.choose_hints_direct(task_name)
402406
elif self.hint_retrieval_mode == "emb":
403-
return self.choose_hints_emb(goal)
407+
return self.choose_hints_emb(goal, task_name)
404408
else:
405409
raise ValueError(f"Unknown hint retrieval mode: {self.hint_retrieval_mode}")
406410

407-
def choose_hints_llm(self, llm, goal: str) -> list[str]:
411+
def choose_hints_llm(self, llm, goal: str, task_name: str) -> list[str]:
408412
"""Choose hints using LLM to filter the hints."""
409413
topic_to_hints = defaultdict(list)
410414
hints_df = self.hint_db
@@ -439,7 +443,7 @@ def choose_hints_llm(self, llm, goal: str) -> list[str]:
439443
hints = []
440444
return hints
441445

442-
def choose_hints_emb(self, goal: str) -> list[str]:
446+
def choose_hints_emb(self, goal: str, task_name: str) -> list[str]:
443447
"""Choose hints using embeddings to filter the hints."""
444448
goal_embeddings = self._encode([goal], prompt="task description")
445449
hint_embeddings = self.hint_embeddings

0 commit comments

Comments
 (0)