Skip to content

Commit 7af2d15

Browse files
committed
fixes
1 parent 219e467 commit 7af2d15

File tree

3 files changed

+137
-16
lines changed

3 files changed

+137
-16
lines changed

src/agentlab/agents/tool_use_agent/tool_use_agent.py

Lines changed: 128 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import fnmatch
22
import json
3+
import logging
34
from abc import ABC, abstractmethod
5+
from collections import defaultdict
46
from copy import copy
57
from dataclasses import asdict, dataclass, field
68
from pathlib import Path
7-
from typing import Any
9+
from typing import Any, Literal
810

911
import bgym
1012
import pandas as pd
@@ -16,6 +18,7 @@
1618
overlay_som,
1719
prune_html,
1820
)
21+
from sentence_transformers import SentenceTransformer
1922

2023
from agentlab.agents.agent_args import AgentArgs
2124
from agentlab.benchmarks.abstract_env import AbstractBenchmark as AgentLabBenchmark
@@ -34,6 +37,8 @@
3437
)
3538
from agentlab.llm.tracking import cost_tracker_decorator
3639

40+
logger = logging.getLogger(__name__)
41+
3742

3843
@dataclass
3944
class Block(ABC):
@@ -298,22 +303,45 @@ def apply_init(self, llm, discussion: StructuredDiscussion) -> dict:
298303
class TaskHint(Block):
299304
use_task_hint: bool = True
300305
hint_db_rel_path: str = "hint_db.csv"
306+
hint_retrieval_mode: Literal["direct", "llm", "emb"] = "direct"
307+
top_n: int = 4 # Number of top hints to return when using embedding retrieval
308+
embedder_model: str = "Qwen/Qwen3-Embedding-0.6B" # Model for embedding hints
309+
llm_prompt: str = """We're choosing hints to help solve the following task:\n{goal}.\n
310+
You need to choose the most relevant hints topic from the following list:\n\nHint topics:\n{topics}\n
311+
Choose hint topic for the task and return only its number, e.g. 1. If you don't know the answer, return -1."""
301312

302313
def _init(self):
303314
"""Initialize the block."""
304-
hint_db_path = Path(__file__).parent / self.hint_db_rel_path
315+
if Path(self.hint_db_rel_path).is_absolute():
316+
hint_db_path = Path(self.hint_db_rel_path)
317+
else:
318+
hint_db_path = Path(__file__).parent / self.hint_db_rel_path
305319
self.hint_db = pd.read_csv(hint_db_path, header=0, index_col=None, dtype=str)
320+
if self.hint_retrieval_mode == "emb":
321+
logger.info("Load sentence transformer model for hint embeddings.")
322+
self.emb_model = SentenceTransformer(
323+
"Qwen/Qwen3-Embedding-0.6B", model_kwargs={"torch_dtype": "bfloat16"}
324+
)
325+
self.encode_hints()
326+
327+
def encode_hints(self):
328+
self.uniq_hints = self.hint_db.drop_duplicates(subset=["hint"], keep="first")
329+
logger.info(
330+
f"Encoding {len(self.uniq_hints)} unique hints using {self.embedder_model} model."
331+
)
332+
self.hint_embeddings = self.emb_model.encode(
333+
self.uniq_hints["hint"].tolist(), prompt="task hint"
334+
)
306335

307336
def apply(self, llm, discussion: StructuredDiscussion, task_name: str) -> dict:
308337
if not self.use_task_hint:
309-
return
338+
return {}
310339

311-
task_hints = self.hint_db[
312-
self.hint_db["task_name"].apply(lambda x: fnmatch.fnmatch(x, task_name))
313-
]
340+
goal = "\n".join([c.get("text", "") for c in discussion.groups[0].messages[1].content])
341+
task_hints = self.choose_hints(llm, task_name, goal)
314342

315343
hints = []
316-
for hint in task_hints["hint"]:
344+
for hint in task_hints:
317345
hint = hint.strip()
318346
if hint:
319347
hints.append(f"- {hint}")
@@ -327,6 +355,58 @@ def apply(self, llm, discussion: StructuredDiscussion, task_name: str) -> dict:
327355

328356
discussion.append(msg)
329357

