Skip to content

Commit 0bc669e

Browse files
committed
llm retrieval of hints based on the relevance of semantic key of the hint to the given goal
1 parent ed81011 commit 0bc669e

File tree

1 file changed

+51
-5
lines changed

1 file changed

+51
-5
lines changed

src/agentlab/agents/tool_use_agent/tool_use_agent.py

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import fnmatch
22
import json
3+
import logging
34
from abc import ABC, abstractmethod
5+
from collections import defaultdict
46
from copy import copy
57
from dataclasses import asdict, dataclass, field
8+
from email.policy import default
69
from pathlib import Path
7-
from typing import Any
10+
from typing import Any, Literal
811

912
import bgym
1013
import pandas as pd
@@ -34,6 +37,8 @@
3437
)
3538
from agentlab.llm.tracking import cost_tracker_decorator
3639

40+
logger = logging.getLogger(__name__)
41+
3742

3843
@dataclass
3944
class Block(ABC):
@@ -296,6 +301,10 @@ def apply_init(self, llm, discussion: StructuredDiscussion) -> dict:
296301
class TaskHint(Block):
297302
use_task_hint: bool = True
298303
hint_db_rel_path: str = "hint_db.csv"
304+
hint_retrieval_mode: Literal["direct", "llm"] = "direct" # direct or retrieval
305+
llm_prompt: str = """We're choosing hints to help solve the following task:\n{goal}.\n
306+
You need to choose the most relevant hints topic from the following list:\n\nHint topics:\n{topics}\n
307+
Choose hint for the task and return only its numbers, e.g. 1. If you don't know the answer, return -1."""
299308

300309
def _init(self):
301310
"""Initialize the block."""
@@ -306,9 +315,8 @@ def apply(self, llm, discussion: StructuredDiscussion, task_name: str) -> dict:
306315
if not self.use_task_hint:
307316
return
308317

309-
task_hints = self.hint_db[
310-
self.hint_db["task_name"].apply(lambda x: fnmatch.fnmatch(x, task_name))
311-
]
318+
goal = "\n".join([c.get("text", "") for c in discussion.groups[0].messages[1].content])
319+
task_hints = self.choose_hints(llm, task_name, goal)
312320

313321
hints = []
314322
for hint in task_hints["hint"]:
@@ -325,6 +333,44 @@ def apply(self, llm, discussion: StructuredDiscussion, task_name: str) -> dict:
325333

326334
discussion.append(msg)
327335

336+
def choose_hints(self, llm, task_name: str, goal: str) -> pd.DataFrame:
337+
"""Choose hints based on the task name."""
338+
if self.hint_retrieval_mode == "llm":
339+
return self.choose_hints_llm(llm, goal)
340+
elif self.hint_retrieval_mode == "direct":
341+
return self.choose_hints_direct(task_name)
342+
else:
343+
raise ValueError(f"Unknown hint retrieval mode: {self.hint_retrieval_mode}")
344+
345+
def choose_hints_llm(self, llm, goal: str) -> pd.DataFrame:
346+
"""Choose hints using LLM to filter the hints."""
347+
topic_to_hints = defaultdict(list)
348+
for i, row in self.hint_db.iterrows():
349+
topic_to_hints[row["semantic_keys"]].append(i)
350+
hint_topics = list(topic_to_hints.keys())
351+
topics = "\n".join([f"{i}. {h}" for i, h in enumerate(hint_topics)])
352+
prompt = self.llm_prompt.format(goal=goal, topics=topics)
353+
response = llm(APIPayload(messages=[llm.msg.user().add_text(prompt)]))
354+
try:
355+
hint_topic_idx = json.loads(response.think)
356+
if hint_topic_idx < 0 or hint_topic_idx >= len(hint_topics):
357+
logger.error(f"Wrong LLM hint id response: {response.think}, return no hints")
358+
return pd.DataFrame(columns=self.hint_db.columns)
359+
hint_topic = hint_topics[hint_topic_idx]
360+
hint_indices = topic_to_hints[hint_topic]
361+
df = self.hint_db.iloc[hint_indices].copy()
362+
df = df.drop_duplicates(subset=["hint"], keep="first") # leave only unique hints
363+
logger.debug(f"LLM hint topic {hint_topic_idx}, chosen hints: {df['hint'].tolist()}")
364+
except json.JSONDecodeError:
365+
logger.error(f"Failed to parse LLM hint id response: {response.think}, return no hints")
366+
df = pd.DataFrame(columns=self.hint_db.columns)
367+
return df
368+
369+
def choose_hints_direct(self, task_name: str) -> pd.DataFrame:
370+
return self.hint_db[
371+
self.hint_db["task_name"].apply(lambda x: fnmatch.fnmatch(x, task_name))
372+
]
373+
328374

329375
@dataclass
330376
class PromptConfig:
@@ -583,7 +629,7 @@ def get_action(self, obs: Any) -> float:
583629
),
584630
summarizer=Summarizer(do_summary=True),
585631
general_hints=GeneralHints(use_hints=False),
586-
task_hint=TaskHint(use_task_hint=True),
632+
task_hint=TaskHint(use_task_hint=True, hint_retrieval_mode="llm"),
587633
keep_last_n_obs=None,
588634
multiaction=True, # whether to use multi-action or not
589635
# action_subsets=("bid",),

0 commit comments

Comments
 (0)