@@ -242,14 +242,16 @@ def _generate_context(self, question: str, text_chunk: str) -> t.List[str]:
242242 for qstn in question .split ("\n " )
243243 ]
244244
245- def _remove_nodes (self , available_indices : list , node_idx : list ) -> t .List :
245+ def _remove_nodes (
246+ self , available_indices : list [BaseNode ], node_idx : list
247+ ) -> t .List [BaseNode ]:
246248 for idx in node_idx :
247249 available_indices .remove (idx )
248250 return available_indices
249251
250252 def _generate_doc_nodes_map (
251253 self , documenet_nodes : t .List [BaseNode ]
252- ) -> t .Dict [str , BaseNode ]:
254+ ) -> t .Dict [str , t . List [ BaseNode ] ]:
253255 doc_nodes_map : t .Dict [str , t .List [BaseNode ]] = defaultdict (list [BaseNode ])
254256 for node in documenet_nodes :
255257 if node .ref_doc_id :
@@ -288,15 +290,15 @@ def generate(
288290
289291 if isinstance (documents [0 ], LangchainDocument ):
290292 # cast to LangchainDocument since its the only case here
291- documents = t .cast (list [LangchainDocument ], documents )
293+ documents = t .cast (t . List [LangchainDocument ], documents )
292294 documents = [
293295 LlamaindexDocument .from_langchain_format (doc ) for doc in documents
294296 ]
295297 # Convert documents into nodes
296298 node_parser = SimpleNodeParser .from_defaults (
297299 chunk_size = self .chunk_size , chunk_overlap = 0 , include_metadata = True
298300 )
299- documents = t .cast (list [LlamaindexDocument ], documents )
301+ documents = t .cast (t . List [LlamaindexDocument ], documents )
300302 document_nodes : t .List [BaseNode ] = node_parser .get_nodes_from_documents (
301303 documents = documents
302304 )
@@ -319,7 +321,7 @@ def generate(
319321 pbar = tqdm (total = test_size )
320322 while count < test_size and available_nodes != []:
321323 evolve_type = self ._get_evolve_type ()
322- curr_node = self .rng .choice (available_nodes , size = 1 )[0 ]
324+ curr_node = self .rng .choice (np . array ( available_nodes ) , size = 1 )[0 ]
323325 available_nodes = self ._remove_nodes (available_nodes , [curr_node ])
324326
325327 neighbor_nodes = doc_nodes_map [curr_node .source_node .node_id ]
@@ -353,6 +355,8 @@ def generate(
353355 similarity_cutoff = self .threshold / 10 ,
354356 )
355357 if indices :
358+ # type cast indices from list[Any] to list[int]
359+ indices = t .cast (t .List [int ], indices )
356360 best_neighbor = neighbor_nodes [indices [0 ]]
357361 question = self ._multicontext_question (
358362 question = seed_question ,
0 commit comments