358+
def choose_hints(self, llm, task_name: str, goal: str) -> list[str]:
359+
"""Choose hints based on the task name."""
360+
if self.hint_retrieval_mode == "llm":
361+
return self.choose_hints_llm(llm, goal)
362+
elif self.hint_retrieval_mode == "direct":
363+
return self.choose_hints_direct(task_name)
364+
elif self.hint_retrieval_mode == "emb":
365+
return self.choose_hints_emb(goal)
366+
else:
367+
raise ValueError(f"Unknown hint retrieval mode: {self.hint_retrieval_mode}")
368+
369+
def choose_hints_llm(self, llm, goal: str) -> list[str]:
370+
"""Choose hints using LLM to filter the hints."""
371+
topic_to_hints = defaultdict(list)
372+
for i, row in self.hint_db.iterrows():
373+
topic_to_hints[row["semantic_keys"]].append(i)
374+
hint_topics = list(topic_to_hints.keys())
375+
topics = "\n".join([f"{i}. {h}" for i, h in enumerate(hint_topics)])
376+
prompt = self.llm_prompt.format(goal=goal, topics=topics)
377+
response = llm(APIPayload(messages=[llm.msg.user().add_text(prompt)]))
378+
try:
379+
hint_topic_idx = json.loads(response.think)
380+
if hint_topic_idx < 0 or hint_topic_idx >= len(hint_topics):
381+
logger.error(f"Wrong LLM hint id response: {response.think}, no hints")
382+
return []
383+
hint_topic = hint_topics[hint_topic_idx]
384+
hint_indices = topic_to_hints[hint_topic]
385+
df = self.hint_db.iloc[hint_indices].copy()
386+
df = df.drop_duplicates(subset=["hint"], keep="first") # leave only unique hints
387+
hints = df["hint"].tolist()
388+
logger.debug(f"LLM hint topic {hint_topic_idx}, chosen hints: {df['hint'].tolist()}")
389+
except json.JSONDecodeError:
390+
logger.error(f"Failed to parse LLM hint id response: {response.think}, no hints")
391+
hints = []
392+
return hints
393+
394+
def choose_hints_emb(self, goal: str) -> list[str]:
395+
"""Choose hints using embeddings to filter the hints."""
396+
goal_embeddings = self.emb_model.encode([goal], prompt="task description")
397+
similarities = self.emb_model.similarity(goal_embeddings, self.hint_embeddings)
398+
top_indices = similarities.argsort()[0][-self.top_n :].tolist()
399+
logger.info(f"Top hint indices based on embedding similarity: {top_indices}")
400+
hints = self.uniq_hints.iloc[top_indices]
401+
logger.info(f"Embedding-based hints chosen: {hints}")
402+
return hints["hint"].tolist()
403+
404+
def choose_hints_direct(self, task_name: str) -> list[str]:
405+
hints = self.hint_db[
406+
self.hint_db["task_name"].apply(lambda x: fnmatch.fnmatch(x, task_name))
407+
]
408+
return hints["hint"].tolist()
409+
330410

331411
@dataclass
332412
class PromptConfig:
@@ -510,6 +590,15 @@ def get_action(self, obs: Any) -> float:
510590
vision_support=True,
511591
)
512592

593+
GPT_4_1_CC_API = OpenAIChatModelArgs(
594+
model_name="gpt-4.1",
595+
max_total_tokens=200_000,
596+
max_input_tokens=200_000,
597+
max_new_tokens=2_000,
598+
temperature=0.1,
599+
vision_support=True,
600+
)
601+
513602
GPT_4_1_MINI = OpenAIResponseModelArgs(
514603
model_name="gpt-4.1-mini",
515604
max_total_tokens=200_000,
@@ -528,7 +617,7 @@ def get_action(self, obs: Any) -> float:
528617
vision_support=True,
529618
)
530619

