Skip to content

Commit ddb59bf

Browse files
authored
Add batch prediction for pipelines (#3432)
* Add batch prediction for pipelines * Fix some hardcode problem& Update comments
1 parent ee696c3 commit ddb59bf

File tree

9 files changed

+548
-37
lines changed

9 files changed

+548
-37
lines changed

pipelines/examples/semantic-search/semantic_search_example.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,22 @@ def semantic_search_tutorial():
209209
})
210210

211211
print_documents(prediction)
212+
# Batch prediction
213+
predictions = pipe.run_batch(queries=["亚马逊河流的介绍", '期货交易手续费指的是什么?'],
214+
params={
215+
"Retriever": {
216+
"top_k": 50
217+
},
218+
"Ranker": {
219+
"top_k": 5
220+
}
221+
})
222+
for i in range(len(predictions['queries'])):
223+
result = {
224+
'documents': predictions['documents'][i],
225+
'query': predictions['queries'][i]
226+
}
227+
print_documents(result)
212228

213229

214230
if __name__ == "__main__":

pipelines/pipelines/nodes/base.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -127,16 +127,33 @@ def _dispatch_run(self, **kwargs) -> Tuple[Dict, str]:
127127
- collate `_debug` information if present
128128
- merge component output with the preceding output and pass it on to the subsequent Component in the Pipeline
129129
"""
130+
return self._dispatch_run_general(self.run, **kwargs)
131+
132+
def _dispatch_run_batch(self, **kwargs):
133+
"""
134+
The Pipelines call this method when run_batch() is executed. This method in turn executes the
135+
_dispatch_run_general() method with the correct run method.
136+
"""
137+
return self._dispatch_run_general(self.run_batch, **kwargs)
138+
139+
def _dispatch_run_general(self, run_method: Callable, **kwargs):
140+
"""
141+
This method takes care of the following:
142+
- inspect run_method's signature to validate if all necessary arguments are available
143+
- pop `debug` and sets them on the instance to control debug output
144+
- call run_method with the corresponding arguments and gather output
145+
- collate `_debug` information if present
146+
- merge component output with the preceding output and pass it on to the subsequent Component in the Pipeline
147+
"""
130148
arguments = deepcopy(kwargs)
131149
params = arguments.get("params") or {}
132150

133-
run_signature_args = inspect.signature(self.run).parameters.keys()
151+
run_signature_args = inspect.signature(run_method).parameters.keys()
134152

135153
run_params: Dict[str, Any] = {}
136154
for key, value in params.items():
137155
if key == self.name: # targeted params for this node
138156
if isinstance(value, dict):
139-
140157
# Extract debug attributes
141158
if "debug" in value.keys():
142159
self.debug = value.pop("debug")
@@ -156,19 +173,19 @@ def _dispatch_run(self, **kwargs) -> Tuple[Dict, str]:
156173
if key in run_signature_args:
157174
run_inputs[key] = value
158175

159-
output, stream = self.run(**run_inputs, **run_params)
176+
output, stream = run_method(**run_inputs, **run_params)
160177

161178
# Collect debug information
162179
debug_info = {}
163180
if getattr(self, "debug", None):
164181
# Include input
165182
debug_info["input"] = {**run_inputs, **run_params}
166183
debug_info["input"]["debug"] = self.debug
167-
# Include output
184+
# Include output, exclude _debug to avoid recursion
168185
filtered_output = {
169186
key: value
170187
for key, value in output.items() if key != "_debug"
171-
} # Exclude _debug to avoid recursion
188+
}
172189
debug_info["output"] = filtered_output
173190
# Include custom debug info
174191
custom_debug = output.get("_debug", {})
@@ -182,9 +199,9 @@ def _dispatch_run(self, **kwargs) -> Tuple[Dict, str]:
182199
if all_debug:
183200
output["_debug"] = all_debug
184201

185-
# add "extra" args that were not used by the node
202+
# add "extra" args that were not used by the node, but not the 'inputs' value
186203
for k, v in arguments.items():
187-
if k not in output.keys():
204+
if k not in output.keys() and k != "inputs":
188205
output[k] = v
189206

190207
output["params"] = params

pipelines/pipelines/nodes/ranker/base.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import List, Optional
15+
from typing import List, Optional, Union
1616

1717
import logging
1818
from abc import abstractmethod
@@ -48,7 +48,7 @@ def predict_batch(self,
4848
def run(self,
4949
query: str,
5050
documents: List[Document],
51-
top_k: Optional[int] = None): # type: ignore
51+
top_k: Optional[int] = None):
5252
self.query_count += 1
5353
if documents:
5454
predict = self.timing(self.predict, "query_time")
@@ -62,6 +62,28 @@ def run(self,
6262

6363
return output, "output_1"
6464

65+
def run_batch(
66+
self,
67+
queries: List[str],
68+
documents: Union[List[Document], List[List[Document]]],
69+
top_k: Optional[int] = None,
70+
batch_size: Optional[int] = None,
71+
):
72+
self.query_count += len(queries)
73+
predict_batch = self.timing(self.predict_batch, "query_time")
74+
results = predict_batch(queries=queries,
75+
documents=documents,
76+
top_k=top_k,
77+
batch_size=batch_size)
78+
79+
for doc_list in results:
80+
document_ids = [doc.id for doc in doc_list]
81+
logger.debug("Ranked documents with IDs: %s", document_ids)
82+
83+
output = {"documents": results}
84+
85+
return output, "output_1"
86+
6587
def timing(self, fn, attr_name):
6688
"""Wrapper method used to time functions."""
6789

pipelines/pipelines/nodes/ranker/ernie_ranker.py

Lines changed: 150 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import List, Optional, Union
15+
from typing import List, Optional, Union, Tuple, Iterator
1616
import logging
1717
from pathlib import Path
18+
from tqdm import tqdm
1819

1920
import paddle
2021
from paddlenlp.transformers import ErnieCrossEncoder, AutoTokenizer
@@ -44,6 +45,9 @@ def __init__(
4445
model_name_or_path: Union[str, Path],
4546
top_k: int = 10,
4647
use_gpu: bool = True,
48+
max_seq_len: int = 256,
49+
progress_bar: bool = True,
50+
batch_size: int = 1000,
4751
):
4852
"""
4953
:param model_name_or_path: Directory of a saved model or the name of a public model e.g.
@@ -66,26 +70,13 @@ def __init__(
6670
self.transformer_model = ErnieCrossEncoder(model_name_or_path)
6771
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
6872
self.transformer_model.eval()
73+
self.progress_bar = progress_bar
74+
self.batch_size = batch_size
75+
self.max_seq_len = max_seq_len
6976

7077
if len(self.devices) > 1:
7178
self.model = paddle.DataParallel(self.transformer_model)
7279

73-
def predict_batch(self,
74-
query_doc_list: List[dict],
75-
top_k: int = None,
76-
batch_size: int = None):
77-
"""
78-
Use loaded Ranker model to, for a list of queries, rank each query's supplied list of Document.
79-
80-
Returns list of dictionary of query and list of document sorted by (desc.) similarity with query
81-
82-
:param query_doc_list: List of dictionaries containing queries with their retrieved documents
83-
:param top_k: The maximum number of answers to return for each query
84-
:param batch_size: Number of samples the model receives in one batch for inference
85-
:return: List of dictionaries containing query and ranked list of Document
86-
"""
87-
raise NotImplementedError
88-
8980
def predict(self,
9081
query: str,
9182
documents: List[Document],
@@ -105,7 +96,7 @@ def predict(self,
10596

10697
features = self.tokenizer([query for doc in documents],
10798
[doc.content for doc in documents],
108-
max_seq_len=256,
99+
max_seq_len=self.max_seq_len,
109100
pad_to_max_seq_len=True,
110101
truncation_strategy="longest_first")
111102

@@ -125,6 +116,146 @@ def predict(self,
125116
reverse=True,
126117
)
127118

128-
# rank documents according to scores
119+
# Rank documents according to scores
129120
sorted_documents = [doc for _, doc in sorted_scores_and_documents]
130121
return sorted_documents[:top_k]
122+
123+
def predict_batch(
124+
self,
125+
queries: List[str],
126+
documents: Union[List[Document], List[List[Document]]],
127+
top_k: Optional[int] = None,
128+
batch_size: Optional[int] = None,
129+
) -> Union[List[Document], List[List[Document]]]:
130+
"""
131+
Use loaded ranker model to re-rank the supplied lists of Documents
132+
133+
Returns lists of Documents sorted by (desc.) similarity with the corresponding queries.
134+
135+
:param queries: Single query string or list of queries
136+
:param documents: Single list of Documents or list of lists of Documents to be reranked.
137+
:param top_k: The maximum number of documents to return per Document list.
138+
:param batch_size: Number of Documents to process at a time.
139+
"""
140+
if top_k is None:
141+
top_k = self.top_k
142+
143+
if batch_size is None:
144+
batch_size = self.batch_size
145+
146+
number_of_docs, all_queries, all_docs, single_list_of_docs = self._preprocess_batch_queries_and_docs(
147+
queries=queries, documents=documents)
148+
batches = self._get_batches(all_queries=all_queries,
149+
all_docs=all_docs,
150+
batch_size=batch_size)
151+
pb = tqdm(total=len(all_docs),
152+
disable=not self.progress_bar,
153+
desc="Ranking")
154+
155+
preds = []
156+
for cur_queries, cur_docs in batches:
157+
features = self.tokenizer(cur_queries,
158+
[doc.content for doc in cur_docs],
159+
max_seq_len=256,
160+
pad_to_max_seq_len=True,
161+
truncation_strategy="longest_first")
162+
163+
tensors = {k: paddle.to_tensor(v) for (k, v) in features.items()}
164+
165+
with paddle.no_grad():
166+
similarity_scores = self.transformer_model.matching(
167+
**tensors).numpy()
168+
preds.extend(similarity_scores)
169+
170+
for doc, rank_score in zip(cur_docs, similarity_scores):
171+
doc.rank_score = rank_score
172+
doc.score = rank_score
173+
pb.update(len(cur_docs))
174+
pb.close()
175+
if single_list_of_docs:
176+
sorted_scores_and_documents = sorted(
177+
zip(preds, documents),
178+
key=lambda similarity_document_tuple: similarity_document_tuple[
179+
0],
180+
reverse=True,
181+
)
182+
sorted_documents = [doc for _, doc in sorted_scores_and_documents]
183+
return sorted_documents[:top_k]
184+
else:
185+
grouped_predictions = []
186+
left_idx = 0
187+
right_idx = 0
188+
for number in number_of_docs:
189+
right_idx = left_idx + number
190+
grouped_predictions.append(
191+
similarity_scores[left_idx:right_idx])
192+
left_idx = right_idx
193+
result = []
194+
for pred_group, doc_group in zip(grouped_predictions, documents):
195+
sorted_scores_and_documents = sorted(
196+
zip(pred_group, doc_group),
197+
key=lambda similarity_document_tuple:
198+
similarity_document_tuple[0],
199+
reverse=True,
200+
)
201+
sorted_documents = [
202+
doc for _, doc in sorted_scores_and_documents
203+
]
204+
result.append(sorted_documents[:top_k])
205+
206+
return result
207+
208+
def _preprocess_batch_queries_and_docs(
209+
self, queries: List[str], documents: Union[List[Document],
210+
List[List[Document]]]
211+
) -> Tuple[List[int], List[str], List[Document], bool]:
212+
number_of_docs = []
213+
all_queries = []
214+
all_docs: List[Document] = []
215+
single_list_of_docs = False
216+
217+
# Docs case 1: single list of Documents -> rerank single list of Documents based on single query
218+
if len(documents) > 0 and isinstance(documents[0], Document):
219+
if len(queries) != 1:
220+
raise Exception(
221+
"Number of queries must be 1 if a single list of Documents is provided."
222+
)
223+
query = queries[0]
224+
number_of_docs = [len(documents)]
225+
all_queries = [query] * len(documents)
226+
all_docs = documents # type: ignore
227+
single_list_of_docs = True
228+
229+
# Docs case 2: list of lists of Documents -> rerank each list of Documents based on corresponding query
230+
# If queries contains a single query, apply it to each list of Documents
231+
if len(documents) > 0 and isinstance(documents[0], list):
232+
if len(queries) == 1:
233+
queries = queries * len(documents)
234+
if len(queries) != len(documents):
235+
raise Exception(
236+
"Number of queries must be equal to number of provided Document lists."
237+
)
238+
for query, cur_docs in zip(queries, documents):
239+
if not isinstance(cur_docs, list):
240+
raise Exception(
241+
f"cur_docs was of type {type(cur_docs)}, but expected a list of Documents."
242+
)
243+
number_of_docs.append(len(cur_docs))
244+
all_queries.extend([query] * len(cur_docs))
245+
all_docs.extend(cur_docs)
246+
247+
return number_of_docs, all_queries, all_docs, single_list_of_docs
248+
249+
@staticmethod
250+
def _get_batches(
251+
all_queries: List[str], all_docs: List[Document],
252+
batch_size: Optional[int]
253+
) -> Iterator[Tuple[List[str], List[Document]]]:
254+
if batch_size is None:
255+
yield all_queries, all_docs
256+
return
257+
else:
258+
for index in range(0, len(all_queries), batch_size):
259+
yield all_queries[index:index +
260+
batch_size], all_docs[index:index +
261+
batch_size]

0 commit comments

Comments
 (0)