11import fnmatch
22import json
33import logging
4+ import os
5+ import random
6+ import time
47from abc import ABC , abstractmethod
58from collections import defaultdict
69from copy import copy
912from typing import Any , Literal
1013
1114import bgym
15+ import numpy as np
1216import pandas as pd
17+ import requests
1318from bgym import Benchmark as BgymBenchmark
1419from browsergym .core .observation import extract_screenshot
1520from browsergym .utils .obs import (
1823 overlay_som ,
1924 prune_html ,
2025)
21- from sentence_transformers import SentenceTransformer
2226
2327from agentlab .agents .agent_args import AgentArgs
2428from agentlab .benchmarks .abstract_env import AbstractBenchmark as AgentLabBenchmark
@@ -181,7 +185,6 @@ class Obs(Block):
181185 def apply (
182186 self , llm , discussion : StructuredDiscussion , obs : dict , last_llm_output : LLMOutput
183187 ) -> dict :
184-
185188 obs_msg = llm .msg .user ()
186189 tool_calls = last_llm_output .tool_calls
187190 if self .use_last_error :
@@ -306,6 +309,7 @@ class TaskHint(Block):
306309 hint_retrieval_mode : Literal ["direct" , "llm" , "emb" ] = "direct"
307310 top_n : int = 4 # Number of top hints to return when using embedding retrieval
308311 embedder_model : str = "Qwen/Qwen3-Embedding-0.6B" # Model for embedding hints
312+ embedder_server : str = "http://localhost:5000"
309313 llm_prompt : str = """We're choosing hints to help solve the following task:\n {goal}.\n
310314You need to choose the most relevant hints topic from the following list:\n \n Hint topics:\n {topics}\n
311315Choose hint topic for the task and return only its number, e.g. 1. If you don't know the answer, return -1."""
@@ -318,20 +322,26 @@ def _init(self):
318322 hint_db_path = Path (__file__ ).parent / self .hint_db_rel_path
319323 self .hint_db = pd .read_csv (hint_db_path , header = 0 , index_col = None , dtype = str )
320324 if self .hint_retrieval_mode == "emb" :
321- logger .info ("Load sentence transformer model for hint embeddings." )
322- self .emb_model = SentenceTransformer (
323- "Qwen/Qwen3-Embedding-0.6B" , model_kwargs = {"torch_dtype" : "bfloat16" }
324- )
325325 self .encode_hints ()
326326
327+ def oai_embed (self , text : str ):
328+ response = self ._oai_emb .create (input = text , model = "text-embedding-3-small" )
329+ return response .data [0 ].embedding
330+
327331 def encode_hints (self ):
328332 self .uniq_hints = self .hint_db .drop_duplicates (subset = ["hint" ], keep = "first" )
329333 logger .info (
330- f"Encoding { len (self .uniq_hints )} unique hints using { self .embedder_model } model."
331- )
332- self .hint_embeddings = self .emb_model .encode (
333- self .uniq_hints ["hint" ].tolist (), prompt = "task hint"
334+ f"Encoding { len (self .uniq_hints )} unique hints with semantic keys using { self .embedder_model } model."
334335 )
336+ hints = self .uniq_hints ["hint" ].tolist ()
337+ semantic_keys = self .uniq_hints ["semantic_keys" ].tolist ()
338+ lines = [f"{ k } : { h } " for h , k in zip (hints , semantic_keys )]
339+ emb_path = f"{ self .hint_db_rel_path } .embs.npy"
340+ assert os .path .exists (emb_path ), f"Embedding file not found: { emb_path } "
341+ logger .info (f"Loading hint embeddings from: { emb_path } " )
342+ emb_dict = np .load (emb_path , allow_pickle = True ).item ()
343+ self .hint_embeddings = np .array ([emb_dict [k ] for k in lines ])
344+ logger .info (f"Loaded hint embeddings shape: { self .hint_embeddings .shape } " )
335345
336346 def apply (self , llm , discussion : StructuredDiscussion , task_name : str ) -> dict :
337347 if not self .use_task_hint :
@@ -393,14 +403,50 @@ def choose_hints_llm(self, llm, goal: str) -> list[str]:
393403
394404 def choose_hints_emb (self , goal : str ) -> list [str ]:
395405 """Choose hints using embeddings to filter the hints."""
396- goal_embeddings = self .emb_model . encode ([goal ], prompt = "task description" )
397- similarities = self .emb_model . similarity (goal_embeddings , self .hint_embeddings )
406+ goal_embeddings = self ._encode ([goal ], prompt = "task description" )
407+ similarities = self ._similarity (goal_embeddings . tolist () , self .hint_embeddings . tolist () )
398408 top_indices = similarities .argsort ()[0 ][- self .top_n :].tolist ()
399409 logger .info (f"Top hint indices based on embedding similarity: { top_indices } " )
400410 hints = self .uniq_hints .iloc [top_indices ]
401411 logger .info (f"Embedding-based hints chosen: { hints } " )
402412 return hints ["hint" ].tolist ()
403413
414+ def _encode (self , texts : list [str ], prompt : str = "" , timeout : int = 10 , max_retries : int = 5 ):
415+ """Call the encode API endpoint with timeout and retries"""
416+ for attempt in range (max_retries ):
417+ try :
418+ response = requests .post (
419+ f"{ self .embedder_server } /encode" ,
420+ json = {"texts" : texts , "prompt" : prompt },
421+ timeout = timeout ,
422+ )
423+ embs = response .json ()["embeddings" ]
424+ return np .asarray (embs )
425+ except (requests .exceptions .RequestException , requests .exceptions .Timeout ) as e :
426+ if attempt == max_retries - 1 :
427+ raise e
428+ time .sleep (random .uniform (1 , timeout ))
429+ continue
430+
431+ def _similarity (
432+ self , texts1 : list [str ], texts2 : list [str ], timeout : int = 2 , max_retries : int = 5
433+ ):
434+ """Call the similarity API endpoint with timeout and retries"""
435+ for attempt in range (max_retries ):
436+ try :
437+ response = requests .post (
438+ f"{ self .embedder_server } /similarity" ,
439+ json = {"texts1" : texts1 , "texts2" : texts2 },
440+ timeout = timeout ,
441+ )
442+ similarities = response .json ()["similarities" ]
443+ return np .asarray (similarities )
444+ except (requests .exceptions .RequestException , requests .exceptions .Timeout ) as e :
445+ if attempt == max_retries - 1 :
446+ raise e
447+ time .sleep (random .uniform (1 , timeout ))
448+ continue
449+
404450 def choose_hints_direct (self , task_name : str ) -> list [str ]:
405451 hints = self .hint_db [
406452 self .hint_db ["task_name" ].apply (lambda x : fnmatch .fnmatch (x , task_name ))
@@ -466,7 +512,8 @@ def __init__(
466512 self .model_args = model_args
467513 self .config = config
468514 self .action_set : bgym .AbstractActionSet = action_set or bgym .HighLevelActionSet (
469- self .config .action_subsets , multiaction = self .config .multiaction # type: ignore
515+ self .config .action_subsets ,
516+ multiaction = self .config .multiaction , # type: ignore
470517 )
471518 self .tools = self .action_set .to_tool_description (api = model_args .api )
472519
@@ -656,7 +703,7 @@ def get_action(self, obs: Any) -> float:
656703 model_name = "gpt-5" ,
657704 max_total_tokens = 200_000 ,
658705 max_input_tokens = 200_000 ,
659- max_new_tokens = 2_000 ,
706+ max_new_tokens = 8_000 ,
660707 temperature = None ,
661708 vision_support = True ,
662709)
0 commit comments