12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
- from typing import List , Optional , Union
15
+ from typing import List , Optional , Union , Tuple , Iterator
16
16
import logging
17
17
from pathlib import Path
18
+ from tqdm import tqdm
18
19
19
20
import paddle
20
21
from paddlenlp .transformers import ErnieCrossEncoder , AutoTokenizer
@@ -44,6 +45,9 @@ def __init__(
44
45
model_name_or_path : Union [str , Path ],
45
46
top_k : int = 10 ,
46
47
use_gpu : bool = True ,
48
+ max_seq_len : int = 256 ,
49
+ progress_bar : bool = True ,
50
+ batch_size : int = 1000 ,
47
51
):
48
52
"""
49
53
:param model_name_or_path: Directory of a saved model or the name of a public model e.g.
@@ -66,26 +70,13 @@ def __init__(
66
70
self .transformer_model = ErnieCrossEncoder (model_name_or_path )
67
71
self .tokenizer = AutoTokenizer .from_pretrained (model_name_or_path )
68
72
self .transformer_model .eval ()
73
+ self .progress_bar = progress_bar
74
+ self .batch_size = batch_size
75
+ self .max_seq_len = max_seq_len
69
76
70
77
if len (self .devices ) > 1 :
71
78
self .model = paddle .DataParallel (self .transformer_model )
72
79
73
- def predict_batch (self ,
74
- query_doc_list : List [dict ],
75
- top_k : int = None ,
76
- batch_size : int = None ):
77
- """
78
- Use loaded Ranker model to, for a list of queries, rank each query's supplied list of Document.
79
-
80
- Returns list of dictionary of query and list of document sorted by (desc.) similarity with query
81
-
82
- :param query_doc_list: List of dictionaries containing queries with their retrieved documents
83
- :param top_k: The maximum number of answers to return for each query
84
- :param batch_size: Number of samples the model receives in one batch for inference
85
- :return: List of dictionaries containing query and ranked list of Document
86
- """
87
- raise NotImplementedError
88
-
89
80
def predict (self ,
90
81
query : str ,
91
82
documents : List [Document ],
@@ -105,7 +96,7 @@ def predict(self,
105
96
106
97
features = self .tokenizer ([query for doc in documents ],
107
98
[doc .content for doc in documents ],
108
- max_seq_len = 256 ,
99
+ max_seq_len = self . max_seq_len ,
109
100
pad_to_max_seq_len = True ,
110
101
truncation_strategy = "longest_first" )
111
102
@@ -125,6 +116,146 @@ def predict(self,
125
116
reverse = True ,
126
117
)
127
118
128
- # rank documents according to scores
119
+ # Rank documents according to scores
129
120
sorted_documents = [doc for _ , doc in sorted_scores_and_documents ]
130
121
return sorted_documents [:top_k ]
122
+
123
+ def predict_batch (
124
+ self ,
125
+ queries : List [str ],
126
+ documents : Union [List [Document ], List [List [Document ]]],
127
+ top_k : Optional [int ] = None ,
128
+ batch_size : Optional [int ] = None ,
129
+ ) -> Union [List [Document ], List [List [Document ]]]:
130
+ """
131
+ Use loaded ranker model to re-rank the supplied lists of Documents
132
+
133
+ Returns lists of Documents sorted by (desc.) similarity with the corresponding queries.
134
+
135
+ :param queries: Single query string or list of queries
136
+ :param documents: Single list of Documents or list of lists of Documents to be reranked.
137
+ :param top_k: The maximum number of documents to return per Document list.
138
+ :param batch_size: Number of Documents to process at a time.
139
+ """
140
+ if top_k is None :
141
+ top_k = self .top_k
142
+
143
+ if batch_size is None :
144
+ batch_size = self .batch_size
145
+
146
+ number_of_docs , all_queries , all_docs , single_list_of_docs = self ._preprocess_batch_queries_and_docs (
147
+ queries = queries , documents = documents )
148
+ batches = self ._get_batches (all_queries = all_queries ,
149
+ all_docs = all_docs ,
150
+ batch_size = batch_size )
151
+ pb = tqdm (total = len (all_docs ),
152
+ disable = not self .progress_bar ,
153
+ desc = "Ranking" )
154
+
155
+ preds = []
156
+ for cur_queries , cur_docs in batches :
157
+ features = self .tokenizer (cur_queries ,
158
+ [doc .content for doc in cur_docs ],
159
+ max_seq_len = 256 ,
160
+ pad_to_max_seq_len = True ,
161
+ truncation_strategy = "longest_first" )
162
+
163
+ tensors = {k : paddle .to_tensor (v ) for (k , v ) in features .items ()}
164
+
165
+ with paddle .no_grad ():
166
+ similarity_scores = self .transformer_model .matching (
167
+ ** tensors ).numpy ()
168
+ preds .extend (similarity_scores )
169
+
170
+ for doc , rank_score in zip (cur_docs , similarity_scores ):
171
+ doc .rank_score = rank_score
172
+ doc .score = rank_score
173
+ pb .update (len (cur_docs ))
174
+ pb .close ()
175
+ if single_list_of_docs :
176
+ sorted_scores_and_documents = sorted (
177
+ zip (preds , documents ),
178
+ key = lambda similarity_document_tuple : similarity_document_tuple [
179
+ 0 ],
180
+ reverse = True ,
181
+ )
182
+ sorted_documents = [doc for _ , doc in sorted_scores_and_documents ]
183
+ return sorted_documents [:top_k ]
184
+ else :
185
+ grouped_predictions = []
186
+ left_idx = 0
187
+ right_idx = 0
188
+ for number in number_of_docs :
189
+ right_idx = left_idx + number
190
+ grouped_predictions .append (
191
+ similarity_scores [left_idx :right_idx ])
192
+ left_idx = right_idx
193
+ result = []
194
+ for pred_group , doc_group in zip (grouped_predictions , documents ):
195
+ sorted_scores_and_documents = sorted (
196
+ zip (pred_group , doc_group ),
197
+ key = lambda similarity_document_tuple :
198
+ similarity_document_tuple [0 ],
199
+ reverse = True ,
200
+ )
201
+ sorted_documents = [
202
+ doc for _ , doc in sorted_scores_and_documents
203
+ ]
204
+ result .append (sorted_documents [:top_k ])
205
+
206
+ return result
207
+
208
+ def _preprocess_batch_queries_and_docs (
209
+ self , queries : List [str ], documents : Union [List [Document ],
210
+ List [List [Document ]]]
211
+ ) -> Tuple [List [int ], List [str ], List [Document ], bool ]:
212
+ number_of_docs = []
213
+ all_queries = []
214
+ all_docs : List [Document ] = []
215
+ single_list_of_docs = False
216
+
217
+ # Docs case 1: single list of Documents -> rerank single list of Documents based on single query
218
+ if len (documents ) > 0 and isinstance (documents [0 ], Document ):
219
+ if len (queries ) != 1 :
220
+ raise Exception (
221
+ "Number of queries must be 1 if a single list of Documents is provided."
222
+ )
223
+ query = queries [0 ]
224
+ number_of_docs = [len (documents )]
225
+ all_queries = [query ] * len (documents )
226
+ all_docs = documents # type: ignore
227
+ single_list_of_docs = True
228
+
229
+ # Docs case 2: list of lists of Documents -> rerank each list of Documents based on corresponding query
230
+ # If queries contains a single query, apply it to each list of Documents
231
+ if len (documents ) > 0 and isinstance (documents [0 ], list ):
232
+ if len (queries ) == 1 :
233
+ queries = queries * len (documents )
234
+ if len (queries ) != len (documents ):
235
+ raise Exception (
236
+ "Number of queries must be equal to number of provided Document lists."
237
+ )
238
+ for query , cur_docs in zip (queries , documents ):
239
+ if not isinstance (cur_docs , list ):
240
+ raise Exception (
241
+ f"cur_docs was of type { type (cur_docs )} , but expected a list of Documents."
242
+ )
243
+ number_of_docs .append (len (cur_docs ))
244
+ all_queries .extend ([query ] * len (cur_docs ))
245
+ all_docs .extend (cur_docs )
246
+
247
+ return number_of_docs , all_queries , all_docs , single_list_of_docs
248
+
249
+ @staticmethod
250
+ def _get_batches (
251
+ all_queries : List [str ], all_docs : List [Document ],
252
+ batch_size : Optional [int ]
253
+ ) -> Iterator [Tuple [List [str ], List [Document ]]]:
254
+ if batch_size is None :
255
+ yield all_queries , all_docs
256
+ return
257
+ else :
258
+ for index in range (0 , len (all_queries ), batch_size ):
259
+ yield all_queries [index :index +
260
+ batch_size ], all_docs [index :index +
261
+ batch_size ]
0 commit comments