@@ -85,7 +85,7 @@ def __init__(
8585 step : int ,
8686 flags : GenericPromptFlags ,
8787 llm : ChatModel ,
88- queries : list [str ] | None = None ,
88+ task_hints : list [str ] = [] ,
8989 ) -> None :
9090 super ().__init__ ()
9191 self .flags = flags
@@ -120,25 +120,7 @@ def time_for_caution():
120120 self .be_cautious = dp .BeCautious (visible = time_for_caution )
121121 self .think = dp .Think (visible = lambda : flags .use_thinking )
122122 self .hints = dp .Hints (visible = lambda : flags .use_hints )
123- goal_str : str = goal [0 ]["text" ]
124- # TODO: This design is not very good as we will instantiate the loop up at every step
125- self .task_hint = TaskHint (
126- use_task_hint = flags .use_task_hint ,
127- hint_db_path = flags .hint_db_path ,
128- goal = goal_str ,
129- hint_retrieval_mode = flags .task_hint_retrieval_mode ,
130- llm = llm ,
131- skip_hints_for_current_task = flags .skip_hints_for_current_task ,
132- # hint related
133- hint_type = flags .hint_type ,
134- hint_index_type = flags .hint_index_type ,
135- hint_query_type = flags .hint_query_type ,
136- hint_index_path = flags .hint_index_path ,
137- hint_retriever_path = flags .hint_retriever_path ,
138- hint_num_results = flags .hint_num_results ,
139- hint_level = flags .hint_level ,
140- queries = queries ,
141- )
123+ self .task_hints = TaskHint (visible = lambda : flags .use_task_hint , task_hints = task_hints )
142124 self .plan = Plan (previous_plan , step , lambda : flags .use_plan ) # TODO add previous plan
143125 self .criticise = Criticise (visible = lambda : flags .use_criticise )
144126 self .memory = Memory (visible = lambda : flags .use_memory )
@@ -147,19 +129,13 @@ def time_for_caution():
147129 def _prompt (self ) -> HumanMessage :
148130 prompt = HumanMessage (self .instructions .prompt )
149131
150- # Add task hints if enabled
151- task_hints_text = ""
152- # if self.flags.use_task_hint and hasattr(self, "task_name"):
153- if self .flags .use_task_hint :
154- task_hints_text = self .task_hint .get_hints_for_task (self .task_name )
155-
156132 prompt .add_text (
157133 f"""\
158134{ self .obs .prompt } \
159135{ self .history .prompt } \
160136{ self .action_prompt .prompt } \
161137{ self .hints .prompt } \
162- { task_hints_text } \
138+ { self . task_hint . prompt } \
163139{ self .be_cautious .prompt } \
164140{ self .think .prompt } \
165141{ self .plan .prompt } \
@@ -321,37 +297,11 @@ def _parse_answer(self, text_answer):
321297class TaskHint (dp .PromptElement ):
322298 def __init__ (
323299 self ,
324- use_task_hint : bool ,
325- hint_db_path : str ,
326- goal : str ,
327- llm : ChatModel ,
328- hint_type : Literal ["human" , "llm" , "docs" ] = "human" ,
329- hint_index_type : Literal ["sparse" , "dense" ] = "sparse" ,
330- hint_query_type : Literal ["direct" , "llm" , "emb" ] = "direct" ,
331- hint_index_path : str = None ,
332- hint_retriever_path : str = None ,
333- hint_num_results : int = 5 ,
334- skip_hints_for_current_task : bool = False ,
335- hint_retrieval_mode : Literal ["direct" , "llm" , "emb" ] = "direct" ,
336- hint_level : Literal ["episode" , "step" ] = "episode" ,
337- queries : list [str ] | None = None ,
300+ visible : bool ,
301+ task_hints : list [str ]
338302 ) -> None :
339- super ().__init__ (visible = use_task_hint )
340- self .use_task_hint = use_task_hint
341- self .hint_type = hint_type
342- self .hint_index_type = hint_index_type
343- self .hint_query_type = hint_query_type
344- self .hint_index_path = hint_index_path
345- self .hint_retriever_path = hint_retriever_path
346- self .hint_num_results = hint_num_results
347- self .hint_db_rel_path = "hint_db.csv"
348- self .hint_db_path = hint_db_path # Allow external path override
349- self .hint_retrieval_mode : Literal ["direct" , "llm" , "emb" ] = hint_retrieval_mode
350- self .skip_hints_for_current_task = skip_hints_for_current_task
351- self .goal = goal
352- self .llm = llm
353- self .hint_level : Literal ["episode" , "step" ] = hint_level
354- self .queries : list [str ] | None = queries
303+ super ().__init__ (visible = visible )
304+ self .task_hints = task_hints
355305
356306 _prompt = "" # Task hints are added dynamically in MainPrompt
357307
@@ -368,80 +318,6 @@ def __init__(
368318</task_hint>
369319"""
370320
371- def get_hints_for_task (self , task_name : str ) -> str :
372- """Get hints for a specific task."""
373- if not self .use_task_hint :
374- return ""
375-
376- if self .hint_type == "docs" :
377- if not hasattr (self , "hint_index" ):
378- print ("Initializing hint index new time" )
379- self ._init ()
380- if self .hint_query_type == "goal" :
381- query = self .goal
382- elif self .hint_query_type == "llm" :
383- query = self .llm .generate (self ._prompt + self ._abstract_ex + self ._concrete_ex )
384- else :
385- raise ValueError (f"Unknown hint query type: { self .hint_query_type } " )
386-
387- if self .hint_index_type == "sparse" :
388- import bm25s
389- query_tokens = bm25s .tokenize (query )
390- docs , _ = self .hint_index .retrieve (query_tokens , k = self .hint_num_results )
391- docs = [elem ["text" ] for elem in docs [0 ]]
392- # HACK: truncate to 20k characters (should cover >99% of the cases)
393- for doc in docs :
394- if len (doc ) > 20000 :
395- doc = doc [:20000 ]
396- doc += " ...[truncated]"
397- elif self .hint_index_type == "dense" :
398- query_embedding = self .hint_retriever .encode (query )
399- _ , docs = self .hint_index .get_nearest_examples ("embeddings" , query_embedding , k = self .hint_num_results )
400- docs = docs ["text" ]
401-
402- hints_str = (
403- "# Hints:\n Here are some hints for the task you are working on:\n "
404- + "\n " .join (docs )
405- )
406- return hints_str
407-
408- # Check if hint_db has the expected structure
409- if (
410- self .hint_db .empty
411- or "task_name" not in self .hint_db .columns
412- or "hint" not in self .hint_db .columns
413- ):
414- return ""
415-
416- try :
417- # When step-level, pass queries as goal string to fit the llm_prompt
418- goal_or_queries = self .goal
419- if self .hint_level == "step" and self .queries :
420- goal_or_queries = "\n " .join (self .queries )
421-
422- task_hints = self .hints_source .choose_hints (
423- self .llm ,
424- task_name ,
425- goal_or_queries ,
426- )
427-
428- hints = []
429- for hint in task_hints :
430- hint = hint .strip ()
431- if hint :
432- hints .append (f"- { hint } " )
433-
434- if len (hints ) > 0 :
435- hints_str = (
436- "# Hints:\n Here are some hints for the task you are working on:\n "
437- + "\n " .join (hints )
438- )
439- return hints_str
440- except Exception as e :
441- print (f"Warning: Error getting hints for task { task_name } : { e } " )
442-
443- return ""
444-
445321
446322class StepWiseContextIdentificationPrompt (dp .Shrinkable ):
447323 def __init__ (
0 commit comments