Skip to content

Commit b80d731

Browse files
Hints retrieval in tool use agent (#277)
* add new dependency group called 'hint' * use external embedding service in task hints retrieval * gpt5 fixes * Obtain pricing for ChatModel used by generic agent using litellm --------- Co-authored-by: Aman Jaiswal <[email protected]>
1 parent c48e40b commit b80d731

File tree

5 files changed

+263
-18
lines changed

5 files changed

+263
-18
lines changed

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,9 @@ dev = [
102102
"ipykernel>=6.30.1",
103103
"pip>=25.2",
104104
]
105+
hint = [
106+
"sentence-transformers>=5.0.0",
107+
]
105108

106109

107110
[project.scripts]

src/agentlab/agents/tool_use_agent/tool_use_agent.py

Lines changed: 177 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,20 @@
11
import fnmatch
22
import json
3+
import logging
4+
import os
5+
import random
6+
import time
37
from abc import ABC, abstractmethod
8+
from collections import defaultdict
49
from copy import copy
510
from dataclasses import asdict, dataclass, field
611
from pathlib import Path
7-
from typing import Any
12+
from typing import Any, Literal
813

914
import bgym
15+
import numpy as np
1016
import pandas as pd
17+
import requests
1118
from bgym import Benchmark as BgymBenchmark
1219
from browsergym.core.observation import extract_screenshot
1320
from browsergym.utils.obs import (
@@ -34,6 +41,8 @@
3441
)
3542
from agentlab.llm.tracking import cost_tracker_decorator
3643

44+
logger = logging.getLogger(__name__)
45+
3746

3847
@dataclass
3948
class Block(ABC):
@@ -176,7 +185,6 @@ class Obs(Block):
176185
def apply(
177186
self, llm, discussion: StructuredDiscussion, obs: dict, last_llm_output: LLMOutput
178187
) -> dict:
179-
180188
obs_msg = llm.msg.user()
181189
tool_calls = last_llm_output.tool_calls
182190
if self.use_last_error:
@@ -298,22 +306,52 @@ def apply_init(self, llm, discussion: StructuredDiscussion) -> dict:
298306
class TaskHint(Block):
299307
use_task_hint: bool = True
300308
hint_db_rel_path: str = "hint_db.csv"
309+
hint_retrieval_mode: Literal["direct", "llm", "emb"] = "direct"
310+
top_n: int = 4 # Number of top hints to return when using embedding retrieval
311+
embedder_model: str = "Qwen/Qwen3-Embedding-0.6B" # Model for embedding hints
312+
embedder_server: str = "http://localhost:5000"
313+
llm_prompt: str = """We're choosing hints to help solve the following task:\n{goal}.\n
314+
You need to choose the most relevant hints topic from the following list:\n\nHint topics:\n{topics}\n
315+
Choose hint topic for the task and return only its number, e.g. 1. If you don't know the answer, return -1."""
301316

302317
def _init(self):
303318
"""Initialize the block."""
304-
hint_db_path = Path(__file__).parent / self.hint_db_rel_path
319+
if Path(self.hint_db_rel_path).is_absolute():
320+
hint_db_path = Path(self.hint_db_rel_path)
321+
else:
322+
hint_db_path = Path(__file__).parent / self.hint_db_rel_path
305323
self.hint_db = pd.read_csv(hint_db_path, header=0, index_col=None, dtype=str)
324+
if self.hint_retrieval_mode == "emb":
325+
self.encode_hints()
326+
327+
def oai_embed(self, text: str):
328+
response = self._oai_emb.create(input=text, model="text-embedding-3-small")
329+
return response.data[0].embedding
330+
331+
def encode_hints(self):
332+
self.uniq_hints = self.hint_db.drop_duplicates(subset=["hint"], keep="first")
333+
logger.info(
334+
f"Encoding {len(self.uniq_hints)} unique hints with semantic keys using {self.embedder_model} model."
335+
)
336+
hints = self.uniq_hints["hint"].tolist()
337+
semantic_keys = self.uniq_hints["semantic_keys"].tolist()
338+
lines = [f"{k}: {h}" for h, k in zip(hints, semantic_keys)]
339+
emb_path = f"{self.hint_db_rel_path}.embs.npy"
340+
assert os.path.exists(emb_path), f"Embedding file not found: {emb_path}"
341+
logger.info(f"Loading hint embeddings from: {emb_path}")
342+
emb_dict = np.load(emb_path, allow_pickle=True).item()
343+
self.hint_embeddings = np.array([emb_dict[k] for k in lines])
344+
logger.info(f"Loaded hint embeddings shape: {self.hint_embeddings.shape}")
306345

307346
def apply(self, llm, discussion: StructuredDiscussion, task_name: str) -> dict:
308347
if not self.use_task_hint:
309-
return
348+
return {}
310349

311-
task_hints = self.hint_db[
312-
self.hint_db["task_name"].apply(lambda x: fnmatch.fnmatch(x, task_name))
313-
]
350+
goal = "\n".join([c.get("text", "") for c in discussion.groups[0].messages[1].content])
351+
task_hints = self.choose_hints(llm, task_name, goal)
314352

315353
hints = []
316-
for hint in task_hints["hint"]:
354+
for hint in task_hints:
317355
hint = hint.strip()
318356
if hint:
319357
hints.append(f"- {hint}")
@@ -327,6 +365,94 @@ def apply(self, llm, discussion: StructuredDiscussion, task_name: str) -> dict:
327365

328366
discussion.append(msg)
329367

368+
def choose_hints(self, llm, task_name: str, goal: str) -> list[str]:
369+
"""Choose hints based on the task name."""
370+
if self.hint_retrieval_mode == "llm":
371+
return self.choose_hints_llm(llm, goal)
372+
elif self.hint_retrieval_mode == "direct":
373+
return self.choose_hints_direct(task_name)
374+
elif self.hint_retrieval_mode == "emb":
375+
return self.choose_hints_emb(goal)
376+
else:
377+
raise ValueError(f"Unknown hint retrieval mode: {self.hint_retrieval_mode}")
378+
379+
def choose_hints_llm(self, llm, goal: str) -> list[str]:
380+
"""Choose hints using LLM to filter the hints."""
381+
topic_to_hints = defaultdict(list)
382+
for i, row in self.hint_db.iterrows():
383+
topic_to_hints[row["semantic_keys"]].append(i)
384+
hint_topics = list(topic_to_hints.keys())
385+
topics = "\n".join([f"{i}. {h}" for i, h in enumerate(hint_topics)])
386+
prompt = self.llm_prompt.format(goal=goal, topics=topics)
387+
response = llm(APIPayload(messages=[llm.msg.user().add_text(prompt)]))
388+
try:
389+
hint_topic_idx = json.loads(response.think)
390+
if hint_topic_idx < 0 or hint_topic_idx >= len(hint_topics):
391+
logger.error(f"Wrong LLM hint id response: {response.think}, no hints")
392+
return []
393+
hint_topic = hint_topics[hint_topic_idx]
394+
hint_indices = topic_to_hints[hint_topic]
395+
df = self.hint_db.iloc[hint_indices].copy()
396+
df = df.drop_duplicates(subset=["hint"], keep="first") # leave only unique hints
397+
hints = df["hint"].tolist()
398+
logger.debug(f"LLM hint topic {hint_topic_idx}, chosen hints: {df['hint'].tolist()}")
399+
except json.JSONDecodeError:
400+
logger.error(f"Failed to parse LLM hint id response: {response.think}, no hints")
401+
hints = []
402+
return hints
403+
404+
def choose_hints_emb(self, goal: str) -> list[str]:
405+
"""Choose hints using embeddings to filter the hints."""
406+
goal_embeddings = self._encode([goal], prompt="task description")
407+
similarities = self._similarity(goal_embeddings.tolist(), self.hint_embeddings.tolist())
408+
top_indices = similarities.argsort()[0][-self.top_n :].tolist()
409+
logger.info(f"Top hint indices based on embedding similarity: {top_indices}")
410+
hints = self.uniq_hints.iloc[top_indices]
411+
logger.info(f"Embedding-based hints chosen: {hints}")
412+
return hints["hint"].tolist()
413+
414+
def _encode(self, texts: list[str], prompt: str = "", timeout: int = 10, max_retries: int = 5):
415+
"""Call the encode API endpoint with timeout and retries"""
416+
for attempt in range(max_retries):
417+
try:
418+
response = requests.post(
419+
f"{self.embedder_server}/encode",
420+
json={"texts": texts, "prompt": prompt},
421+
timeout=timeout,
422+
)
423+
embs = response.json()["embeddings"]
424+
return np.asarray(embs)
425+
except (requests.exceptions.RequestException, requests.exceptions.Timeout) as e:
426+
if attempt == max_retries - 1:
427+
raise e
428+
time.sleep(random.uniform(1, timeout))
429+
continue
430+
431+
def _similarity(
432+
self, texts1: list[str], texts2: list[str], timeout: int = 2, max_retries: int = 5
433+
):
434+
"""Call the similarity API endpoint with timeout and retries"""
435+
for attempt in range(max_retries):
436+
try:
437+
response = requests.post(
438+
f"{self.embedder_server}/similarity",
439+
json={"texts1": texts1, "texts2": texts2},
440+
timeout=timeout,
441+
)
442+
similarities = response.json()["similarities"]
443+
return np.asarray(similarities)
444+
except (requests.exceptions.RequestException, requests.exceptions.Timeout) as e:
445+
if attempt == max_retries - 1:
446+
raise e
447+
time.sleep(random.uniform(1, timeout))
448+
continue
449+
450+
def choose_hints_direct(self, task_name: str) -> list[str]:
451+
hints = self.hint_db[
452+
self.hint_db["task_name"].apply(lambda x: fnmatch.fnmatch(x, task_name))
453+
]
454+
return hints["hint"].tolist()
455+
330456

331457
@dataclass
332458
class PromptConfig:
@@ -386,7 +512,8 @@ def __init__(
386512
self.model_args = model_args
387513
self.config = config
388514
self.action_set: bgym.AbstractActionSet = action_set or bgym.HighLevelActionSet(
389-
self.config.action_subsets, multiaction=self.config.multiaction # type: ignore
515+
self.config.action_subsets,
516+
multiaction=self.config.multiaction, # type: ignore
390517
)
391518
self.tools = self.action_set.to_tool_description(api=model_args.api)
392519

@@ -510,6 +637,15 @@ def get_action(self, obs: Any) -> float:
510637
vision_support=True,
511638
)
512639

640+
GPT_4_1_CC_API = OpenAIChatModelArgs(
641+
model_name="gpt-4.1",
642+
max_total_tokens=200_000,
643+
max_input_tokens=200_000,
644+
max_new_tokens=2_000,
645+
temperature=0.1,
646+
vision_support=True,
647+
)
648+
513649
GPT_5_mini = OpenAIChatModelArgs(
514650
model_name="gpt-5-mini-2025-08-07",
515651
max_total_tokens=400_000,
@@ -548,7 +684,7 @@ def get_action(self, obs: Any) -> float:
548684
vision_support=True,
549685
)
550686

551-
CLAUDE_MODEL_CONFIG = ClaudeResponseModelArgs(
687+
CLAUDE_SONNET_37 = ClaudeResponseModelArgs(
552688
model_name="claude-3-7-sonnet-20250219",
553689
max_total_tokens=200_000,
554690
max_input_tokens=200_000,
@@ -557,6 +693,15 @@ def get_action(self, obs: Any) -> float:
557693
vision_support=True,
558694
)
559695

696+
CLAUDE_SONNET_4 = ClaudeResponseModelArgs(
697+
model_name="claude-sonnet-4-20250514",
698+
max_total_tokens=200_000,
699+
max_input_tokens=200_000,
700+
max_new_tokens=2_000,
701+
temperature=0.1,
702+
vision_support=True,
703+
)
704+
560705
O3_RESPONSE_MODEL = OpenAIResponseModelArgs(
561706
model_name="o3-2025-04-16",
562707
max_total_tokens=200_000,
@@ -574,6 +719,25 @@ def get_action(self, obs: Any) -> float:
574719
vision_support=True,
575720
)
576721

722+
GPT_5 = OpenAIChatModelArgs(
723+
model_name="gpt-5",
724+
max_total_tokens=200_000,
725+
max_input_tokens=200_000,
726+
max_new_tokens=8_000,
727+
temperature=None,
728+
vision_support=True,
729+
)
730+
731+
732+
GPT_5_MINI = OpenAIChatModelArgs(
733+
model_name="gpt-5-mini-2025-08-07",
734+
max_total_tokens=200_000,
735+
max_input_tokens=200_000,
736+
max_new_tokens=2_000,
737+
temperature=1.0,
738+
vision_support=True,
739+
)
740+
577741
GPT4_1_OPENROUTER_MODEL = OpenRouterModelArgs(
578742
model_name="openai/gpt-4.1",
579743
max_total_tokens=200_000,
@@ -600,12 +764,12 @@ def get_action(self, obs: Any) -> float:
600764
keep_last_n_obs=None,
601765
multiaction=False, # whether to use multi-action or not
602766
# action_subsets=("bid",),
603-
action_subsets=("coord"),
767+
action_subsets=("coord",),
604768
# action_subsets=("coord", "bid"),
605769
)
606770

607771
AGENT_CONFIG = ToolUseAgentArgs(
608-
model_args=CLAUDE_MODEL_CONFIG,
772+
model_args=CLAUDE_SONNET_37,
609773
config=DEFAULT_PROMPT_CONFIG,
610774
)
611775

@@ -633,7 +797,7 @@ def get_action(self, obs: Any) -> float:
633797
)
634798

635799
OSWORLD_CLAUDE = ToolUseAgentArgs(
636-
model_args=CLAUDE_MODEL_CONFIG,
800+
model_args=CLAUDE_SONNET_37,
637801
config=PromptConfig(
638802
tag_screenshot=True,
639803
goal=Goal(goal_as_system_msg=True),

src/agentlab/llm/chat_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ def __init__(
359359
min_retry_wait_time=min_retry_wait_time,
360360
api_key_env_var="OPENAI_API_KEY",
361361
client_class=OpenAI,
362-
pricing_func=tracking.get_pricing_openai,
362+
pricing_func=partial(tracking.get_pricing_litellm, model_name=model_name),
363363
log_probs=log_probs,
364364
)
365365

src/agentlab/llm/tracking.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -178,9 +178,9 @@ def __call__(self, *args, **kwargs):
178178
# 'self' here calls ._call_api() method of the subclass
179179
response = self._call_api(*args, **kwargs)
180180
usage = dict(getattr(response, "usage", {}))
181-
if "prompt_tokens_details" in usage:
181+
if "prompt_tokens_details" in usage and usage["prompt_tokens_details"]:
182182
usage["cached_tokens"] = usage["prompt_tokens_details"].cached_tokens
183-
if "input_tokens_details" in usage:
183+
if "input_tokens_details" in usage and usage["input_tokens_details"]:
184184
usage["cached_tokens"] = usage["input_tokens_details"].cached_tokens
185185
usage = {f"usage_{k}": v for k, v in usage.items() if isinstance(v, (int, float))}
186186
usage |= {"n_api_calls": 1}
@@ -338,12 +338,16 @@ def get_effective_cost_from_openai_api(self, response) -> float:
338338
if api_type == "chatcompletion":
339339
total_input_tokens = usage.prompt_tokens # (cache read tokens + new input tokens)
340340
output_tokens = usage.completion_tokens
341-
cached_input_tokens = usage.prompt_tokens_details.cached_tokens
341+
cached_input_tokens = (
342+
usage.prompt_tokens_details.cached_tokens if usage.prompt_tokens_details else 0
343+
)
342344
new_input_tokens = total_input_tokens - cached_input_tokens
343345
elif api_type == "response":
344346
total_input_tokens = usage.input_tokens # (cache read tokens + new input tokens)
345347
output_tokens = usage.output_tokens
346-
cached_input_tokens = usage.input_tokens_details.cached_tokens
348+
cached_input_tokens = (
349+
usage.input_tokens_details.cached_tokens if usage.input_tokens_details else 0
350+
)
347351
new_input_tokens = total_input_tokens - cached_input_tokens
348352
else:
349353
logging.warning(f"Unsupported API type: {api_type}. Defaulting cost to 0.0.")

0 commit comments

Comments
 (0)