11import fnmatch
22import json
3+ import logging
4+ import os
5+ import random
6+ import time
37from abc import ABC , abstractmethod
8+ from collections import defaultdict
49from copy import copy
510from dataclasses import asdict , dataclass , field
611from pathlib import Path
7- from typing import Any
12+ from typing import Any , Literal
813
914import bgym
15+ import numpy as np
1016import pandas as pd
17+ import requests
1118from bgym import Benchmark as BgymBenchmark
1219from browsergym .core .observation import extract_screenshot
1320from browsergym .utils .obs import (
3441)
3542from agentlab .llm .tracking import cost_tracker_decorator
3643
44+ logger = logging .getLogger (__name__ )
45+
3746
3847@dataclass
3948class Block (ABC ):
@@ -176,7 +185,6 @@ class Obs(Block):
176185 def apply (
177186 self , llm , discussion : StructuredDiscussion , obs : dict , last_llm_output : LLMOutput
178187 ) -> dict :
179-
180188 obs_msg = llm .msg .user ()
181189 tool_calls = last_llm_output .tool_calls
182190 if self .use_last_error :
@@ -298,22 +306,52 @@ def apply_init(self, llm, discussion: StructuredDiscussion) -> dict:
298306class TaskHint (Block ):
299307 use_task_hint : bool = True
300308 hint_db_rel_path : str = "hint_db.csv"
309+ hint_retrieval_mode : Literal ["direct" , "llm" , "emb" ] = "direct"
310+ top_n : int = 4 # Number of top hints to return when using embedding retrieval
311+ embedder_model : str = "Qwen/Qwen3-Embedding-0.6B" # Model for embedding hints
312+ embedder_server : str = "http://localhost:5000"
313+ llm_prompt : str = """We're choosing hints to help solve the following task:\n {goal}.\n
314+ You need to choose the most relevant hints topic from the following list:\n \n Hint topics:\n {topics}\n
315+ Choose hint topic for the task and return only its number, e.g. 1. If you don't know the answer, return -1."""
301316
302317 def _init (self ):
303318 """Initialize the block."""
304- hint_db_path = Path (__file__ ).parent / self .hint_db_rel_path
319+ if Path (self .hint_db_rel_path ).is_absolute ():
320+ hint_db_path = Path (self .hint_db_rel_path )
321+ else :
322+ hint_db_path = Path (__file__ ).parent / self .hint_db_rel_path
305323 self .hint_db = pd .read_csv (hint_db_path , header = 0 , index_col = None , dtype = str )
324+ if self .hint_retrieval_mode == "emb" :
325+ self .encode_hints ()
326+
327+ def oai_embed (self , text : str ):
328+ response = self ._oai_emb .create (input = text , model = "text-embedding-3-small" )
329+ return response .data [0 ].embedding
330+
331+ def encode_hints (self ):
332+ self .uniq_hints = self .hint_db .drop_duplicates (subset = ["hint" ], keep = "first" )
333+ logger .info (
334+ f"Encoding { len (self .uniq_hints )} unique hints with semantic keys using { self .embedder_model } model."
335+ )
336+ hints = self .uniq_hints ["hint" ].tolist ()
337+ semantic_keys = self .uniq_hints ["semantic_keys" ].tolist ()
338+ lines = [f"{ k } : { h } " for h , k in zip (hints , semantic_keys )]
339+ emb_path = f"{ self .hint_db_rel_path } .embs.npy"
340+ assert os .path .exists (emb_path ), f"Embedding file not found: { emb_path } "
341+ logger .info (f"Loading hint embeddings from: { emb_path } " )
342+ emb_dict = np .load (emb_path , allow_pickle = True ).item ()
343+ self .hint_embeddings = np .array ([emb_dict [k ] for k in lines ])
344+ logger .info (f"Loaded hint embeddings shape: { self .hint_embeddings .shape } " )
306345
307346 def apply (self , llm , discussion : StructuredDiscussion , task_name : str ) -> dict :
308347 if not self .use_task_hint :
309- return
348+ return {}
310349
311- task_hints = self .hint_db [
312- self .hint_db ["task_name" ].apply (lambda x : fnmatch .fnmatch (x , task_name ))
313- ]
350+ goal = "\n " .join ([c .get ("text" , "" ) for c in discussion .groups [0 ].messages [1 ].content ])
351+ task_hints = self .choose_hints (llm , task_name , goal )
314352
315353 hints = []
316- for hint in task_hints [ "hint" ] :
354+ for hint in task_hints :
317355 hint = hint .strip ()
318356 if hint :
319357 hints .append (f"- { hint } " )
@@ -327,6 +365,94 @@ def apply(self, llm, discussion: StructuredDiscussion, task_name: str) -> dict:
327365
328366 discussion .append (msg )
329367
368+ def choose_hints (self , llm , task_name : str , goal : str ) -> list [str ]:
369+ """Choose hints based on the task name."""
370+ if self .hint_retrieval_mode == "llm" :
371+ return self .choose_hints_llm (llm , goal )
372+ elif self .hint_retrieval_mode == "direct" :
373+ return self .choose_hints_direct (task_name )
374+ elif self .hint_retrieval_mode == "emb" :
375+ return self .choose_hints_emb (goal )
376+ else :
377+ raise ValueError (f"Unknown hint retrieval mode: { self .hint_retrieval_mode } " )
378+
379+ def choose_hints_llm (self , llm , goal : str ) -> list [str ]:
380+ """Choose hints using LLM to filter the hints."""
381+ topic_to_hints = defaultdict (list )
382+ for i , row in self .hint_db .iterrows ():
383+ topic_to_hints [row ["semantic_keys" ]].append (i )
384+ hint_topics = list (topic_to_hints .keys ())
385+ topics = "\n " .join ([f"{ i } . { h } " for i , h in enumerate (hint_topics )])
386+ prompt = self .llm_prompt .format (goal = goal , topics = topics )
387+ response = llm (APIPayload (messages = [llm .msg .user ().add_text (prompt )]))
388+ try :
389+ hint_topic_idx = json .loads (response .think )
390+ if hint_topic_idx < 0 or hint_topic_idx >= len (hint_topics ):
391+ logger .error (f"Wrong LLM hint id response: { response .think } , no hints" )
392+ return []
393+ hint_topic = hint_topics [hint_topic_idx ]
394+ hint_indices = topic_to_hints [hint_topic ]
395+ df = self .hint_db .iloc [hint_indices ].copy ()
396+ df = df .drop_duplicates (subset = ["hint" ], keep = "first" ) # leave only unique hints
397+ hints = df ["hint" ].tolist ()
398+ logger .debug (f"LLM hint topic { hint_topic_idx } , chosen hints: { df ['hint' ].tolist ()} " )
399+ except json .JSONDecodeError :
400+ logger .error (f"Failed to parse LLM hint id response: { response .think } , no hints" )
401+ hints = []
402+ return hints
403+
404+ def choose_hints_emb (self , goal : str ) -> list [str ]:
405+ """Choose hints using embeddings to filter the hints."""
406+ goal_embeddings = self ._encode ([goal ], prompt = "task description" )
407+ similarities = self ._similarity (goal_embeddings .tolist (), self .hint_embeddings .tolist ())
408+ top_indices = similarities .argsort ()[0 ][- self .top_n :].tolist ()
409+ logger .info (f"Top hint indices based on embedding similarity: { top_indices } " )
410+ hints = self .uniq_hints .iloc [top_indices ]
411+ logger .info (f"Embedding-based hints chosen: { hints } " )
412+ return hints ["hint" ].tolist ()
413+
414+ def _encode (self , texts : list [str ], prompt : str = "" , timeout : int = 10 , max_retries : int = 5 ):
415+ """Call the encode API endpoint with timeout and retries"""
416+ for attempt in range (max_retries ):
417+ try :
418+ response = requests .post (
419+ f"{ self .embedder_server } /encode" ,
420+ json = {"texts" : texts , "prompt" : prompt },
421+ timeout = timeout ,
422+ )
423+ embs = response .json ()["embeddings" ]
424+ return np .asarray (embs )
425+ except (requests .exceptions .RequestException , requests .exceptions .Timeout ) as e :
426+ if attempt == max_retries - 1 :
427+ raise e
428+ time .sleep (random .uniform (1 , timeout ))
429+ continue
430+
431+ def _similarity (
432+ self , texts1 : list [str ], texts2 : list [str ], timeout : int = 2 , max_retries : int = 5
433+ ):
434+ """Call the similarity API endpoint with timeout and retries"""
435+ for attempt in range (max_retries ):
436+ try :
437+ response = requests .post (
438+ f"{ self .embedder_server } /similarity" ,
439+ json = {"texts1" : texts1 , "texts2" : texts2 },
440+ timeout = timeout ,
441+ )
442+ similarities = response .json ()["similarities" ]
443+ return np .asarray (similarities )
444+ except (requests .exceptions .RequestException , requests .exceptions .Timeout ) as e :
445+ if attempt == max_retries - 1 :
446+ raise e
447+ time .sleep (random .uniform (1 , timeout ))
448+ continue
449+
450+ def choose_hints_direct (self , task_name : str ) -> list [str ]:
451+ hints = self .hint_db [
452+ self .hint_db ["task_name" ].apply (lambda x : fnmatch .fnmatch (x , task_name ))
453+ ]
454+ return hints ["hint" ].tolist ()
455+
330456
331457@dataclass
332458class PromptConfig :
@@ -386,7 +512,8 @@ def __init__(
386512 self .model_args = model_args
387513 self .config = config
388514 self .action_set : bgym .AbstractActionSet = action_set or bgym .HighLevelActionSet (
389- self .config .action_subsets , multiaction = self .config .multiaction # type: ignore
515+ self .config .action_subsets ,
516+ multiaction = self .config .multiaction , # type: ignore
390517 )
391518 self .tools = self .action_set .to_tool_description (api = model_args .api )
392519
@@ -510,6 +637,15 @@ def get_action(self, obs: Any) -> float:
510637 vision_support = True ,
511638)
512639
640+ GPT_4_1_CC_API = OpenAIChatModelArgs (
641+ model_name = "gpt-4.1" ,
642+ max_total_tokens = 200_000 ,
643+ max_input_tokens = 200_000 ,
644+ max_new_tokens = 2_000 ,
645+ temperature = 0.1 ,
646+ vision_support = True ,
647+ )
648+
513649GPT_5_mini = OpenAIChatModelArgs (
514650 model_name = "gpt-5-mini-2025-08-07" ,
515651 max_total_tokens = 400_000 ,
@@ -548,7 +684,7 @@ def get_action(self, obs: Any) -> float:
548684 vision_support = True ,
549685)
550686
551- CLAUDE_MODEL_CONFIG = ClaudeResponseModelArgs (
687+ CLAUDE_SONNET_37 = ClaudeResponseModelArgs (
552688 model_name = "claude-3-7-sonnet-20250219" ,
553689 max_total_tokens = 200_000 ,
554690 max_input_tokens = 200_000 ,
@@ -557,6 +693,15 @@ def get_action(self, obs: Any) -> float:
557693 vision_support = True ,
558694)
559695
696+ CLAUDE_SONNET_4 = ClaudeResponseModelArgs (
697+ model_name = "claude-sonnet-4-20250514" ,
698+ max_total_tokens = 200_000 ,
699+ max_input_tokens = 200_000 ,
700+ max_new_tokens = 2_000 ,
701+ temperature = 0.1 ,
702+ vision_support = True ,
703+ )
704+
560705O3_RESPONSE_MODEL = OpenAIResponseModelArgs (
561706 model_name = "o3-2025-04-16" ,
562707 max_total_tokens = 200_000 ,
@@ -574,6 +719,25 @@ def get_action(self, obs: Any) -> float:
574719 vision_support = True ,
575720)
576721
722+ GPT_5 = OpenAIChatModelArgs (
723+ model_name = "gpt-5" ,
724+ max_total_tokens = 200_000 ,
725+ max_input_tokens = 200_000 ,
726+ max_new_tokens = 8_000 ,
727+ temperature = None ,
728+ vision_support = True ,
729+ )
730+
731+
732+ GPT_5_MINI = OpenAIChatModelArgs (
733+ model_name = "gpt-5-mini-2025-08-07" ,
734+ max_total_tokens = 200_000 ,
735+ max_input_tokens = 200_000 ,
736+ max_new_tokens = 2_000 ,
737+ temperature = 1.0 ,
738+ vision_support = True ,
739+ )
740+
577741GPT4_1_OPENROUTER_MODEL = OpenRouterModelArgs (
578742 model_name = "openai/gpt-4.1" ,
579743 max_total_tokens = 200_000 ,
@@ -600,12 +764,12 @@ def get_action(self, obs: Any) -> float:
600764 keep_last_n_obs = None ,
601765 multiaction = False , # whether to use multi-action or not
602766 # action_subsets=("bid",),
603- action_subsets = ("coord" ),
767+ action_subsets = ("coord" , ),
604768 # action_subsets=("coord", "bid"),
605769)
606770
607771AGENT_CONFIG = ToolUseAgentArgs (
608- model_args = CLAUDE_MODEL_CONFIG ,
772+ model_args = CLAUDE_SONNET_37 ,
609773 config = DEFAULT_PROMPT_CONFIG ,
610774)
611775
@@ -633,7 +797,7 @@ def get_action(self, obs: Any) -> float:
633797)
634798
635799OSWORLD_CLAUDE = ToolUseAgentArgs (
636- model_args = CLAUDE_MODEL_CONFIG ,
800+ model_args = CLAUDE_SONNET_37 ,
637801 config = PromptConfig (
638802 tag_screenshot = True ,
639803 goal = Goal (goal_as_system_msg = True ),
0 commit comments