Skip to content

Commit c074e9d

Browse files
committed
add opensearch retriever for the BM25 and neural benchmarks
Signed-off-by: Samuel Herman <sherman8915@gmail.com>
1 parent 1440b9d commit c074e9d

File tree

11 files changed

+975
-1
lines changed

11 files changed

+975
-1
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,6 @@ dmypy.json
134134

135135
# Pyre type checker
136136
.pyre/
137+
138+
# IDE
139+
.idea

beir/retrieval/search/lexical/bm25_search.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def __init__(
2626
maxsize: int = 24,
2727
number_of_shards: int = "default",
2828
initialize: bool = True,
29-
sleep_for: int = 2,
29+
sleep_for: int = 2
3030
):
3131
self.results = {}
3232
self.batch_size = batch_size
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from __future__ import annotations
2+
3+
from .opensearch_search import OpenSearchEngine
4+
5+
__all__ = ["OpenSearchEngine"]
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from __future__ import annotations
2+
3+
from .neural_search import NeuralSearch
4+
5+
__all__ = ["NeuralSearch"]
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
from __future__ import annotations
2+
3+
import time
4+
5+
import tqdm
6+
import logging
7+
8+
from ...base import BaseSearch
9+
from ..opensearch_search import OpenSearchEngine
10+
logger = logging.getLogger("NeuralSearch")
11+
12+
def sleep(seconds):
13+
if seconds:
14+
time.sleep(seconds)
15+
16+
17+
class NeuralSearch(BaseSearch):
18+
def __init__(
19+
self,
20+
index_name: str,
21+
hostname: str = "localhost",
22+
keys: dict[str, str] = {"title": "title", "body": "txt", "embedding": "embedding"},
23+
language: str = "english",
24+
batch_size: int = 128,
25+
timeout: int = 100,
26+
retry_on_timeout: bool = True,
27+
maxsize: int = 24,
28+
number_of_shards: int = "default",
29+
initialize: bool = True,
30+
sleep_for: int = 2
31+
):
32+
self.model_id = None
33+
self.results = {}
34+
self.batch_size = batch_size
35+
self.initialize = initialize
36+
self.sleep_for = sleep_for
37+
self.config = {
38+
"hostname": hostname,
39+
"index_name": index_name,
40+
"keys": keys,
41+
"timeout": timeout,
42+
"retry_on_timeout": retry_on_timeout,
43+
"maxsize": maxsize,
44+
"number_of_shards": number_of_shards,
45+
"language": language,
46+
}
47+
# Initialize OpenSearch engine
48+
self.os_engine = OpenSearchEngine(self.config)
49+
if self.initialize:
50+
self.initialise()
51+
52+
def initialise(self):
53+
"""
54+
Initialise OpenSearch for neural search.
55+
"""
56+
# Setup ML infrastructure
57+
self.os_engine.configure_ml_settings()
58+
# Register model group and get ID
59+
model_group_response = self.os_engine.register_model_group()
60+
model_group_id = model_group_response["model_group_id"]
61+
# Register model using group ID
62+
model_register_response = self.os_engine.register_model(model_group_id=model_group_id)
63+
logger.info(f"Model registration response: {model_register_response}")
64+
self.model_id = self.os_engine.wait_for_model_deployment(task_id=model_register_response["task_id"]) # Use this ID in create_ingest_pipeline
65+
logger.info(f"Model ID: {self.model_id}")
66+
deploy_task_response = self.os_engine.deploy_model(self.model_id)
67+
logger.info(f"Model deployment response: {deploy_task_response}")
68+
self.os_engine.wait_for_model_deployment(task_id=deploy_task_response["task_id"])
69+
# Create pipeline and index
70+
self.os_engine.create_ingest_pipeline(model_id=self.model_id)
71+
self.os_engine.create_neural_search_index()
72+
73+
def search(
74+
self,
75+
corpus: dict[str, dict[str, str]],
76+
queries: dict[str, str],
77+
top_k: int,
78+
*args,
79+
**kwargs,
80+
) -> dict[str, dict[str, float]]:
81+
# Index the corpus within elastic-search
82+
# False, if the corpus has been already indexed
83+
if self.initialize:
84+
self.index(corpus)
85+
# Sleep for few seconds so that elastic-search indexes the docs properly
86+
sleep(self.sleep_for)
87+
88+
# retrieve neural search results from OpenSearch
89+
query_ids = list(queries.keys())
90+
queries = [queries[qid] for qid in query_ids]
91+
92+
for start_idx in tqdm.trange(0, len(queries), self.batch_size, desc="que"):
93+
query_ids_batch = query_ids[start_idx : start_idx + self.batch_size]
94+
results = self.os_engine.neural_multisearch(
95+
texts=queries[start_idx : start_idx + self.batch_size],
96+
model_id=self.model_id,
97+
top_hits=top_k + 1,
98+
) # Add 1 extra if query is present with documents
99+
100+
for query_id, hit in zip(query_ids_batch, results):
101+
scores = {}
102+
for corpus_id, score in hit["hits"]:
103+
if corpus_id != query_id: # query doesnt return in results
104+
scores[corpus_id] = score
105+
self.results[query_id] = scores
106+
107+
return self.results
108+
109+
def index(self, corpus: dict[str, dict[str, str]]):
110+
progress = tqdm.tqdm(unit="docs", total=len(corpus))
111+
# dictionary structure = {_id: {title_key: title, text_key: text}}
112+
dictionary = {
113+
idx: {
114+
self.config["keys"]["title"]: corpus[idx].get("title", None),
115+
self.config["keys"]["body"]: corpus[idx].get("text", None),
116+
}
117+
for idx in list(corpus.keys())
118+
}
119+
self.os_engine.bulk_add_to_index(
120+
generate_actions=self.os_engine.generate_actions(dictionary=dictionary, update=False),
121+
progress=progress,
122+
)
123+
124+
def cleanup(self):
125+
self.os_engine.delete_index()
126+
self.os_engine.delete_ingest_pipeline()
127+
self.os_engine.undeploy_model(self.model_id)
128+
self.os_engine.delete_model(self.model_id)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from __future__ import annotations
2+
3+
from .bm25_search import BM25Search
4+
5+
__all__ = ["BM25Search"]
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
from __future__ import annotations
2+
3+
import time
4+
5+
import tqdm
6+
7+
from ...base import BaseSearch
8+
from ..opensearch_search import OpenSearchEngine
9+
10+
11+
def sleep(seconds):
12+
if seconds:
13+
time.sleep(seconds)
14+
15+
16+
class BM25Search(BaseSearch):
17+
def __init__(
18+
self,
19+
index_name: str,
20+
hostname: str = "localhost",
21+
keys: dict[str, str] = {"title": "title", "body": "txt"},
22+
language: str = "english",
23+
batch_size: int = 128,
24+
timeout: int = 100,
25+
retry_on_timeout: bool = True,
26+
maxsize: int = 24,
27+
number_of_shards: int = "default",
28+
initialize: bool = True,
29+
sleep_for: int = 2
30+
):
31+
self.results = {}
32+
self.batch_size = batch_size
33+
self.initialize = initialize
34+
self.sleep_for = sleep_for
35+
self.config = {
36+
"hostname": hostname,
37+
"index_name": index_name,
38+
"keys": keys,
39+
"timeout": timeout,
40+
"retry_on_timeout": retry_on_timeout,
41+
"maxsize": maxsize,
42+
"number_of_shards": number_of_shards,
43+
"language": language,
44+
}
45+
self.es = OpenSearchEngine(self.config)
46+
if self.initialize:
47+
self.initialise()
48+
49+
def initialise(self):
50+
self.es.delete_index()
51+
sleep(self.sleep_for)
52+
self.es.create_index()
53+
54+
def search(
55+
self,
56+
corpus: dict[str, dict[str, str]],
57+
queries: dict[str, str],
58+
top_k: int,
59+
*args,
60+
**kwargs,
61+
) -> dict[str, dict[str, float]]:
62+
# Index the corpus within elastic-search
63+
# False, if the corpus has been already indexed
64+
if self.initialize:
65+
self.index(corpus)
66+
# Sleep for few seconds so that elastic-search indexes the docs properly
67+
sleep(self.sleep_for)
68+
69+
# retrieve results from BM25
70+
query_ids = list(queries.keys())
71+
queries = [queries[qid] for qid in query_ids]
72+
73+
for start_idx in tqdm.trange(0, len(queries), self.batch_size, desc="que"):
74+
query_ids_batch = query_ids[start_idx : start_idx + self.batch_size]
75+
results = self.es.lexical_multisearch(
76+
texts=queries[start_idx : start_idx + self.batch_size],
77+
top_hits=top_k + 1,
78+
) # Add 1 extra if query is present with documents
79+
80+
for query_id, hit in zip(query_ids_batch, results):
81+
scores = {}
82+
for corpus_id, score in hit["hits"]:
83+
if corpus_id != query_id: # query doesnt return in results
84+
scores[corpus_id] = score
85+
self.results[query_id] = scores
86+
87+
return self.results
88+
89+
def index(self, corpus: dict[str, dict[str, str]]):
90+
progress = tqdm.tqdm(unit="docs", total=len(corpus))
91+
# dictionary structure = {_id: {title_key: title, text_key: text}}
92+
dictionary = {
93+
idx: {
94+
self.config["keys"]["title"]: corpus[idx].get("title", None),
95+
self.config["keys"]["body"]: corpus[idx].get("text", None),
96+
}
97+
for idx in list(corpus.keys())
98+
}
99+
self.es.bulk_add_to_index(
100+
generate_actions=self.es.generate_actions(dictionary=dictionary, update=False),
101+
progress=progress,
102+
)

0 commit comments

Comments
 (0)