Skip to content

Commit d862821

Browse files
authored
add kwargs to index search (#5628)
* add kwargs to search * format
1 parent ce1d107 commit d862821

File tree

2 files changed

+36
-17
lines changed

2 files changed

+36
-17
lines changed

src/datasets/search.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -57,14 +57,14 @@ class BatchedNearestExamplesResults(NamedTuple):
5757
class BaseIndex:
5858
"""Base class for indexing"""
5959

60-
def search(self, query, k: int = 10) -> SearchResults:
60+
def search(self, query, k: int = 10, **kwargs) -> SearchResults:
6161
"""
6262
To implement.
6363
This method has to return the scores and the indices of the retrieved examples given a certain query.
6464
"""
6565
raise NotImplementedError
6666

67-
def search_batch(self, queries, k: int = 10) -> BatchedSearchResults:
67+
def search_batch(self, queries, k: int = 10, **kwargs) -> BatchedSearchResults:
6868
"""Find the nearest examples indices to the query.
6969
7070
Args:
@@ -176,7 +176,7 @@ def passage_generator():
176176
)
177177
logger.info(f"Indexed {successes:d} documents")
178178

179-
def search(self, query: str, k=10) -> SearchResults:
179+
def search(self, query: str, k=10, **kwargs) -> SearchResults:
180180
"""Find the nearest examples indices to the query.
181181
182182
Args:
@@ -190,16 +190,17 @@ def search(self, query: str, k=10) -> SearchResults:
190190
response = self.es_client.search(
191191
index=self.es_index_name,
192192
body={"query": {"multi_match": {"query": query, "fields": ["text"], "type": "cross_fields"}}, "size": k},
193+
**kwargs,
193194
)
194195
hits = response["hits"]["hits"]
195196
return SearchResults([hit["_score"] for hit in hits], [int(hit["_id"]) for hit in hits])
196197

197-
def search_batch(self, queries, k: int = 10, max_workers=10) -> BatchedSearchResults:
198+
def search_batch(self, queries, k: int = 10, max_workers=10, **kwargs) -> BatchedSearchResults:
198199
import concurrent.futures
199200

200201
total_scores, total_indices = [None] * len(queries), [None] * len(queries)
201202
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
202-
future_to_index = {executor.submit(self.search, query, k): i for i, query in enumerate(queries)}
203+
future_to_index = {executor.submit(self.search, query, k, **kwargs): i for i, query in enumerate(queries)}
203204
for future in concurrent.futures.as_completed(future_to_index):
204205
index = future_to_index[future]
205206
results: SearchResults = future.result()
@@ -337,7 +338,7 @@ def _faiss_index_to_device(index: "faiss.Index", device: Optional[Union[int, Lis
337338

338339
return index
339340

340-
def search(self, query: np.array, k=10) -> SearchResults:
341+
def search(self, query: np.array, k=10, **kwargs) -> SearchResults:
341342
"""Find the nearest examples indices to the query.
342343
343344
Args:
@@ -354,10 +355,10 @@ def search(self, query: np.array, k=10) -> SearchResults:
354355
queries = query.reshape(1, -1)
355356
if not queries.flags.c_contiguous:
356357
queries = np.asarray(queries, order="C")
357-
scores, indices = self.faiss_index.search(queries, k)
358+
scores, indices = self.faiss_index.search(queries, k, **kwargs)
358359
return SearchResults(scores[0], indices[0].astype(int))
359360

360-
def search_batch(self, queries: np.array, k=10) -> BatchedSearchResults:
361+
def search_batch(self, queries: np.array, k=10, **kwargs) -> BatchedSearchResults:
361362
"""Find the nearest examples indices to the queries.
362363
363364
Args:
@@ -372,7 +373,7 @@ def search_batch(self, queries: np.array, k=10) -> BatchedSearchResults:
372373
raise ValueError("Shape of query must be 2D")
373374
if not queries.flags.c_contiguous:
374375
queries = np.asarray(queries, order="C")
375-
scores, indices = self.faiss_index.search(queries, k)
376+
scores, indices = self.faiss_index.search(queries, k, **kwargs)
376377
return BatchedSearchResults(scores, indices.astype(int))
377378

378379
def save(self, file: Union[str, PurePath]):
@@ -667,7 +668,7 @@ def drop_index(self, index_name: str):
667668
"""
668669
del self._indexes[index_name]
669670

670-
def search(self, index_name: str, query: Union[str, np.array], k: int = 10) -> SearchResults:
671+
def search(self, index_name: str, query: Union[str, np.array], k: int = 10, **kwargs) -> SearchResults:
671672
"""Find the nearest examples indices in the dataset to the query.
672673
673674
Args:
@@ -683,9 +684,11 @@ def search(self, index_name: str, query: Union[str, np.array], k: int = 10) -> S
683684
- indices (`List[List[int]]`): The indices of the retrieved examples.
684685
"""
685686
self._check_index_is_initialized(index_name)
686-
return self._indexes[index_name].search(query, k)
687+
return self._indexes[index_name].search(query, k, **kwargs)
687688

688-
def search_batch(self, index_name: str, queries: Union[List[str], np.array], k: int = 10) -> BatchedSearchResults:
689+
def search_batch(
690+
self, index_name: str, queries: Union[List[str], np.array], k: int = 10, **kwargs
691+
) -> BatchedSearchResults:
689692
"""Find the nearest examples indices in the dataset to the query.
690693
691694
Args:
@@ -701,10 +704,10 @@ def search_batch(self, index_name: str, queries: Union[List[str], np.array], k:
701704
- total_indices (`List[List[int]]`): The indices of the retrieved examples per query.
702705
"""
703706
self._check_index_is_initialized(index_name)
704-
return self._indexes[index_name].search_batch(queries, k)
707+
return self._indexes[index_name].search_batch(queries, k, **kwargs)
705708

706709
def get_nearest_examples(
707-
self, index_name: str, query: Union[str, np.array], k: int = 10
710+
self, index_name: str, query: Union[str, np.array], k: int = 10, **kwargs
708711
) -> NearestExamplesResults:
709712
"""Find the nearest examples in the dataset to the query.
710713
@@ -721,12 +724,12 @@ def get_nearest_examples(
721724
- examples (`dict`): The retrieved examples.
722725
"""
723726
self._check_index_is_initialized(index_name)
724-
scores, indices = self.search(index_name, query, k)
727+
scores, indices = self.search(index_name, query, k, **kwargs)
725728
top_indices = [i for i in indices if i >= 0]
726729
return NearestExamplesResults(scores[: len(top_indices)], self[top_indices])
727730

728731
def get_nearest_examples_batch(
729-
self, index_name: str, queries: Union[List[str], np.array], k: int = 10
732+
self, index_name: str, queries: Union[List[str], np.array], k: int = 10, **kwargs
730733
) -> BatchedNearestExamplesResults:
731734
"""Find the nearest examples in the dataset to the query.
732735
@@ -743,7 +746,7 @@ def get_nearest_examples_batch(
743746
- total_examples (`List[dict]`): The retrieved examples per query.
744747
"""
745748
self._check_index_is_initialized(index_name)
746-
total_scores, total_indices = self.search_batch(index_name, queries, k)
749+
total_scores, total_indices = self.search_batch(index_name, queries, k, **kwargs)
747750
total_scores = [
748751
scores_i[: len([i for i in indices_i if i >= 0])]
749752
for scores_i, indices_i in zip(total_scores, total_indices)

tests/test_search.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,13 @@ def test_elasticsearch(self):
188188
self.assertEqual(scores[0], 1)
189189
self.assertEqual(indices[0], 0)
190190

191+
# single query with timeout
192+
query = "foo"
193+
mocked_search.return_value = {"hits": {"hits": [{"_score": 1, "_id": 0}]}}
194+
scores, indices = index.search(query, request_timeout=30)
195+
self.assertEqual(scores[0], 1)
196+
self.assertEqual(indices[0], 0)
197+
191198
# batched queries
192199
queries = ["foo", "bar", "foobar"]
193200
mocked_search.return_value = {"hits": {"hits": [{"_score": 1, "_id": 1}]}}
@@ -196,3 +203,12 @@ def test_elasticsearch(self):
196203
best_indices = [indices[0] for indices in total_indices]
197204
self.assertGreater(np.min(best_scores), 0)
198205
self.assertListEqual([1, 1, 1], best_indices)
206+
207+
# batched queries with timeout
208+
queries = ["foo", "bar", "foobar"]
209+
mocked_search.return_value = {"hits": {"hits": [{"_score": 1, "_id": 1}]}}
210+
total_scores, total_indices = index.search_batch(queries, request_timeout=30)
211+
best_scores = [scores[0] for scores in total_scores]
212+
best_indices = [indices[0] for indices in total_indices]
213+
self.assertGreater(np.min(best_scores), 0)
214+
self.assertListEqual([1, 1, 1], best_indices)

0 commit comments

Comments
 (0)