Skip to content

Commit deddc50

Browse files
committed
fix current task hints exclusion
1 parent 5393a34 commit deddc50

File tree

1 file changed

+45
-37
lines changed

1 file changed

+45
-37
lines changed

src/agentlab/agents/tool_use_agent/tool_use_agent.py

Lines changed: 45 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -411,58 +411,62 @@ def choose_hints(self, llm, task_name: str, goal: str) -> list[str]:
411411
def choose_hints_llm(self, llm, goal: str, task_name: str) -> list[str]:
412412
"""Choose hints using LLM to filter the hints."""
413413
topic_to_hints = defaultdict(list)
414-
hints_df = self.hint_db
414+
skip_hints = []
415415
if self.skip_hints_for_current_task:
416-
current_task_hints = self.get_current_task_hints(task_name)
417-
hints_df = hints_df[~hints_df["hint"].isin(current_task_hints)]
418-
logger.info(
419-
f"Filtered out current task hints, remaining hints: {hints_df.shape[0]} out of {self.hint_db.shape[0]}"
420-
)
421-
for i, row in hints_df.iterrows():
422-
topic_to_hints[row["semantic_keys"]].append(i)
416+
skip_hints = self.get_current_task_hints(task_name)
417+
for _, row in self.hint_db.iterrows():
418+
hint = row["hint"]
419+
if hint in skip_hints:
420+
continue
421+
topic_to_hints[row["semantic_keys"]].append(hint)
422+
logger.info(f"Collected {len(topic_to_hints)} hint topics")
423423
hint_topics = list(topic_to_hints.keys())
424424
topics = "\n".join([f"{i}. {h}" for i, h in enumerate(hint_topics)])
425425
prompt = self.llm_prompt.format(goal=goal, topics=topics)
426+
426427
if isinstance(llm, ChatModel):
427428
response: str = llm(messages=[dict(role="user", content=prompt)])["content"]
428429
else:
429430
response: str = llm(APIPayload(messages=[llm.msg.user().add_text(prompt)])).think
430431
try:
431-
hint_topic_idx = json.loads(response)
432-
if hint_topic_idx < 0 or hint_topic_idx >= len(hint_topics):
432+
topic_number = json.loads(response)
433+
if topic_number < 0 or topic_number >= len(hint_topics):
433434
logger.error(f"Wrong LLM hint id response: {response}, no hints")
434435
return []
435-
hint_topic = hint_topics[hint_topic_idx]
436-
hint_indices = topic_to_hints[hint_topic]
437-
df = hints_df.iloc[hint_indices].copy()
438-
df = df.drop_duplicates(subset=["hint"], keep="first") # leave only unique hints
439-
hints = df["hint"].tolist()
440-
logger.info(f"LLM hint topic {hint_topic_idx}, chosen hints: {df['hint'].tolist()}")
441-
except json.JSONDecodeError:
442-
logger.error(f"Failed to parse LLM hint id response: {response}, no hints")
436+
hint_topic = hint_topics[topic_number]
437+
hints = list(set(topic_to_hints[hint_topic]))
438+
logger.info(f"LLM hint topic {topic_number}:'{hint_topic}', chosen hints: {hints}")
439+
except Exception as e:
440+
logger.exception(f"Failed to parse LLM hint id response: {response}:\n{e}")
443441
hints = []
444442
return hints
445443

446444
def choose_hints_emb(self, goal: str, task_name: str) -> list[str]:
447445
"""Choose hints using embeddings to filter the hints."""
448-
goal_embeddings = self._encode([goal], prompt="task description")
449-
hint_embeddings = self.hint_embeddings
450-
hints_df = self.uniq_hints
451-
if self.skip_hints_for_current_task:
452-
current_task_hints = self.get_current_task_hints(task_name)
453-
mask = ~hints_df["hint"].isin(current_task_hints)
454-
hints_df = hints_df[mask]
455-
filtered_indices = hints_df.index.tolist()
456-
hint_embeddings = hint_embeddings[filtered_indices]
457-
logger.info(
458-
f"Filtered same task hint, remained: {len(hint_embeddings)} out of {len(self.hint_embeddings)} embeddings"
459-
)
460-
similarities = self._similarity(goal_embeddings.tolist(), hint_embeddings.tolist())
461-
top_indices = similarities.argsort()[0][-self.top_n :].tolist()
462-
logger.info(f"Top hint indices based on embedding similarity: {top_indices}")
463-
hints = hints_df.iloc[top_indices]
464-
logger.info(f"Embedding-based hints chosen: {hints}")
465-
return hints["hint"].tolist()
446+
try:
447+
goal_embeddings = self._encode([goal], prompt="task description")
448+
hint_embeddings = self.hint_embeddings.copy()
449+
all_hints = self.uniq_hints["hint"].tolist()
450+
skip_hints = []
451+
if self.skip_hints_for_current_task:
452+
skip_hints = self.get_current_task_hints(task_name)
453+
hint_embeddings = []
454+
id_to_hint = {}
455+
for hint, emb in zip(all_hints, self.hint_embeddings):
456+
if hint in skip_hints:
457+
continue
458+
hint_embeddings.append(emb.tolist())
459+
id_to_hint[len(hint_embeddings) - 1] = hint
460+
logger.info(f"Prepared hint embeddings for {len(hint_embeddings)} hints")
461+
similarities = self._similarity(goal_embeddings.tolist(), hint_embeddings)
462+
top_indices = similarities.argsort()[0][-self.top_n :].tolist()
463+
logger.info(f"Top hint indices based on embedding similarity: {top_indices}")
464+
hints = [id_to_hint[idx] for idx in top_indices]
465+
logger.info(f"Embedding-based hints chosen: {hints}")
466+
except Exception as e:
467+
logger.exception(f"Failed to choose hints using embeddings: {e}")
468+
hints = []
469+
return hints
466470

467471
def _encode(self, texts: list[str], prompt: str = "", timeout: int = 10, max_retries: int = 5):
468472
"""Call the encode API endpoint with timeout and retries"""
@@ -483,7 +487,11 @@ def _encode(self, texts: list[str], prompt: str = "", timeout: int = 10, max_ret
483487
raise ValueError("Failed to encode hints")
484488

485489
def _similarity(
486-
self, texts1: list[str], texts2: list[str], timeout: int = 2, max_retries: int = 5
490+
self,
491+
texts1: list,
492+
texts2: list,
493+
timeout: int = 2,
494+
max_retries: int = 5,
487495
):
488496
"""Call the similarity API endpoint with timeout and retries"""
489497
for attempt in range(max_retries):

0 commit comments

Comments
 (0)