@@ -80,6 +80,7 @@ def __init__(
8080 actions : list [str ],
8181 memories : list [str ],
8282 thoughts : list [str ],
83+ hints : list [str ],
8384 previous_plan : str ,
8485 step : int ,
8586 flags : GenericPromptFlags ,
@@ -120,6 +121,7 @@ def time_for_caution():
120121 self .think = dp .Think (visible = lambda : flags .use_thinking )
121122 self .hints = dp .Hints (visible = lambda : flags .use_hints )
122123 goal_str : str = goal [0 ]["text" ]
124+ # TODO: This design is not very good as we will instantiate the loop up at every step
123125 self .task_hint = TaskHint (
124126 use_task_hint = flags .use_task_hint ,
125127 hint_db_path = flags .hint_db_path ,
@@ -147,7 +149,8 @@ def _prompt(self) -> HumanMessage:
147149
148150 # Add task hints if enabled
149151 task_hints_text = ""
150- if self .flags .use_task_hint and hasattr (self , "task_name" ):
152+ # if self.flags.use_task_hint and hasattr(self, "task_name"):
153+ if self .flags .use_task_hint :
151154 task_hints_text = self .task_hint .get_hints_for_task (self .task_name )
152155
153156 prompt .add_text (
@@ -371,19 +374,14 @@ def _init(self):
371374 try :
372375 if self .hint_type == "docs" :
373376 if self .hint_index_type == "sparse" :
374- print ("Loading sparse hint index" )
375377 import bm25s
376378 self .hint_index = bm25s .BM25 .load (self .hint_index_path , load_corpus = True )
377- print ("Sparse hint index loaded successfully" )
378379 elif self .hint_index_type == "dense" :
379- print ("Loading dense hint index and retriever" )
380380 from datasets import load_from_disk
381381 from sentence_transformers import SentenceTransformer
382382 self .hint_index = load_from_disk (self .hint_index_path )
383383 self .hint_index .load_faiss_index ("embeddings" , self .hint_index_path .removesuffix ("/" ) + ".faiss" )
384- print ("Dense hint index loaded successfully" )
385384 self .hint_retriever = SentenceTransformer (self .hint_retriever_path )
386- print ("Hint retriever loaded successfully" )
387385 else :
388386 raise ValueError (f"Unknown hint index type: { self .hint_index_type } " )
389387 else :
@@ -422,8 +420,8 @@ def get_hints_for_task(self, task_name: str) -> str:
422420
423421 if self .hint_type == "docs" :
424422 if not hasattr (self , "hint_index" ):
423+ print ("Initializing hint index new time" )
425424 self ._init ()
426-
427425 if self .hint_query_type == "goal" :
428426 query = self .goal
429427 elif self .hint_query_type == "llm" :
@@ -432,9 +430,15 @@ def get_hints_for_task(self, task_name: str) -> str:
432430 raise ValueError (f"Unknown hint query type: { self .hint_query_type } " )
433431
434432 if self .hint_index_type == "sparse" :
433+ import bm25s
435434 query_tokens = bm25s .tokenize (query )
436- docs = self .hint_index .search (query_tokens , k = self .hint_num_results )
437- docs = docs ["text" ]
435+ docs , _ = self .hint_index .retrieve (query_tokens , k = self .hint_num_results )
436+ docs = [elem ["text" ] for elem in docs [0 ]]
437+ # HACK: truncate to 20k characters (should cover >99% of the cases)
438+ for doc in docs :
439+ if len (doc ) > 20000 :
440+ doc = doc [:20000 ]
441+ doc += " ...[truncated]"
438442 elif self .hint_index_type == "dense" :
439443 query_embedding = self .hint_retriever .encode (query )
440444 _ , docs = self .hint_index .get_nearest_examples ("embeddings" , query_embedding , k = self .hint_num_results )
0 commit comments