Skip to content

Commit d2166b3

Browse files
move HintsSource to separate hinting file
1 parent 66b9692 commit d2166b3

File tree

3 files changed

+190
-173
lines changed

3 files changed

+190
-173
lines changed

src/agentlab/agents/tool_use_agent/tool_use_agent.py

Lines changed: 1 addition & 173 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
ToolCalls,
4242
)
4343
from agentlab.llm.tracking import cost_tracker_decorator
44+
from agentlab.utils.hinting import HintsSource
4445

4546
logger = logging.getLogger(__name__)
4647

@@ -349,179 +350,6 @@ def apply(self, llm, discussion: StructuredDiscussion, task_name: str) -> dict:
349350
discussion.append(msg)
350351

351352

352-
class HintsSource:
353-
def __init__(
354-
self,
355-
hint_db_path: str,
356-
hint_retrieval_mode: Literal["direct", "llm", "emb"] = "direct",
357-
skip_hints_for_current_task: bool = False,
358-
top_n: int = 4,
359-
embedder_model: str = "Qwen/Qwen3-Embedding-0.6B",
360-
embedder_server: str = "http://localhost:5000",
361-
llm_prompt: str = """We're choosing hints to help solve the following task:\n{goal}.\n
362-
You need to choose the most relevant hints topic from the following list:\n\nHint topics:\n{topics}\n
363-
Choose hint topic for the task and return only its number, e.g. 1. If you don't know the answer, return -1.""",
364-
) -> None:
365-
self.hint_db_path = hint_db_path
366-
self.hint_retrieval_mode = hint_retrieval_mode
367-
self.skip_hints_for_current_task = skip_hints_for_current_task
368-
self.top_n = top_n
369-
self.embedder_model = embedder_model
370-
self.embedder_server = embedder_server
371-
self.llm_prompt = llm_prompt
372-
373-
if Path(hint_db_path).is_absolute():
374-
self.hint_db_path = Path(hint_db_path).as_posix()
375-
else:
376-
self.hint_db_path = (Path(__file__).parent / self.hint_db_path).as_posix()
377-
self.hint_db = pd.read_csv(self.hint_db_path, header=0, index_col=None, dtype=str)
378-
logger.info(f"Loaded {len(self.hint_db)} hints from database {self.hint_db_path}")
379-
if self.hint_retrieval_mode == "emb":
380-
self.load_hint_vectors()
381-
382-
def load_hint_vectors(self):
383-
self.uniq_hints = self.hint_db.drop_duplicates(subset=["hint"], keep="first")
384-
logger.info(
385-
f"Encoding {len(self.uniq_hints)} unique hints with semantic keys using {self.embedder_model} model."
386-
)
387-
hints = self.uniq_hints["hint"].tolist()
388-
semantic_keys = self.uniq_hints["semantic_keys"].tolist()
389-
lines = [f"{k}: {h}" for h, k in zip(hints, semantic_keys)]
390-
emb_path = f"{self.hint_db_path}.embs.npy"
391-
assert os.path.exists(emb_path), f"Embedding file not found: {emb_path}"
392-
logger.info(f"Loading hint embeddings from: {emb_path}")
393-
emb_dict = np.load(emb_path, allow_pickle=True).item()
394-
self.hint_embeddings = np.array([emb_dict[k] for k in lines])
395-
logger.info(f"Loaded hint embeddings shape: {self.hint_embeddings.shape}")
396-
397-
def choose_hints(self, llm, task_name: str, goal: str) -> list[str]:
398-
"""Choose hints based on the task name."""
399-
logger.info(
400-
f"Choosing hints for task: {task_name}, goal: {goal} from db: {self.hint_db_path} using mode: {self.hint_retrieval_mode}"
401-
)
402-
if self.hint_retrieval_mode == "llm":
403-
return self.choose_hints_llm(llm, goal, task_name)
404-
elif self.hint_retrieval_mode == "direct":
405-
return self.choose_hints_direct(task_name)
406-
elif self.hint_retrieval_mode == "emb":
407-
return self.choose_hints_emb(goal, task_name)
408-
else:
409-
raise ValueError(f"Unknown hint retrieval mode: {self.hint_retrieval_mode}")
410-
411-
def choose_hints_llm(self, llm, goal: str, task_name: str) -> list[str]:
412-
"""Choose hints using LLM to filter the hints."""
413-
topic_to_hints = defaultdict(list)
414-
skip_hints = []
415-
if self.skip_hints_for_current_task:
416-
skip_hints = self.get_current_task_hints(task_name)
417-
for _, row in self.hint_db.iterrows():
418-
hint = row["hint"]
419-
if hint in skip_hints:
420-
continue
421-
topic_to_hints[row["semantic_keys"]].append(hint)
422-
logger.info(f"Collected {len(topic_to_hints)} hint topics")
423-
hint_topics = list(topic_to_hints.keys())
424-
topics = "\n".join([f"{i}. {h}" for i, h in enumerate(hint_topics)])
425-
prompt = self.llm_prompt.format(goal=goal, topics=topics)
426-
427-
if isinstance(llm, ChatModel):
428-
response: str = llm(messages=[dict(role="user", content=prompt)])["content"]
429-
else:
430-
response: str = llm(APIPayload(messages=[llm.msg.user().add_text(prompt)])).think
431-
try:
432-
topic_number = json.loads(response)
433-
if topic_number < 0 or topic_number >= len(hint_topics):
434-
logger.error(f"Wrong LLM hint id response: {response}, no hints")
435-
return []
436-
hint_topic = hint_topics[topic_number]
437-
hints = list(set(topic_to_hints[hint_topic]))
438-
logger.info(f"LLM hint topic {topic_number}:'{hint_topic}', chosen hints: {hints}")
439-
except Exception as e:
440-
logger.exception(f"Failed to parse LLM hint id response: {response}:\n{e}")
441-
hints = []
442-
return hints
443-
444-
def choose_hints_emb(self, goal: str, task_name: str) -> list[str]:
445-
"""Choose hints using embeddings to filter the hints."""
446-
try:
447-
goal_embeddings = self._encode([goal], prompt="task description")
448-
hint_embeddings = self.hint_embeddings.copy()
449-
all_hints = self.uniq_hints["hint"].tolist()
450-
skip_hints = []
451-
if self.skip_hints_for_current_task:
452-
skip_hints = self.get_current_task_hints(task_name)
453-
hint_embeddings = []
454-
id_to_hint = {}
455-
for hint, emb in zip(all_hints, self.hint_embeddings):
456-
if hint in skip_hints:
457-
continue
458-
hint_embeddings.append(emb.tolist())
459-
id_to_hint[len(hint_embeddings) - 1] = hint
460-
logger.info(f"Prepared hint embeddings for {len(hint_embeddings)} hints")
461-
similarities = self._similarity(goal_embeddings.tolist(), hint_embeddings)
462-
top_indices = similarities.argsort()[0][-self.top_n :].tolist()
463-
logger.info(f"Top hint indices based on embedding similarity: {top_indices}")
464-
hints = [id_to_hint[idx] for idx in top_indices]
465-
logger.info(f"Embedding-based hints chosen: {hints}")
466-
except Exception as e:
467-
logger.exception(f"Failed to choose hints using embeddings: {e}")
468-
hints = []
469-
return hints
470-
471-
def _encode(self, texts: list[str], prompt: str = "", timeout: int = 10, max_retries: int = 5):
472-
"""Call the encode API endpoint with timeout and retries"""
473-
for attempt in range(max_retries):
474-
try:
475-
response = requests.post(
476-
f"{self.embedder_server}/encode",
477-
json={"texts": texts, "prompt": prompt},
478-
timeout=timeout,
479-
)
480-
embs = response.json()["embeddings"]
481-
return np.asarray(embs)
482-
except (requests.exceptions.RequestException, requests.exceptions.Timeout) as e:
483-
if attempt == max_retries - 1:
484-
raise e
485-
time.sleep(random.uniform(1, timeout))
486-
continue
487-
raise ValueError("Failed to encode hints")
488-
489-
def _similarity(
490-
self,
491-
texts1: list,
492-
texts2: list,
493-
timeout: int = 2,
494-
max_retries: int = 5,
495-
):
496-
"""Call the similarity API endpoint with timeout and retries"""
497-
for attempt in range(max_retries):
498-
try:
499-
response = requests.post(
500-
f"{self.embedder_server}/similarity",
501-
json={"texts1": texts1, "texts2": texts2},
502-
timeout=timeout,
503-
)
504-
similarities = response.json()["similarities"]
505-
return np.asarray(similarities)
506-
except (requests.exceptions.RequestException, requests.exceptions.Timeout) as e:
507-
if attempt == max_retries - 1:
508-
raise e
509-
time.sleep(random.uniform(1, timeout))
510-
continue
511-
raise ValueError("Failed to compute similarity")
512-
513-
def choose_hints_direct(self, task_name: str) -> list[str]:
514-
hints = self.get_current_task_hints(task_name)
515-
logger.info(f"Direct hints chosen: {hints}")
516-
return hints
517-
518-
def get_current_task_hints(self, task_name):
519-
hints_df = self.hint_db[
520-
self.hint_db["task_name"].apply(lambda x: fnmatch.fnmatch(x, task_name))
521-
]
522-
return hints_df["hint"].tolist()
523-
524-
525353
@dataclass
526354
class PromptConfig:
527355
tag_screenshot: bool = True # Whether to tag the screenshot with the last action.

