@@ -60,6 +60,13 @@ class GenericPromptFlags(dp.Flags):
6060 add_missparsed_messages : bool = True
6161 max_trunc_itr : int = 20
6262 flag_group : str = None
63+ # hint flags
64+ hint_type : Literal ["human" , "llm" , "docs" ] = "human"
65+ hint_index_type : Literal ["sparse" , "dense" ] = "sparse"
66+ hint_query_type : Literal ["direct" , "llm" , "emb" ] = "direct"
67+ hint_index_path : str = None
68+ hint_retriever_path : str = None
69+ hint_num_results : int = 5
6370
6471
6572class MainPrompt (dp .Shrinkable ):
@@ -116,6 +123,13 @@ def time_for_caution():
116123 hint_retrieval_mode = flags .task_hint_retrieval_mode ,
117124 llm = llm ,
118125 skip_hints_for_current_task = flags .skip_hints_for_current_task ,
126+ # hint related
127+ hint_type = flags .hint_type ,
128+ hint_index_type = flags .hint_index_type ,
129+ hint_query_type = flags .hint_query_type ,
130+ hint_index_path = flags .hint_index_path ,
131+ hint_retriever_path = flags .hint_retriever_path ,
132+ hint_num_results = flags .hint_num_results ,
119133 )
120134 self .plan = Plan (previous_plan , step , lambda : flags .use_plan ) # TODO add previous plan
121135 self .criticise = Criticise (visible = lambda : flags .use_criticise )
@@ -301,12 +315,24 @@ def __init__(
301315 use_task_hint : bool ,
302316 hint_db_path : str ,
303317 goal : str ,
304- hint_retrieval_mode : Literal ["direct" , "llm" , "emb" ],
305- skip_hints_for_current_task : bool ,
306318 llm : ChatModel ,
319+ hint_type : Literal ["human" , "llm" , "docs" ] = "human" ,
320+ hint_index_type : Literal ["sparse" , "dense" ] = "sparse" ,
321+ hint_query_type : Literal ["direct" , "llm" , "emb" ] = "direct" ,
322+ hint_index_path : str = None ,
323+ hint_retriever_path : str = None ,
324+ hint_num_results : int = 5 ,
325+ skip_hints_for_current_task : bool = False ,
326+ hint_retrieval_mode : Literal ["direct" , "llm" , "emb" ] = "direct" ,
307327 ) -> None :
308328 super ().__init__ (visible = use_task_hint )
309329 self .use_task_hint = use_task_hint
330+ self .hint_type = hint_type
331+ self .hint_index_type = hint_index_type
332+ self .hint_query_type = hint_query_type
333+ self .hint_index_path = hint_index_path
334+ self .hint_retriever_path = hint_retriever_path
335+ self .hint_num_results = hint_num_results
310336 self .hint_db_rel_path = "hint_db.csv"
311337 self .hint_db_path = hint_db_path # Allow external path override
312338 self .hint_retrieval_mode : Literal ["direct" , "llm" , "emb" ] = hint_retrieval_mode
@@ -333,28 +359,46 @@ def __init__(
333359 def _init (self ):
334360 """Initialize the block."""
335361 try :
336- # Use external path if provided, otherwise fall back to relative path
337- if self .hint_db_path and Path (self .hint_db_path ).exists ():
338- hint_db_path = Path (self .hint_db_path )
362+ if self .hint_type == "docs" :
363+ if self .hint_index_type == "sparse" :
364+ print ("Loading sparse hint index" )
365+ import bm25s
366+ self .hint_index = bm25s .BM25 .load (self .hint_index_path , load_corpus = True )
367+ print ("Sparse hint index loaded successfully" )
368+ elif self .hint_index_type == "dense" :
369+ print ("Loading dense hint index and retriever" )
370+ from datasets import load_from_disk
371+ from sentence_transformers import SentenceTransformer
372+ self .hint_index = load_from_disk (self .hint_index_path )
373+ self .hint_index .load_faiss_index ("embeddings" , self .hint_index_path .removesuffix ("/" ) + ".faiss" )
374+ print ("Dense hint index loaded successfully" )
375+ self .hint_retriever = SentenceTransformer (self .hint_retriever_path )
376+ print ("Hint retriever loaded successfully" )
377+ else :
378+ raise ValueError (f"Unknown hint index type: { self .hint_index_type } " )
339379 else :
340- hint_db_path = Path (__file__ ).parent / self .hint_db_rel_path
341-
342- if hint_db_path .exists ():
343- self .hint_db = pd .read_csv (hint_db_path , header = 0 , index_col = None , dtype = str )
344- # Verify the expected columns exist
345- if "task_name" not in self .hint_db .columns or "hint" not in self .hint_db .columns :
346- print (
347- f"Warning: Hint database missing expected columns. Found: { list (self .hint_db .columns )} "
348- )
380+ # Use external path if provided, otherwise fall back to relative path
381+ if self .hint_db_path and Path (self .hint_db_path ).exists ():
382+ hint_db_path = Path (self .hint_db_path )
383+ else :
384+ hint_db_path = Path (__file__ ).parent / self .hint_db_rel_path
385+
386+ if hint_db_path .exists ():
387+ self .hint_db = pd .read_csv (hint_db_path , header = 0 , index_col = None , dtype = str )
388+ # Verify the expected columns exist
389+ if "task_name" not in self .hint_db .columns or "hint" not in self .hint_db .columns :
390+ print (
391+ f"Warning: Hint database missing expected columns. Found: { list (self .hint_db .columns )} "
392+ )
393+ self .hint_db = pd .DataFrame (columns = ["task_name" , "hint" ])
394+ else :
395+ print (f"Warning: Hint database not found at { hint_db_path } " )
349396 self .hint_db = pd .DataFrame (columns = ["task_name" , "hint" ])
350- else :
351- print (f"Warning: Hint database not found at { hint_db_path } " )
352- self .hint_db = pd .DataFrame (columns = ["task_name" , "hint" ])
353- self .hints_source = HintsSource (
354- hint_db_path = hint_db_path .as_posix (),
355- hint_retrieval_mode = self .hint_retrieval_mode ,
356- skip_hints_for_current_task = self .skip_hints_for_current_task ,
357- )
397+ self .hints_source = HintsSource (
398+ hint_db_path = hint_db_path .as_posix (),
399+ hint_retrieval_mode = self .hint_retrieval_mode ,
400+ skip_hints_for_current_task = self .skip_hints_for_current_task ,
401+ )
358402 except Exception as e :
359403 # Fallback to empty database on any error
360404 print (f"Warning: Could not load hint database: { e } " )
@@ -365,6 +409,32 @@ def get_hints_for_task(self, task_name: str) -> str:
365409 if not self .use_task_hint :
366410 return ""
367411
412+ if self .hint_type == "docs" :
413+ if not hasattr (self , "hint_index" ):
414+ self ._init ()
415+
416+ if self .hint_query_type == "goal" :
417+ query = self .goal
418+ elif self .hint_query_type == "llm" :
419+ query = self .llm .generate (self ._prompt + self ._abstract_ex + self ._concrete_ex )
420+ else :
421+ raise ValueError (f"Unknown hint query type: { self .hint_query_type } " )
422+
423+ if self .hint_index_type == "sparse" :
424+ query_tokens = bm25s .tokenize (query )
425+ docs = self .hint_index .search (query_tokens , k = self .hint_num_results )
426+ docs = docs ["text" ]
427+ elif self .hint_index_type == "dense" :
428+ query_embedding = self .hint_retriever .encode (query )
429+ _ , docs = self .hint_index .get_nearest_examples ("embeddings" , query_embedding , k = self .hint_num_results )
430+ docs = docs ["text" ]
431+
432+ hints_str = (
433+ "# Hints:\n Here are some hints for the task you are working on:\n "
434+ + "\n " .join (docs )
435+ )
436+ return hints_str
437+
368438 # Ensure hint_db is initialized
369439 if not hasattr (self , "hint_db" ):
370440 self ._init ()
0 commit comments