@@ -120,6 +120,12 @@ def get_action(self, obs):
120120 queries if getattr (self .flags , "hint_level" , "episode" ) == "step" else None
121121 )
122122
123+ # get hints
124+ if self .flags .use_hints :
125+ task_hints = self ._get_task_hints ()
126+ else :
127+ task_hints = []
128+
123129 main_prompt = MainPrompt (
124130 action_set = self .action_set ,
125131 obs_history = self .obs_history ,
@@ -130,7 +136,7 @@ def get_action(self, obs):
130136 step = self .plan_step ,
131137 flags = self .flags ,
132138 llm = self .chat_llm ,
133- queries = queries_for_hints ,
139+ task_hints = task_hints ,
134140 )
135141
136142 # Set task name for task hints if available
@@ -303,60 +309,57 @@ def _init_hints_index(self):
303309 print (f"Warning: Could not load hint database: { e } " )
304310 self .hint_db = pd .DataFrame (columns = ["task_name" , "hint" ])
305311
306- def get_hints_for_task (self , task_name : str ) -> str :
312+ def _get_task_hints (self ) -> list [ str ] :
307313 """Get hints for a specific task."""
308- if not self .use_task_hint :
309- return ""
314+ if not self .flags . use_task_hint :
315+ return []
310316
311- if self .hint_type == "docs" :
317+ if self .flags . hint_type == "docs" :
312318 if not hasattr (self , "hint_index" ):
313319 print ("Initializing hint index new time" )
314320 self ._init ()
315- if self .hint_query_type == "goal" :
316- query = self .goal
317- elif self .hint_query_type == "llm" :
321+ if self .flags . hint_query_type == "goal" :
322+ query = self .obs_history [ - 1 ][ "goal_object" ][ 0 ][ "text" ]
323+ elif self .flags . hint_query_type == "llm" :
318324 query = self .llm .generate (self ._prompt + self ._abstract_ex + self ._concrete_ex )
319325 else :
320- raise ValueError (f"Unknown hint query type: { self .hint_query_type } " )
326+ raise ValueError (f"Unknown hint query type: { self .flags . hint_query_type } " )
321327
322- if self .hint_index_type == "sparse" :
328+ print (f"Query: { query } " )
329+ if self .flags .hint_index_type == "sparse" :
323330 import bm25s
324331 query_tokens = bm25s .tokenize (query )
325- docs , _ = self .hint_index .retrieve (query_tokens , k = self .hint_num_results )
332+ docs , _ = self .hint_index .retrieve (query_tokens , k = self .flags . hint_num_results )
326333 docs = [elem ["text" ] for elem in docs [0 ]]
327334 # HACK: truncate to 20k characters (should cover >99% of the cases)
328335 for doc in docs :
329336 if len (doc ) > 20000 :
330337 doc = doc [:20000 ]
331338 doc += " ...[truncated]"
332- elif self .hint_index_type == "dense" :
339+ elif self .flags . hint_index_type == "dense" :
333340 query_embedding = self .hint_retriever .encode (query )
334- _ , docs = self .hint_index .get_nearest_examples ("embeddings" , query_embedding , k = self .hint_num_results )
341+ _ , docs = self .hint_index .get_nearest_examples ("embeddings" , query_embedding , k = self .flags . hint_num_results )
335342 docs = docs ["text" ]
336343
337- hints_str = (
338- "# Hints:\n Here are some hints for the task you are working on:\n "
339- + "\n " .join (docs )
340- )
341- return hints_str
344+ return docs
342345
343346 # Check if hint_db has the expected structure
344347 if (
345348 self .hint_db .empty
346349 or "task_name" not in self .hint_db .columns
347350 or "hint" not in self .hint_db .columns
348351 ):
349- return ""
352+ return []
350353
351354 try :
352355 # When step-level, pass queries as goal string to fit the llm_prompt
353- goal_or_queries = self .goal
354- if self .hint_level == "step" and self .queries :
356+ goal_or_queries = self .obs_history [ - 1 ][ "goal_object" ][ 0 ][ "text" ]
357+ if self .flags . hint_level == "step" and self .queries :
355358 goal_or_queries = "\n " .join (self .queries )
356359
357360 task_hints = self .hints_source .choose_hints (
358361 self .llm ,
359- task_name ,
362+ self . task_name ,
360363 goal_or_queries ,
361364 )
362365
@@ -366,13 +369,8 @@ def get_hints_for_task(self, task_name: str) -> str:
366369 if hint :
367370 hints .append (f"- { hint } " )
368371
369- if len (hints ) > 0 :
370- hints_str = (
371- "# Hints:\n Here are some hints for the task you are working on:\n "
372- + "\n " .join (hints )
373- )
374- return hints_str
372+ return hints
375373 except Exception as e :
376- print (f"Warning: Error getting hints for task { task_name } : { e } " )
374+ print (f"Warning: Error getting hints for task { self . task_name } : { e } " )
377375
378- return ""
376+ return []
0 commit comments