Skip to content

Commit cad1209

Browse files
committed
filter out current task hints if needed
1 parent 55ce26a commit cad1209

File tree

1 file changed

+33
-8
lines changed

1 file changed

+33
-8
lines changed

src/agentlab/agents/tool_use_agent/tool_use_agent.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ def apply(self, llm, discussion: StructuredDiscussion, task_name: str) -> dict:
341341

342342
if len(hints) > 0:
343343
hints_str = (
344-
"# Hints:\nHere are some hints for the task you are working on:\n"
344+
"\n# Hints:\nHere are some hints for the task you are working on:\n"
345345
+ "\n".join(hints)
346346
)
347347
msg = llm.msg.user().add_text(hints_str)
@@ -354,6 +354,7 @@ def __init__(
354354
self,
355355
hint_db_path: str,
356356
hint_retrieval_mode: Literal["direct", "llm", "emb"] = "direct",
357+
skip_hints_for_current_task: bool = False,
357358
top_n: int = 4,
358359
embedder_model: str = "Qwen/Qwen3-Embedding-0.6B",
359360
embedder_server: str = "http://localhost:5000",
@@ -363,6 +364,7 @@ def __init__(
363364
) -> None:
364365
self.hint_db_path = hint_db_path
365366
self.hint_retrieval_mode = hint_retrieval_mode
367+
self.skip_hints_for_current_task = skip_hints_for_current_task
366368
self.top_n = top_n
367369
self.embedder_model = embedder_model
368370
self.embedder_server = embedder_server
@@ -405,7 +407,14 @@ def choose_hints(self, llm, task_name: str, goal: str) -> list[str]:
405407
def choose_hints_llm(self, llm, goal: str) -> list[str]:
406408
"""Choose hints using LLM to filter the hints."""
407409
topic_to_hints = defaultdict(list)
408-
for i, row in self.hint_db.iterrows():
410+
hints_df = self.hint_db
411+
if self.skip_hints_for_current_task:
412+
current_task_hints = self.get_current_task_hints(task_name)
413+
hints_df = hints_df[~hints_df["hint"].isin(current_task_hints)]
414+
logger.info(
415+
f"Filtered out current task hints, remaining hints: {hints_df.shape[0]} out of {self.hint_db.shape[0]}"
416+
)
417+
for i, row in hints_df.iterrows():
409418
topic_to_hints[row["semantic_keys"]].append(i)
410419
hint_topics = list(topic_to_hints.keys())
411420
topics = "\n".join([f"{i}. {h}" for i, h in enumerate(hint_topics)])
@@ -421,10 +430,10 @@ def choose_hints_llm(self, llm, goal: str) -> list[str]:
421430
return []
422431
hint_topic = hint_topics[hint_topic_idx]
423432
hint_indices = topic_to_hints[hint_topic]
424-
df = self.hint_db.iloc[hint_indices].copy()
433+
df = hints_df.iloc[hint_indices].copy()
425434
df = df.drop_duplicates(subset=["hint"], keep="first") # leave only unique hints
426435
hints = df["hint"].tolist()
427-
logger.debug(f"LLM hint topic {hint_topic_idx}, chosen hints: {df['hint'].tolist()}")
436+
logger.info(f"LLM hint topic {hint_topic_idx}, chosen hints: {df['hint'].tolist()}")
428437
except json.JSONDecodeError:
429438
logger.error(f"Failed to parse LLM hint id response: {response}, no hints")
430439
hints = []
@@ -433,10 +442,21 @@ def choose_hints_llm(self, llm, goal: str) -> list[str]:
433442
def choose_hints_emb(self, goal: str) -> list[str]:
434443
"""Choose hints using embeddings to filter the hints."""
435444
goal_embeddings = self._encode([goal], prompt="task description")
436-
similarities = self._similarity(goal_embeddings.tolist(), self.hint_embeddings.tolist())
445+
hint_embeddings = self.hint_embeddings
446+
hints_df = self.uniq_hints
447+
if self.skip_hints_for_current_task:
448+
current_task_hints = self.get_current_task_hints(task_name)
449+
mask = ~hints_df["hint"].isin(current_task_hints)
450+
hints_df = hints_df[mask]
451+
filtered_indices = hints_df.index.tolist()
452+
hint_embeddings = hint_embeddings[filtered_indices]
453+
logger.info(
454+
f"Filtered same task hint, remained: {len(hint_embeddings)} out of {len(self.hint_embeddings)} embeddings"
455+
)
456+
similarities = self._similarity(goal_embeddings.tolist(), hint_embeddings.tolist())
437457
top_indices = similarities.argsort()[0][-self.top_n :].tolist()
438458
logger.info(f"Top hint indices based on embedding similarity: {top_indices}")
439-
hints = self.uniq_hints.iloc[top_indices]
459+
hints = hints_df.iloc[top_indices]
440460
logger.info(f"Embedding-based hints chosen: {hints}")
441461
return hints["hint"].tolist()
442462

@@ -479,10 +499,15 @@ def _similarity(
479499
raise ValueError("Failed to compute similarity")
480500

481501
def choose_hints_direct(self, task_name: str) -> list[str]:
482-
hints = self.hint_db[
502+
hints = self.get_current_task_hints(task_name)
503+
logger.info(f"Direct hints chosen: {hints}")
504+
return hints
505+
506+
def get_current_task_hints(self, task_name):
507+
hints_df = self.hint_db[
483508
self.hint_db["task_name"].apply(lambda x: fnmatch.fnmatch(x, task_name))
484509
]
485-
return hints["hint"].tolist()
510+
return hints_df["hint"].tolist()
486511

487512

488513
@dataclass

0 commit comments

Comments
 (0)