66
77import logging
88from dataclasses import dataclass
9+ from pathlib import Path
10+ from typing import Literal
911
10- from browsergym . core import action
12+ import pandas as pd
1113from browsergym .core .action .base import AbstractActionSet
1214
1315from agentlab .agents import dynamic_prompting as dp
16+ from agentlab .agents .tool_use_agent .tool_use_agent import HintsSource
17+ from agentlab .llm .chat_api import ChatModel
1418from agentlab .llm .llm_utils import HumanMessage , parse_html_tags_raise
15- import fnmatch
16- import pandas as pd
17- from pathlib import Path
1819
1920
2021@dataclass
@@ -49,6 +50,8 @@ class GenericPromptFlags(dp.Flags):
4950 use_abstract_example : bool = False
5051 use_hints : bool = False
5152 use_task_hint : bool = False
53+ task_hint_retrieval_mode : Literal ["direct" , "llm" , "emb" ] = "direct"
54+ skip_hints_for_current_task : bool = False
5255 hint_db_path : str = None
5356 enable_chat : bool = False
5457 max_prompt_tokens : int = None
@@ -70,10 +73,12 @@ def __init__(
7073 previous_plan : str ,
7174 step : int ,
7275 flags : GenericPromptFlags ,
76+ llm : ChatModel ,
7377 ) -> None :
7478 super ().__init__ ()
7579 self .flags = flags
7680 self .history = dp .History (obs_history , actions , memories , thoughts , flags .obs )
81+ goal = obs_history [- 1 ]["goal_object" ]
7782 if self .flags .enable_chat :
7883 self .instructions = dp .ChatInstructions (
7984 obs_history [- 1 ]["chat_messages" ], extra_instructions = flags .extra_instructions
@@ -84,7 +89,7 @@ def __init__(
8489 "Agent is in goal mode, but multiple user messages are present in the chat. Consider switching to `enable_chat=True`."
8590 )
8691 self .instructions = dp .GoalInstructions (
87- obs_history [ - 1 ][ "goal_object" ] , extra_instructions = flags .extra_instructions
92+ goal , extra_instructions = flags .extra_instructions
8893 )
8994
9095 self .obs = dp .Observation (
@@ -103,9 +108,14 @@ def time_for_caution():
103108 self .be_cautious = dp .BeCautious (visible = time_for_caution )
104109 self .think = dp .Think (visible = lambda : flags .use_thinking )
105110 self .hints = dp .Hints (visible = lambda : flags .use_hints )
111+ goal_str : str = goal [0 ]["text" ]
106112 self .task_hint = TaskHint (
107113 use_task_hint = flags .use_task_hint ,
108- hint_db_path = flags .hint_db_path
114+ hint_db_path = flags .hint_db_path ,
115+ goal = goal_str ,
116+ hint_retrieval_mode = flags .task_hint_retrieval_mode ,
117+ llm = llm ,
118+ skip_hints_for_current_task = flags .skip_hints_for_current_task ,
109119 )
110120 self .plan = Plan (previous_plan , step , lambda : flags .use_plan ) # TODO add previous plan
111121 self .criticise = Criticise (visible = lambda : flags .use_criticise )
@@ -114,12 +124,12 @@ def time_for_caution():
114124 @property
115125 def _prompt (self ) -> HumanMessage :
116126 prompt = HumanMessage (self .instructions .prompt )
117-
127+
118128 # Add task hints if enabled
119129 task_hints_text = ""
120- if self .flags .use_task_hint and hasattr (self , ' task_name' ):
130+ if self .flags .use_task_hint and hasattr (self , " task_name" ):
121131 task_hints_text = self .task_hint .get_hints_for_task (self .task_name )
122-
132+
123133 prompt .add_text (
124134 f"""\
125135{ self .obs .prompt } \
@@ -286,11 +296,23 @@ def _parse_answer(self, text_answer):
286296
287297
288298class TaskHint (dp .PromptElement ):
289- def __init__ (self , use_task_hint : bool = True , hint_db_path : str = None ) -> None :
299+ def __init__ (
300+ self ,
301+ use_task_hint : bool ,
302+ hint_db_path : str ,
303+ goal : str ,
304+ hint_retrieval_mode : Literal ["direct" , "llm" , "emb" ],
305+ skip_hints_for_current_task : bool ,
306+ llm : ChatModel ,
307+ ) -> None :
290308 super ().__init__ (visible = use_task_hint )
291309 self .use_task_hint = use_task_hint
292310 self .hint_db_rel_path = "hint_db.csv"
293311 self .hint_db_path = hint_db_path # Allow external path override
312+ self .hint_retrieval_mode : Literal ["direct" , "llm" , "emb" ] = hint_retrieval_mode
313+ self .skip_hints_for_current_task = skip_hints_for_current_task
314+ self .goal = goal
315+ self .llm = llm
294316 self ._init ()
295317
296318 _prompt = "" # Task hints are added dynamically in MainPrompt
@@ -316,42 +338,50 @@ def _init(self):
316338 hint_db_path = Path (self .hint_db_path )
317339 else :
318340 hint_db_path = Path (__file__ ).parent / self .hint_db_rel_path
319-
341+
320342 if hint_db_path .exists ():
321343 self .hint_db = pd .read_csv (hint_db_path , header = 0 , index_col = None , dtype = str )
322344 # Verify the expected columns exist
323345 if "task_name" not in self .hint_db .columns or "hint" not in self .hint_db .columns :
324- print (f"Warning: Hint database missing expected columns. Found: { list (self .hint_db .columns )} " )
346+ print (
347+ f"Warning: Hint database missing expected columns. Found: { list (self .hint_db .columns )} "
348+ )
325349 self .hint_db = pd .DataFrame (columns = ["task_name" , "hint" ])
326350 else :
327351 print (f"Warning: Hint database not found at { hint_db_path } " )
328352 self .hint_db = pd .DataFrame (columns = ["task_name" , "hint" ])
353+ self .hints_source = HintsSource (
354+ hint_db_path = hint_db_path .as_posix (),
355+ hint_retrieval_mode = self .hint_retrieval_mode ,
356+ skip_hints_for_current_task = self .skip_hints_for_current_task ,
357+ )
329358 except Exception as e :
330359 # Fallback to empty database on any error
331360 print (f"Warning: Could not load hint database: { e } " )
332361 self .hint_db = pd .DataFrame (columns = ["task_name" , "hint" ])
333362
334-
335363 def get_hints_for_task (self , task_name : str ) -> str :
336364 """Get hints for a specific task."""
337365 if not self .use_task_hint :
338366 return ""
339367
340368 # Ensure hint_db is initialized
341- if not hasattr (self , ' hint_db' ):
369+ if not hasattr (self , " hint_db" ):
342370 self ._init ()
343371
344372 # Check if hint_db has the expected structure
345- if self .hint_db .empty or "task_name" not in self .hint_db .columns or "hint" not in self .hint_db .columns :
373+ if (
374+ self .hint_db .empty
375+ or "task_name" not in self .hint_db .columns
376+ or "hint" not in self .hint_db .columns
377+ ):
346378 return ""
347379
348380 try :
349- task_hints = self .hint_db [
350- self .hint_db ["task_name" ].apply (lambda x : fnmatch .fnmatch (x , task_name ))
351- ]
381+ task_hints = self .hints_source .choose_hints (self .llm , task_name , self .goal )
352382
353383 hints = []
354- for hint in task_hints [ "hint" ] :
384+ for hint in task_hints :
355385 hint = hint .strip ()
356386 if hint :
357387 hints .append (f"- { hint } " )
@@ -364,5 +394,5 @@ def get_hints_for_task(self, task_name: str) -> str:
364394 return hints_str
365395 except Exception as e :
366396 print (f"Warning: Error getting hints for task { task_name } : { e } " )
367-
397+
368398 return ""
0 commit comments