@@ -57,14 +57,14 @@ class BatchedNearestExamplesResults(NamedTuple):
5757class BaseIndex :
5858 """Base class for indexing"""
5959
60- def search (self , query , k : int = 10 ) -> SearchResults :
60+ def search (self , query , k : int = 10 , ** kwargs ) -> SearchResults :
6161 """
6262 To implement.
6363 This method has to return the scores and the indices of the retrieved examples given a certain query.
6464 """
6565 raise NotImplementedError
6666
67- def search_batch (self , queries , k : int = 10 ) -> BatchedSearchResults :
67+ def search_batch (self , queries , k : int = 10 , ** kwargs ) -> BatchedSearchResults :
6868 """Find the nearest examples indices to the query.
6969
7070 Args:
@@ -176,7 +176,7 @@ def passage_generator():
176176 )
177177 logger .info (f"Indexed { successes :d} documents" )
178178
179- def search (self , query : str , k = 10 ) -> SearchResults :
179+ def search (self , query : str , k = 10 , ** kwargs ) -> SearchResults :
180180 """Find the nearest examples indices to the query.
181181
182182 Args:
@@ -190,16 +190,17 @@ def search(self, query: str, k=10) -> SearchResults:
190190 response = self .es_client .search (
191191 index = self .es_index_name ,
192192 body = {"query" : {"multi_match" : {"query" : query , "fields" : ["text" ], "type" : "cross_fields" }}, "size" : k },
193+ ** kwargs ,
193194 )
194195 hits = response ["hits" ]["hits" ]
195196 return SearchResults ([hit ["_score" ] for hit in hits ], [int (hit ["_id" ]) for hit in hits ])
196197
197- def search_batch (self , queries , k : int = 10 , max_workers = 10 ) -> BatchedSearchResults :
198+ def search_batch (self , queries , k : int = 10 , max_workers = 10 , ** kwargs ) -> BatchedSearchResults :
198199 import concurrent .futures
199200
200201 total_scores , total_indices = [None ] * len (queries ), [None ] * len (queries )
201202 with concurrent .futures .ThreadPoolExecutor (max_workers = max_workers ) as executor :
202- future_to_index = {executor .submit (self .search , query , k ): i for i , query in enumerate (queries )}
203+ future_to_index = {executor .submit (self .search , query , k , ** kwargs ): i for i , query in enumerate (queries )}
203204 for future in concurrent .futures .as_completed (future_to_index ):
204205 index = future_to_index [future ]
205206 results : SearchResults = future .result ()
@@ -337,7 +338,7 @@ def _faiss_index_to_device(index: "faiss.Index", device: Optional[Union[int, Lis
337338
338339 return index
339340
340- def search (self , query : np .array , k = 10 ) -> SearchResults :
341+ def search (self , query : np .array , k = 10 , ** kwargs ) -> SearchResults :
341342 """Find the nearest examples indices to the query.
342343
343344 Args:
@@ -354,10 +355,10 @@ def search(self, query: np.array, k=10) -> SearchResults:
354355 queries = query .reshape (1 , - 1 )
355356 if not queries .flags .c_contiguous :
356357 queries = np .asarray (queries , order = "C" )
357- scores , indices = self .faiss_index .search (queries , k )
358+ scores , indices = self .faiss_index .search (queries , k , ** kwargs )
358359 return SearchResults (scores [0 ], indices [0 ].astype (int ))
359360
360- def search_batch (self , queries : np .array , k = 10 ) -> BatchedSearchResults :
361+ def search_batch (self , queries : np .array , k = 10 , ** kwargs ) -> BatchedSearchResults :
361362 """Find the nearest examples indices to the queries.
362363
363364 Args:
@@ -372,7 +373,7 @@ def search_batch(self, queries: np.array, k=10) -> BatchedSearchResults:
372373 raise ValueError ("Shape of query must be 2D" )
373374 if not queries .flags .c_contiguous :
374375 queries = np .asarray (queries , order = "C" )
375- scores , indices = self .faiss_index .search (queries , k )
376+ scores , indices = self .faiss_index .search (queries , k , ** kwargs )
376377 return BatchedSearchResults (scores , indices .astype (int ))
377378
378379 def save (self , file : Union [str , PurePath ]):
@@ -667,7 +668,7 @@ def drop_index(self, index_name: str):
667668 """
668669 del self ._indexes [index_name ]
669670
670- def search (self , index_name : str , query : Union [str , np .array ], k : int = 10 ) -> SearchResults :
671+ def search (self , index_name : str , query : Union [str , np .array ], k : int = 10 , ** kwargs ) -> SearchResults :
671672 """Find the nearest examples indices in the dataset to the query.
672673
673674 Args:
@@ -683,9 +684,11 @@ def search(self, index_name: str, query: Union[str, np.array], k: int = 10) -> S
683684 - indices (`List[List[int]]`): The indices of the retrieved examples.
684685 """
685686 self ._check_index_is_initialized (index_name )
686- return self ._indexes [index_name ].search (query , k )
687+ return self ._indexes [index_name ].search (query , k , ** kwargs )
687688
688- def search_batch (self , index_name : str , queries : Union [List [str ], np .array ], k : int = 10 ) -> BatchedSearchResults :
689+ def search_batch (
690+ self , index_name : str , queries : Union [List [str ], np .array ], k : int = 10 , ** kwargs
691+ ) -> BatchedSearchResults :
689692 """Find the nearest examples indices in the dataset to the query.
690693
691694 Args:
@@ -701,10 +704,10 @@ def search_batch(self, index_name: str, queries: Union[List[str], np.array], k:
701704 - total_indices (`List[List[int]]`): The indices of the retrieved examples per query.
702705 """
703706 self ._check_index_is_initialized (index_name )
704- return self ._indexes [index_name ].search_batch (queries , k )
707+ return self ._indexes [index_name ].search_batch (queries , k , ** kwargs )
705708
706709 def get_nearest_examples (
707- self , index_name : str , query : Union [str , np .array ], k : int = 10
710+ self , index_name : str , query : Union [str , np .array ], k : int = 10 , ** kwargs
708711 ) -> NearestExamplesResults :
709712 """Find the nearest examples in the dataset to the query.
710713
@@ -721,12 +724,12 @@ def get_nearest_examples(
721724 - examples (`dict`): The retrieved examples.
722725 """
723726 self ._check_index_is_initialized (index_name )
724- scores , indices = self .search (index_name , query , k )
727+ scores , indices = self .search (index_name , query , k , ** kwargs )
725728 top_indices = [i for i in indices if i >= 0 ]
726729 return NearestExamplesResults (scores [: len (top_indices )], self [top_indices ])
727730
728731 def get_nearest_examples_batch (
729- self , index_name : str , queries : Union [List [str ], np .array ], k : int = 10
732+ self , index_name : str , queries : Union [List [str ], np .array ], k : int = 10 , ** kwargs
730733 ) -> BatchedNearestExamplesResults :
731734 """Find the nearest examples in the dataset to the query.
732735
@@ -743,7 +746,7 @@ def get_nearest_examples_batch(
743746 - total_examples (`List[dict]`): The retrieved examples per query.
744747 """
745748 self ._check_index_is_initialized (index_name )
746- total_scores , total_indices = self .search_batch (index_name , queries , k )
749+ total_scores , total_indices = self .search_batch (index_name , queries , k , ** kwargs )
747750 total_scores = [
748751 scores_i [: len ([i for i in indices_i if i >= 0 ])]
749752 for scores_i , indices_i in zip (total_scores , total_indices )
0 commit comments