2828from agentlab .benchmarks .abstract_env import AbstractBenchmark as AgentLabBenchmark
2929from agentlab .benchmarks .osworld import OSWorldActionSet
3030from agentlab .llm .base_api import BaseModelArgs
31+ from agentlab .llm .chat_api import ChatModel
3132from agentlab .llm .llm_utils import image_to_png_base64_url
3233from agentlab .llm .response_api import (
3334 APIPayload ,
@@ -316,39 +317,21 @@ class TaskHint(Block):
316317
317318 def _init (self ):
318319 """Initialize the block."""
319- if Path (self .hint_db_rel_path ).is_absolute ():
320- hint_db_path = Path (self .hint_db_rel_path )
321- else :
322- hint_db_path = Path (__file__ ).parent / self .hint_db_rel_path
323- self .hint_db = pd .read_csv (hint_db_path , header = 0 , index_col = None , dtype = str )
324- if self .hint_retrieval_mode == "emb" :
325- self .encode_hints ()
326-
327- def oai_embed (self , text : str ):
328- response = self ._oai_emb .create (input = text , model = "text-embedding-3-small" )
329- return response .data [0 ].embedding
330-
331- def encode_hints (self ):
332- self .uniq_hints = self .hint_db .drop_duplicates (subset = ["hint" ], keep = "first" )
333- logger .info (
334- f"Encoding { len (self .uniq_hints )} unique hints with semantic keys using { self .embedder_model } model."
320+ self .hints_source = HintsSource (
321+ hint_db_path = self .hint_db_rel_path ,
322+ hint_retrieval_mode = self .hint_retrieval_mode ,
323+ top_n = self .top_n ,
324+ embedder_model = self .embedder_model ,
325+ embedder_server = self .embedder_server ,
326+ llm_prompt = self .llm_prompt ,
335327 )
336- hints = self .uniq_hints ["hint" ].tolist ()
337- semantic_keys = self .uniq_hints ["semantic_keys" ].tolist ()
338- lines = [f"{ k } : { h } " for h , k in zip (hints , semantic_keys )]
339- emb_path = f"{ self .hint_db_rel_path } .embs.npy"
340- assert os .path .exists (emb_path ), f"Embedding file not found: { emb_path } "
341- logger .info (f"Loading hint embeddings from: { emb_path } " )
342- emb_dict = np .load (emb_path , allow_pickle = True ).item ()
343- self .hint_embeddings = np .array ([emb_dict [k ] for k in lines ])
344- logger .info (f"Loaded hint embeddings shape: { self .hint_embeddings .shape } " )
345328
346329 def apply (self , llm , discussion : StructuredDiscussion , task_name : str ) -> dict :
347330 if not self .use_task_hint :
348331 return {}
349332
350333 goal = "\n " .join ([c .get ("text" , "" ) for c in discussion .groups [0 ].messages [1 ].content ])
351- task_hints = self .choose_hints (llm , task_name , goal )
334+ task_hints = self .hints_source . choose_hints (llm , task_name , goal )
352335
353336 hints = []
354337 for hint in task_hints :
@@ -365,6 +348,49 @@ def apply(self, llm, discussion: StructuredDiscussion, task_name: str) -> dict:
365348
366349 discussion .append (msg )
367350
351+
352+ class HintsSource :
353+ def __init__ (
354+ self ,
355+ hint_db_path : str ,
356+ hint_retrieval_mode : Literal ["direct" , "llm" , "emb" ] = "direct" ,
357+ top_n : int = 4 ,
358+ embedder_model : str = "Qwen/Qwen3-Embedding-0.6B" ,
359+ embedder_server : str = "http://localhost:5000" ,
360+ llm_prompt : str = """We're choosing hints to help solve the following task:\n {goal}.\n
361+ You need to choose the most relevant hints topic from the following list:\n \n Hint topics:\n {topics}\n
362+ Choose hint topic for the task and return only its number, e.g. 1. If you don't know the answer, return -1.""" ,
363+ ) -> None :
364+ self .hint_db_path = hint_db_path
365+ self .hint_retrieval_mode = hint_retrieval_mode
366+ self .top_n = top_n
367+ self .embedder_model = embedder_model
368+ self .embedder_server = embedder_server
369+ self .llm_prompt = llm_prompt
370+
371+ if Path (hint_db_path ).is_absolute ():
372+ self .hint_db_path = Path (hint_db_path ).as_posix ()
373+ else :
374+ self .hint_db_path = (Path (__file__ ).parent / self .hint_db_path ).as_posix ()
375+ self .hint_db = pd .read_csv (self .hint_db_path , header = 0 , index_col = None , dtype = str )
376+ if self .hint_retrieval_mode == "emb" :
377+ self .load_hint_vectors ()
378+
379+ def load_hint_vectors (self ):
380+ self .uniq_hints = self .hint_db .drop_duplicates (subset = ["hint" ], keep = "first" )
381+ logger .info (
382+ f"Encoding { len (self .uniq_hints )} unique hints with semantic keys using { self .embedder_model } model."
383+ )
384+ hints = self .uniq_hints ["hint" ].tolist ()
385+ semantic_keys = self .uniq_hints ["semantic_keys" ].tolist ()
386+ lines = [f"{ k } : { h } " for h , k in zip (hints , semantic_keys )]
387+ emb_path = f"{ self .hint_db_path } .embs.npy"
388+ assert os .path .exists (emb_path ), f"Embedding file not found: { emb_path } "
389+ logger .info (f"Loading hint embeddings from: { emb_path } " )
390+ emb_dict = np .load (emb_path , allow_pickle = True ).item ()
391+ self .hint_embeddings = np .array ([emb_dict [k ] for k in lines ])
392+ logger .info (f"Loaded hint embeddings shape: { self .hint_embeddings .shape } " )
393+
368394 def choose_hints (self , llm , task_name : str , goal : str ) -> list [str ]:
369395 """Choose hints based on the task name."""
370396 if self .hint_retrieval_mode == "llm" :
@@ -384,11 +410,14 @@ def choose_hints_llm(self, llm, goal: str) -> list[str]:
384410 hint_topics = list (topic_to_hints .keys ())
385411 topics = "\n " .join ([f"{ i } . { h } " for i , h in enumerate (hint_topics )])
386412 prompt = self .llm_prompt .format (goal = goal , topics = topics )
387- response = llm (APIPayload (messages = [llm .msg .user ().add_text (prompt )]))
413+ if isinstance (llm , ChatModel ):
414+ response : str = llm (messages = [dict (role = "user" , content = prompt )])["content" ]
415+ else :
416+ response : str = llm (APIPayload (messages = [llm .msg .user ().add_text (prompt )])).think
388417 try :
389- hint_topic_idx = json .loads (response . think )
418+ hint_topic_idx = json .loads (response )
390419 if hint_topic_idx < 0 or hint_topic_idx >= len (hint_topics ):
391- logger .error (f"Wrong LLM hint id response: { response . think } , no hints" )
420+ logger .error (f"Wrong LLM hint id response: { response } , no hints" )
392421 return []
393422 hint_topic = hint_topics [hint_topic_idx ]
394423 hint_indices = topic_to_hints [hint_topic ]
@@ -397,7 +426,7 @@ def choose_hints_llm(self, llm, goal: str) -> list[str]:
397426 hints = df ["hint" ].tolist ()
398427 logger .debug (f"LLM hint topic { hint_topic_idx } , chosen hints: { df ['hint' ].tolist ()} " )
399428 except json .JSONDecodeError :
400- logger .error (f"Failed to parse LLM hint id response: { response . think } , no hints" )
429+ logger .error (f"Failed to parse LLM hint id response: { response } , no hints" )
401430 hints = []
402431 return hints
403432
@@ -427,6 +456,7 @@ def _encode(self, texts: list[str], prompt: str = "", timeout: int = 10, max_ret
427456 raise e
428457 time .sleep (random .uniform (1 , timeout ))
429458 continue
459+ raise ValueError ("Failed to encode hints" )
430460
431461 def _similarity (
432462 self , texts1 : list [str ], texts2 : list [str ], timeout : int = 2 , max_retries : int = 5
@@ -446,6 +476,7 @@ def _similarity(
446476 raise e
447477 time .sleep (random .uniform (1 , timeout ))
448478 continue
479+ raise ValueError ("Failed to compute similarity" )
449480
450481 def choose_hints_direct (self , task_name : str ) -> list [str ]:
451482 hints = self .hint_db [
0 commit comments