531-
CLAUDE_MODEL_CONFIG = ClaudeResponseModelArgs(
620+
CLAUDE_SONNET_37 = ClaudeResponseModelArgs(
532621
model_name="claude-3-7-sonnet-20250219",
533622
max_total_tokens=200_000,
534623
max_input_tokens=200_000,
@@ -537,6 +626,15 @@ def get_action(self, obs: Any) -> float:
537626
vision_support=True,
538627
)
539628

629+
CLAUDE_SONNET_4 = ClaudeResponseModelArgs(
630+
model_name="claude-sonnet-4-20250514",
631+
max_total_tokens=200_000,
632+
max_input_tokens=200_000,
633+
max_new_tokens=2_000,
634+
temperature=0.1,
635+
vision_support=True,
636+
)
637+
540638
O3_RESPONSE_MODEL = OpenAIResponseModelArgs(
541639
model_name="o3-2025-04-16",
542640
max_total_tokens=200_000,
@@ -554,6 +652,25 @@ def get_action(self, obs: Any) -> float:
554652
vision_support=True,
555653
)
556654

655+
GPT_5 = OpenAIChatModelArgs(
656+
model_name="gpt-5",
657+
max_total_tokens=200_000,
658+
max_input_tokens=200_000,
659+
max_new_tokens=2_000,
660+
temperature=None,
661+
vision_support=True,
662+
)
663+
664+
665+
GPT_5_MINI = OpenAIChatModelArgs(
666+
model_name="gpt-5-mini-2025-08-07",
667+
max_total_tokens=200_000,
668+
max_input_tokens=200_000,
669+
max_new_tokens=2_000,
670+
temperature=1.0,
671+
vision_support=True,
672+
)
673+
557674
GPT4_1_OPENROUTER_MODEL = OpenRouterModelArgs(
558675
model_name="openai/gpt-4.1",
559676
max_total_tokens=200_000,
@@ -580,12 +697,12 @@ def get_action(self, obs: Any) -> float:
580697
keep_last_n_obs=None,
581698
multiaction=True, # whether to use multi-action or not
582699
# action_subsets=("bid",),
583-
action_subsets=("coord"),
700+
action_subsets=("coord",),
584701
# action_subsets=("coord", "bid"),
585702
)
586703

587704
AGENT_CONFIG = ToolUseAgentArgs(
588-
model_args=CLAUDE_MODEL_CONFIG,
705+
model_args=CLAUDE_SONNET_37,
589706
config=DEFAULT_PROMPT_CONFIG,
590707
)
591708

@@ -605,7 +722,7 @@ def get_action(self, obs: Any) -> float:
605722
)
606723

607724
OSWORLD_CLAUDE = ToolUseAgentArgs(
608-
model_args=CLAUDE_MODEL_CONFIG,
725+
model_args=CLAUDE_SONNET_37,
609726
config=PromptConfig(
610727
tag_screenshot=True,
611728
goal=Goal(goal_as_system_msg=True),

src/agentlab/analyze/agent_xray.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -735,7 +735,7 @@ def dict_msg_to_markdown(d: dict):
735735
case _:
736736
parts.append(f"\n```\n{str(item)}\n```\n")
737737

738-
markdown = f"### {d["role"].capitalize()}\n"
738+
markdown = f"### {d['role'].capitalize()}\n"
739739
markdown += "\n".join(parts)
740740
return markdown
741741

src/agentlab/llm/tracking.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -178,9 +178,9 @@ def __call__(self, *args, **kwargs):
178178
# 'self' here calls ._call_api() method of the subclass
179179
response = self._call_api(*args, **kwargs)
180180
usage = dict(getattr(response, "usage", {}))
181-
if "prompt_tokens_details" in usage:
181+
if "prompt_tokens_details" in usage and usage["prompt_tokens_details"]:
182182
usage["cached_tokens"] = usage["prompt_tokens_details"].cached_tokens
183-
if "input_tokens_details" in usage:
183+
if "input_tokens_details" in usage and usage["input_tokens_details"]:
184184
usage["cached_tokens"] = usage["input_tokens_details"].cached_tokens
185185
usage = {f"usage_{k}": v for k, v in usage.items() if isinstance(v, (int, float))}
186186
usage |= {"n_api_calls": 1}
@@ -332,12 +332,16 @@ def get_effective_cost_from_openai_api(self, response) -> float:
332332
if api_type == "chatcompletion":
333333
total_input_tokens = usage.prompt_tokens # (cache read tokens + new input tokens)
334334
output_tokens = usage.completion_tokens
335-
cached_input_tokens = usage.prompt_tokens_details.cached_tokens
335+
cached_input_tokens = (
336+
usage.prompt_tokens_details.cached_tokens if usage.prompt_tokens_details else 0
337+
)
336338
new_input_tokens = total_input_tokens - cached_input_tokens
337339
elif api_type == "response":
338340
total_input_tokens = usage.input_tokens # (cache read tokens + new input tokens)
339341
output_tokens = usage.output_tokens
340-
cached_input_tokens = usage.input_tokens_details.cached_tokens
342+
cached_input_tokens = (
343+
usage.input_tokens_details.cached_tokens if usage.input_tokens_details else 0
344+
)
341345
new_input_tokens = total_input_tokens - cached_input_tokens
342346
else:
343347
logging.warning(f"Unsupported API type: {api_type}. Defaulting cost to 0.0.")

0 commit comments

Comments
 (0)