11import fnmatch
22import json
3+ import logging
34from abc import ABC , abstractmethod
5+ from collections import defaultdict
46from copy import copy
57from dataclasses import asdict , dataclass , field
68from pathlib import Path
7- from typing import Any
9+ from typing import Any , Literal
810
911import bgym
1012import pandas as pd
1618 overlay_som ,
1719 prune_html ,
1820)
21+ from sentence_transformers import SentenceTransformer
1922
2023from agentlab .agents .agent_args import AgentArgs
2124from agentlab .benchmarks .abstract_env import AbstractBenchmark as AgentLabBenchmark
3437)
3538from agentlab .llm .tracking import cost_tracker_decorator
3639
40+ logger = logging .getLogger (__name__ )
41+
3742
3843@dataclass
3944class Block (ABC ):
@@ -298,22 +303,45 @@ def apply_init(self, llm, discussion: StructuredDiscussion) -> dict:
298303class TaskHint (Block ):
299304 use_task_hint : bool = True
300305 hint_db_rel_path : str = "hint_db.csv"
306+ hint_retrieval_mode : Literal ["direct" , "llm" , "emb" ] = "direct"
307+ top_n : int = 4 # Number of top hints to return when using embedding retrieval
308+ embedder_model : str = "Qwen/Qwen3-Embedding-0.6B" # Model for embedding hints
309+ llm_prompt : str = """We're choosing hints to help solve the following task:\n {goal}.\n
310+ You need to choose the most relevant hints topic from the following list:\n \n Hint topics:\n {topics}\n
311+ Choose hint topic for the task and return only its number, e.g. 1. If you don't know the answer, return -1."""
301312
302313 def _init (self ):
303314 """Initialize the block."""
304- hint_db_path = Path (__file__ ).parent / self .hint_db_rel_path
315+ if Path (self .hint_db_rel_path ).is_absolute ():
316+ hint_db_path = Path (self .hint_db_rel_path )
317+ else :
318+ hint_db_path = Path (__file__ ).parent / self .hint_db_rel_path
305319 self .hint_db = pd .read_csv (hint_db_path , header = 0 , index_col = None , dtype = str )
320+ if self .hint_retrieval_mode == "emb" :
321+ logger .info ("Load sentence transformer model for hint embeddings." )
322+ self .emb_model = SentenceTransformer (
323+ "Qwen/Qwen3-Embedding-0.6B" , model_kwargs = {"torch_dtype" : "bfloat16" }
324+ )
325+ self .encode_hints ()
326+
327+ def encode_hints (self ):
328+ self .uniq_hints = self .hint_db .drop_duplicates (subset = ["hint" ], keep = "first" )
329+ logger .info (
330+ f"Encoding { len (self .uniq_hints )} unique hints using { self .embedder_model } model."
331+ )
332+ self .hint_embeddings = self .emb_model .encode (
333+ self .uniq_hints ["hint" ].tolist (), prompt = "task hint"
334+ )
306335
307336 def apply (self , llm , discussion : StructuredDiscussion , task_name : str ) -> dict :
308337 if not self .use_task_hint :
309- return
338+ return {}
310339
311- task_hints = self .hint_db [
312- self .hint_db ["task_name" ].apply (lambda x : fnmatch .fnmatch (x , task_name ))
313- ]
340+ goal = "\n " .join ([c .get ("text" , "" ) for c in discussion .groups [0 ].messages [1 ].content ])
341+ task_hints = self .choose_hints (llm , task_name , goal )
314342
315343 hints = []
316- for hint in task_hints [ "hint" ] :
344+ for hint in task_hints :
317345 hint = hint .strip ()
318346 if hint :
319347 hints .append (f"- { hint } " )
@@ -327,6 +355,58 @@ def apply(self, llm, discussion: StructuredDiscussion, task_name: str) -> dict:
327355
328356 discussion .append (msg )
329357
358+ def choose_hints (self , llm , task_name : str , goal : str ) -> list [str ]:
359+ """Choose hints based on the task name."""
360+ if self .hint_retrieval_mode == "llm" :
361+ return self .choose_hints_llm (llm , goal )
362+ elif self .hint_retrieval_mode == "direct" :
363+ return self .choose_hints_direct (task_name )
364+ elif self .hint_retrieval_mode == "emb" :
365+ return self .choose_hints_emb (goal )
366+ else :
367+ raise ValueError (f"Unknown hint retrieval mode: { self .hint_retrieval_mode } " )
368+
369+ def choose_hints_llm (self , llm , goal : str ) -> list [str ]:
370+ """Choose hints using LLM to filter the hints."""
371+ topic_to_hints = defaultdict (list )
372+ for i , row in self .hint_db .iterrows ():
373+ topic_to_hints [row ["semantic_keys" ]].append (i )
374+ hint_topics = list (topic_to_hints .keys ())
375+ topics = "\n " .join ([f"{ i } . { h } " for i , h in enumerate (hint_topics )])
376+ prompt = self .llm_prompt .format (goal = goal , topics = topics )
377+ response = llm (APIPayload (messages = [llm .msg .user ().add_text (prompt )]))
378+ try :
379+ hint_topic_idx = json .loads (response .think )
380+ if hint_topic_idx < 0 or hint_topic_idx >= len (hint_topics ):
381+ logger .error (f"Wrong LLM hint id response: { response .think } , no hints" )
382+ return []
383+ hint_topic = hint_topics [hint_topic_idx ]
384+ hint_indices = topic_to_hints [hint_topic ]
385+ df = self .hint_db .iloc [hint_indices ].copy ()
386+ df = df .drop_duplicates (subset = ["hint" ], keep = "first" ) # leave only unique hints
387+ hints = df ["hint" ].tolist ()
388+ logger .debug (f"LLM hint topic { hint_topic_idx } , chosen hints: { df ['hint' ].tolist ()} " )
389+ except json .JSONDecodeError :
390+ logger .error (f"Failed to parse LLM hint id response: { response .think } , no hints" )
391+ hints = []
392+ return hints
393+
394+ def choose_hints_emb (self , goal : str ) -> list [str ]:
395+ """Choose hints using embeddings to filter the hints."""
396+ goal_embeddings = self .emb_model .encode ([goal ], prompt = "task description" )
397+ similarities = self .emb_model .similarity (goal_embeddings , self .hint_embeddings )
398+ top_indices = similarities .argsort ()[0 ][- self .top_n :].tolist ()
399+ logger .info (f"Top hint indices based on embedding similarity: { top_indices } " )
400+ hints = self .uniq_hints .iloc [top_indices ]
401+ logger .info (f"Embedding-based hints chosen: { hints } " )
402+ return hints ["hint" ].tolist ()
403+
404+ def choose_hints_direct (self , task_name : str ) -> list [str ]:
405+ hints = self .hint_db [
406+ self .hint_db ["task_name" ].apply (lambda x : fnmatch .fnmatch (x , task_name ))
407+ ]
408+ return hints ["hint" ].tolist ()
409+
330410
331411@dataclass
332412class PromptConfig :
@@ -510,6 +590,15 @@ def get_action(self, obs: Any) -> float:
510590 vision_support = True ,
511591)
512592
593+ GPT_4_1_CC_API = OpenAIChatModelArgs (
594+ model_name = "gpt-4.1" ,
595+ max_total_tokens = 200_000 ,
596+ max_input_tokens = 200_000 ,
597+ max_new_tokens = 2_000 ,
598+ temperature = 0.1 ,
599+ vision_support = True ,
600+ )
601+
513602GPT_4_1_MINI = OpenAIResponseModelArgs (
514603 model_name = "gpt-4.1-mini" ,
515604 max_total_tokens = 200_000 ,
@@ -528,7 +617,7 @@ def get_action(self, obs: Any) -> float:
528617 vision_support = True ,
529618)
530619
531- CLAUDE_MODEL_CONFIG = ClaudeResponseModelArgs (
620+ CLAUDE_SONNET_37 = ClaudeResponseModelArgs (
532621 model_name = "claude-3-7-sonnet-20250219" ,
533622 max_total_tokens = 200_000 ,
534623 max_input_tokens = 200_000 ,
@@ -537,6 +626,15 @@ def get_action(self, obs: Any) -> float:
537626 vision_support = True ,
538627)
539628
629+ CLAUDE_SONNET_4 = ClaudeResponseModelArgs (
630+ model_name = "claude-sonnet-4-20250514" ,
631+ max_total_tokens = 200_000 ,
632+ max_input_tokens = 200_000 ,
633+ max_new_tokens = 2_000 ,
634+ temperature = 0.1 ,
635+ vision_support = True ,
636+ )
637+
540638O3_RESPONSE_MODEL = OpenAIResponseModelArgs (
541639 model_name = "o3-2025-04-16" ,
542640 max_total_tokens = 200_000 ,
@@ -554,6 +652,25 @@ def get_action(self, obs: Any) -> float:
554652 vision_support = True ,
555653)
556654
655+ GPT_5 = OpenAIChatModelArgs (
656+ model_name = "gpt-5" ,
657+ max_total_tokens = 200_000 ,
658+ max_input_tokens = 200_000 ,
659+ max_new_tokens = 2_000 ,
660+ temperature = None ,
661+ vision_support = True ,
662+ )
663+
664+
665+ GPT_5_MINI = OpenAIChatModelArgs (
666+ model_name = "gpt-5-mini-2025-08-07" ,
667+ max_total_tokens = 200_000 ,
668+ max_input_tokens = 200_000 ,
669+ max_new_tokens = 2_000 ,
670+ temperature = 1.0 ,
671+ vision_support = True ,
672+ )
673+
557674GPT4_1_OPENROUTER_MODEL = OpenRouterModelArgs (
558675 model_name = "openai/gpt-4.1" ,
559676 max_total_tokens = 200_000 ,
@@ -580,12 +697,12 @@ def get_action(self, obs: Any) -> float:
580697 keep_last_n_obs = None ,
581698 multiaction = True , # whether to use multi-action or not
582699 # action_subsets=("bid",),
583- action_subsets = ("coord" ),
700+ action_subsets = ("coord" , ),
584701 # action_subsets=("coord", "bid"),
585702)
586703
587704AGENT_CONFIG = ToolUseAgentArgs (
588- model_args = CLAUDE_MODEL_CONFIG ,
705+ model_args = CLAUDE_SONNET_37 ,
589706 config = DEFAULT_PROMPT_CONFIG ,
590707)
591708
@@ -605,7 +722,7 @@ def get_action(self, obs: Any) -> float:
605722)
606723
607724OSWORLD_CLAUDE = ToolUseAgentArgs (
608- model_args = CLAUDE_MODEL_CONFIG ,
725+ model_args = CLAUDE_SONNET_37 ,
609726 config = PromptConfig (
610727 tag_screenshot = True ,
611728 goal = Goal (goal_as_system_msg = True ),
0 commit comments