src/agentlab/utils/__init__.py

Whitespace-only changes.

src/agentlab/utils/hinting.py

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
import fnmatch
2+
import json
3+
import logging
4+
import os
5+
import random
6+
import time
7+
from collections import defaultdict
8+
from pathlib import Path
9+
from typing import Literal
10+
11+
import numpy as np
12+
import pandas as pd
13+
import requests
14+
from agentlab.llm.chat_api import ChatModel
15+
16+
logger = logging.getLogger(__name__)
17+
18+
19+
class HintsSource:
20+
def __init__(
21+
self,
22+
hint_db_path: str,
23+
hint_retrieval_mode: Literal["direct", "llm", "emb"] = "direct",
24+
skip_hints_for_current_task: bool = False,
25+
top_n: int = 4,
26+
embedder_model: str = "Qwen/Qwen3-Embedding-0.6B",
27+
embedder_server: str = "http://localhost:5000",
28+
llm_prompt: str = """We're choosing hints to help solve the following task:\n{goal}.\n
29+
You need to choose the most relevant hints topic from the following list:\n\nHint topics:\n{topics}\n
30+
Choose hint topic for the task and return only its number, e.g. 1. If you don't know the answer, return -1.""",
31+
) -> None:
32+
self.hint_db_path = hint_db_path
33+
self.hint_retrieval_mode = hint_retrieval_mode
34+
self.skip_hints_for_current_task = skip_hints_for_current_task
35+
self.top_n = top_n
36+
self.embedder_model = embedder_model
37+
self.embedder_server = embedder_server
38+
self.llm_prompt = llm_prompt
39+
40+
if Path(hint_db_path).is_absolute():
41+
self.hint_db_path = Path(hint_db_path).as_posix()
42+
else:
43+
self.hint_db_path = (Path(__file__).parent / self.hint_db_path).as_posix()
44+
self.hint_db = pd.read_csv(self.hint_db_path, header=0, index_col=None, dtype=str)
45+
logger.info(f"Loaded {len(self.hint_db)} hints from database {self.hint_db_path}")
46+
if self.hint_retrieval_mode == "emb":
47+
self.load_hint_vectors()
48+
49+
def load_hint_vectors(self):
50+
self.uniq_hints = self.hint_db.drop_duplicates(subset=["hint"], keep="first")
51+
logger.info(
52+
f"Encoding {len(self.uniq_hints)} unique hints with semantic keys using {self.embedder_model} model."
53+
)
54+
hints = self.uniq_hints["hint"].tolist()
55+
semantic_keys = self.uniq_hints["semantic_keys"].tolist()
56+
lines = [f"{k}: {h}" for h, k in zip(hints, semantic_keys)]
57+
emb_path = f"{self.hint_db_path}.embs.npy"
58+
assert os.path.exists(emb_path), f"Embedding file not found: {emb_path}"
59+
logger.info(f"Loading hint embeddings from: {emb_path}")
60+
emb_dict = np.load(emb_path, allow_pickle=True).item()
61+
self.hint_embeddings = np.array([emb_dict[k] for k in lines])
62+
logger.info(f"Loaded hint embeddings shape: {self.hint_embeddings.shape}")
63+
64+
def choose_hints(self, llm, task_name: str, goal: str) -> list[str]:
65+
"""Choose hints based on the task name."""
66+
logger.info(
67+
f"Choosing hints for task: {task_name}, goal: {goal} from db: {self.hint_db_path} using mode: {self.hint_retrieval_mode}"
68+
)
69+
if self.hint_retrieval_mode == "llm":
70+
return self.choose_hints_llm(llm, goal, task_name)
71+
elif self.hint_retrieval_mode == "direct":
72+
return self.choose_hints_direct(task_name)
73+
elif self.hint_retrieval_mode == "emb":
74+
return self.choose_hints_emb(goal, task_name)
75+
else:
76+
raise ValueError(f"Unknown hint retrieval mode: {self.hint_retrieval_mode}")
77+
78+
def choose_hints_llm(self, llm, goal: str, task_name: str) -> list[str]:
79+
"""Choose hints using LLM to filter the hints."""
80+
topic_to_hints = defaultdict(list)
81+
skip_hints = []
82+
if self.skip_hints_for_current_task:
83+
skip_hints = self.get_current_task_hints(task_name)
84+
for _, row in self.hint_db.iterrows():
85+
hint = row["hint"]
86+
if hint in skip_hints:
87+
continue
88+
topic_to_hints[row["semantic_keys"]].append(hint)
89+
logger.info(f"Collected {len(topic_to_hints)} hint topics")
90+
hint_topics = list(topic_to_hints.keys())
91+
topics = "\n".join([f"{i}. {h}" for i, h in enumerate(hint_topics)])
92+
prompt = self.llm_prompt.format(goal=goal, topics=topics)
93+
94+
if isinstance(llm, ChatModel):
95+
response: str = llm(messages=[dict(role="user", content=prompt)])["content"]
96+
else:
97+
response: str = llm(APIPayload(messages=[llm.msg.user().add_text(prompt)])).think
98+
try:
99+
topic_number = json.loads(response)
100+
if topic_number < 0 or topic_number >= len(hint_topics):
101+
logger.error(f"Wrong LLM hint id response: {response}, no hints")
102+
return []
103+
hint_topic = hint_topics[topic_number]
104+
hints = list(set(topic_to_hints[hint_topic]))
105+
logger.info(f"LLM hint topic {topic_number}:'{hint_topic}', chosen hints: {hints}")
106+
except Exception as e:
107+
logger.exception(f"Failed to parse LLM hint id response: {response}:\n{e}")
108+
hints = []
109+
return hints
110+
111+
def choose_hints_emb(self, goal: str, task_name: str) -> list[str]:
112+
"""Choose hints using embeddings to filter the hints."""
113+
try:
114+
goal_embeddings = self._encode([goal], prompt="task description")
115+
hint_embeddings = self.hint_embeddings.copy()
116+
all_hints = self.uniq_hints["hint"].tolist()
117+
skip_hints = []
118+
if self.skip_hints_for_current_task:
119+
skip_hints = self.get_current_task_hints(task_name)
120+
hint_embeddings = []
121+
id_to_hint = {}
122+
for hint, emb in zip(all_hints, self.hint_embeddings):
123+
if hint in skip_hints:
124+
continue
125+
hint_embeddings.append(emb.tolist())
126+
id_to_hint[len(hint_embeddings) - 1] = hint
127+
logger.info(f"Prepared hint embeddings for {len(hint_embeddings)} hints")
128+
similarities = self._similarity(goal_embeddings.tolist(), hint_embeddings)
129+
top_indices = similarities.argsort()[0][-self.top_n :].tolist()
130+
logger.info(f"Top hint indices based on embedding similarity: {top_indices}")
131+
hints = [id_to_hint[idx] for idx in top_indices]
132+
logger.info(f"Embedding-based hints chosen: {hints}")
133+
except Exception as e:
134+
logger.exception(f"Failed to choose hints using embeddings: {e}")
135+
hints = []
136+
return hints
137+
138+
def _encode(self, texts: list[str], prompt: str = "", timeout: int = 10, max_retries: int = 5):
139+
"""Call the encode API endpoint with timeout and retries"""
140+
for attempt in range(max_retries):
141+
try:
142+
response = requests.post(
143+
f"{self.embedder_server}/encode",
144+
json={"texts": texts, "prompt": prompt},
145+
timeout=timeout,
146+
)
147+
embs = response.json()["embeddings"]
148+
return np.asarray(embs)
149+
except (requests.exceptions.RequestException, requests.exceptions.Timeout) as e:
150+
if attempt == max_retries - 1:
151+
raise e
152+
time.sleep(random.uniform(1, timeout))
153+
continue
154+
raise ValueError("Failed to encode hints")
155+
156+
def _similarity(
157+
self,
158+
texts1: list,
159+
texts2: list,
160+
timeout: int = 2,
161+
max_retries: int = 5,
162+
):
163+
"""Call the similarity API endpoint with timeout and retries"""
164+
for attempt in range(max_retries):
165+
try:
166+
response = requests.post(
167+
f"{self.embedder_server}/similarity",
168+
json={"texts1": texts1, "texts2": texts2},
169+
timeout=timeout,
170+
)
171+
similarities = response.json()["similarities"]
172+
return np.asarray(similarities)
173+
except (requests.exceptions.RequestException, requests.exceptions.Timeout) as e:
174+
if attempt == max_retries - 1:
175+
raise e
176+
time.sleep(random.uniform(1, timeout))
177+
continue
178+
raise ValueError("Failed to compute similarity")
179+
180+
def choose_hints_direct(self, task_name: str) -> list[str]:
181+
hints = self.get_current_task_hints(task_name)
182+
logger.info(f"Direct hints chosen: {hints}")
183+
return hints
184+
185+
def get_current_task_hints(self, task_name):
186+
hints_df = self.hint_db[
187+
self.hint_db["task_name"].apply(lambda x: fnmatch.fnmatch(x, task_name))
188+
]
189+
return hints_df["hint"].tolist()

0 commit comments

Comments
 (0)