@@ -61,6 +61,13 @@ class GenericPromptFlags(dp.Flags):
6161 add_missparsed_messages : bool = True
6262 max_trunc_itr : int = 20
6363 flag_group : str = None
64+ # hint flags
65+ hint_type : Literal ["human" , "llm" , "docs" ] = "human"
66+ hint_index_type : Literal ["sparse" , "dense" ] = "sparse"
67+ hint_query_type : Literal ["direct" , "llm" , "emb" ] = "direct"
68+ hint_index_path : str = None
69+ hint_retriever_path : str = None
70+ hint_num_results : int = 5
6471 n_retrieval_queries : int = 3
6572 hint_level : Literal ["episode" , "step" ] = "episode"
6673
@@ -120,6 +127,13 @@ def time_for_caution():
120127 hint_retrieval_mode = flags .task_hint_retrieval_mode ,
121128 llm = llm ,
122129 skip_hints_for_current_task = flags .skip_hints_for_current_task ,
130+ # hint related
131+ hint_type = flags .hint_type ,
132+ hint_index_type = flags .hint_index_type ,
133+ hint_query_type = flags .hint_query_type ,
134+ hint_index_path = flags .hint_index_path ,
135+ hint_retriever_path = flags .hint_retriever_path ,
136+ hint_num_results = flags .hint_num_results ,
123137 hint_level = flags .hint_level ,
124138 queries = queries ,
125139 )
@@ -307,14 +321,26 @@ def __init__(
307321 use_task_hint : bool ,
308322 hint_db_path : str ,
309323 goal : str ,
310- hint_retrieval_mode : Literal ["direct" , "llm" , "emb" ],
311- skip_hints_for_current_task : bool ,
312324 llm : ChatModel ,
325+ hint_type : Literal ["human" , "llm" , "docs" ] = "human" ,
326+ hint_index_type : Literal ["sparse" , "dense" ] = "sparse" ,
327+ hint_query_type : Literal ["direct" , "llm" , "emb" ] = "direct" ,
328+ hint_index_path : str = None ,
329+ hint_retriever_path : str = None ,
330+ hint_num_results : int = 5 ,
331+ skip_hints_for_current_task : bool = False ,
332+ hint_retrieval_mode : Literal ["direct" , "llm" , "emb" ] = "direct" ,
313333 hint_level : Literal ["episode" , "step" ] = "episode" ,
314334 queries : list [str ] | None = None ,
315335 ) -> None :
316336 super ().__init__ (visible = use_task_hint )
317337 self .use_task_hint = use_task_hint
338+ self .hint_type = hint_type
339+ self .hint_index_type = hint_index_type
340+ self .hint_query_type = hint_query_type
341+ self .hint_index_path = hint_index_path
342+ self .hint_retriever_path = hint_retriever_path
343+ self .hint_num_results = hint_num_results
318344 self .hint_db_rel_path = "hint_db.csv"
319345 self .hint_db_path = hint_db_path # Allow external path override
320346 self .hint_retrieval_mode : Literal ["direct" , "llm" , "emb" ] = hint_retrieval_mode
@@ -343,29 +369,47 @@ def __init__(
343369 def _init (self ):
344370 """Initialize the block."""
345371 try :
346- # Use external path if provided, otherwise fall back to relative path
347- if self .hint_db_path and Path (self .hint_db_path ).exists ():
348- hint_db_path = Path (self .hint_db_path )
372+ if self .hint_type == "docs" :
373+ if self .hint_index_type == "sparse" :
374+ print ("Loading sparse hint index" )
375+ import bm25s
376+ self .hint_index = bm25s .BM25 .load (self .hint_index_path , load_corpus = True )
377+ print ("Sparse hint index loaded successfully" )
378+ elif self .hint_index_type == "dense" :
379+ print ("Loading dense hint index and retriever" )
380+ from datasets import load_from_disk
381+ from sentence_transformers import SentenceTransformer
382+ self .hint_index = load_from_disk (self .hint_index_path )
383+ self .hint_index .load_faiss_index ("embeddings" , self .hint_index_path .removesuffix ("/" ) + ".faiss" )
384+ print ("Dense hint index loaded successfully" )
385+ self .hint_retriever = SentenceTransformer (self .hint_retriever_path )
386+ print ("Hint retriever loaded successfully" )
387+ else :
388+ raise ValueError (f"Unknown hint index type: { self .hint_index_type } " )
349389 else :
350- hint_db_path = Path (__file__ ).parent / self .hint_db_rel_path
351-
352- if hint_db_path .exists ():
353- self .hint_db = pd .read_csv (hint_db_path , header = 0 , index_col = None , dtype = str )
354- # Verify the expected columns exist
355- if "task_name" not in self .hint_db .columns or "hint" not in self .hint_db .columns :
356- print (
357- f"Warning: Hint database missing expected columns. Found: { list (self .hint_db .columns )} "
358- )
390+ # Use external path if provided, otherwise fall back to relative path
391+ if self .hint_db_path and Path (self .hint_db_path ).exists ():
392+ hint_db_path = Path (self .hint_db_path )
393+ else :
394+ hint_db_path = Path (__file__ ).parent / self .hint_db_rel_path
395+
396+ if hint_db_path .exists ():
397+ self .hint_db = pd .read_csv (hint_db_path , header = 0 , index_col = None , dtype = str )
398+ # Verify the expected columns exist
399+ if "task_name" not in self .hint_db .columns or "hint" not in self .hint_db .columns :
400+ print (
401+ f"Warning: Hint database missing expected columns. Found: { list (self .hint_db .columns )} "
402+ )
403+ self .hint_db = pd .DataFrame (columns = ["task_name" , "hint" ])
404+ else :
405+ print (f"Warning: Hint database not found at { hint_db_path } " )
359406 self .hint_db = pd .DataFrame (columns = ["task_name" , "hint" ])
360- else :
361- print (f"Warning: Hint database not found at { hint_db_path } " )
362- self .hint_db = pd .DataFrame (columns = ["task_name" , "hint" ])
363-
364- self .hints_source = HintsSource (
365- hint_db_path = hint_db_path .as_posix (),
366- hint_retrieval_mode = self .hint_retrieval_mode ,
367- skip_hints_for_current_task = self .skip_hints_for_current_task ,
368- )
407+
408+ self .hints_source = HintsSource (
409+ hint_db_path = hint_db_path .as_posix (),
410+ hint_retrieval_mode = self .hint_retrieval_mode ,
411+ skip_hints_for_current_task = self .skip_hints_for_current_task ,
412+ )
369413 except Exception as e :
370414 # Fallback to empty database on any error
371415 print (f"Warning: Could not load hint database: { e } " )
@@ -376,6 +420,32 @@ def get_hints_for_task(self, task_name: str) -> str:
376420 if not self .use_task_hint :
377421 return ""
378422
423+ if self .hint_type == "docs" :
424+ if not hasattr (self , "hint_index" ):
425+ self ._init ()
426+
427+ if self .hint_query_type == "goal" :
428+ query = self .goal
429+ elif self .hint_query_type == "llm" :
430+ query = self .llm .generate (self ._prompt + self ._abstract_ex + self ._concrete_ex )
431+ else :
432+ raise ValueError (f"Unknown hint query type: { self .hint_query_type } " )
433+
434+ if self .hint_index_type == "sparse" :
435+ query_tokens = bm25s .tokenize (query )
436+ docs = self .hint_index .search (query_tokens , k = self .hint_num_results )
437+ docs = docs ["text" ]
438+ elif self .hint_index_type == "dense" :
439+ query_embedding = self .hint_retriever .encode (query )
440+ _ , docs = self .hint_index .get_nearest_examples ("embeddings" , query_embedding , k = self .hint_num_results )
441+ docs = docs ["text" ]
442+
443+ hints_str = (
444+ "# Hints:\n Here are some hints for the task you are working on:\n "
445+ + "\n " .join (docs )
446+ )
447+ return hints_str
448+
379449 # Ensure hint_db is initialized
380450 if not hasattr (self , "hint_db" ):
381451 self ._init ()
0 commit comments