Skip to content

Commit bc178ab

Browse files
authored
optimize the time cost for the wordtag (#863)
* optimize the time cost for the wordtag * update the optimize code for the predictor
1 parent 73b4d38 commit bc178ab

File tree

2 files changed

+60
-19
lines changed

2 files changed

+60
-19
lines changed

examples/text_to_knowledge/wordtag/predictor.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,6 @@ def _load_term_tree_data(term_tree_name_or_path):
217217

218218
def _split_long_text2short_text_list(self, input_texts, max_text_len):
219219
short_input_texts = []
220-
short_input_texts_lens = []
221220
for text in input_texts:
222221
if len(text) <= max_text_len:
223222
short_input_texts.append(text)
@@ -235,13 +234,35 @@ def _split_long_text2short_text_list(self, input_texts, max_text_len):
235234
]
236235
short_input_texts.extend(temp_text_list)
237236
else:
238-
count = 0
239-
for temp_text in temp_text_list:
240-
if len(temp_text) + count < lens:
241-
temp_text = text[:len(temp_text) + count + 1]
242-
count += len(temp_text)
237+
list_len = len(temp_text_list)
238+
start = 0
239+
end = 0
240+
for i in range(0, list_len):
241+
if len(temp_text_list[i]) + 1 >= max_text_len:
242+
if start != end:
243+
short_input_texts.extend(
244+
self._split_long_text_input(
245+
[text[start:end]], max_text_len))
246+
short_input_texts.extend(
247+
self._split_long_text_input([
248+
text[end:end + len(temp_text_list[i]) + 1]
249+
], max_text_len))
250+
start = end + len(temp_text_list[i]) + 1
251+
end = start
252+
else:
253+
if start + len(temp_text_list[
254+
i]) + 1 > max_text_len:
255+
short_input_texts.extend(
256+
self._split_long_text_input(
257+
[text[start:end]], max_text_len))
258+
start = end
259+
end = end + len(temp_text_list[i]) + 1
260+
else:
261+
end = len(temp_text_list[i]) + 1
262+
if start != end:
243263
short_input_texts.extend(
244-
self._split_long_text2short_text_list([temp_text]))
264+
self._split_long_text_input([text[start:end]],
265+
max_text_len))
245266
return short_input_texts
246267

247268
def _convert_short_text2long_text_result(self, input_texts, results):
@@ -268,7 +289,7 @@ def _convert_short_text2long_text_result(self, input_texts, results):
268289
raise Exception("The len of text must same as raw text.")
269290
return concat_results
270291

271-
def _pre_process_text(self, input_texts, max_seq_len=128, batch_size=1):
292+
def _pre_process_text(self, input_texts, max_seq_len=512, batch_size=1):
272293
infer_data = []
273294
max_predict_len = max_seq_len - self.summary_num - 1
274295
short_input_texts = self._split_long_text2short_text_list(
@@ -341,7 +362,7 @@ def _decode(self, batch_texts, batch_pred_tags):
341362
@paddle.no_grad()
342363
def run(self,
343364
input_texts,
344-
max_seq_len=128,
365+
max_seq_len=512,
345366
batch_size=1,
346367
return_hidden_states=None):
347368
"""Predict a input text by wordtag.

paddlenlp/taskflow/text2knowledge.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,6 @@ def _split_long_text_input(self, input_texts, max_text_len):
261261
if the text length greater than 512, will this function that spliting the long text.
262262
"""
263263
short_input_texts = []
264-
short_input_texts_lens = []
265264
for text in input_texts:
266265
if len(text) <= max_text_len:
267266
short_input_texts.append(text)
@@ -279,13 +278,35 @@ def _split_long_text_input(self, input_texts, max_text_len):
279278
]
280279
short_input_texts.extend(temp_text_list)
281280
else:
282-
count = 0
283-
for temp_text in temp_text_list:
284-
if len(temp_text) + count < lens:
285-
temp_text = text[:len(temp_text) + count + 1]
286-
count += len(temp_text)
281+
list_len = len(temp_text_list)
282+
start = 0
283+
end = 0
284+
for i in range(0, list_len):
285+
if len(temp_text_list[i]) + 1 >= max_text_len:
286+
if start != end:
287+
short_input_texts.extend(
288+
self._split_long_text_input(
289+
[text[start:end]], max_text_len))
290+
short_input_texts.extend(
291+
self._split_long_text_input([
292+
text[end:end + len(temp_text_list[i]) + 1]
293+
], max_text_len))
294+
start = end + len(temp_text_list[i]) + 1
295+
end = start
296+
else:
297+
if start + len(temp_text_list[
298+
i]) + 1 > max_text_len:
299+
short_input_texts.extend(
300+
self._split_long_text_input(
301+
[text[start:end]], max_text_len))
302+
start = end
303+
end = end + len(temp_text_list[i]) + 1
304+
else:
305+
end = len(temp_text_list[i]) + 1
306+
if start != end:
287307
short_input_texts.extend(
288-
self._split_long_text2short_text_list([temp_text]))
308+
self._split_long_text_input([text[start:end]],
309+
max_text_len))
289310
return short_input_texts
290311

291312
def _concat_short_text_reuslts(self, input_texts, results):
@@ -318,7 +339,6 @@ def _concat_short_text_reuslts(self, input_texts, results):
318339
pred_words = result['items']
319340
pred_words = self._reset_offset(pred_words)
320341
result['items'] = pred_words
321-
322342
return concat_results
323343

324344
def _preprocess_text(self, input_texts):
@@ -333,7 +353,7 @@ def _preprocess_text(self, input_texts):
333353
lazy_load = self.kwargs[
334354
'lazy_load'] if 'lazy_load' in self.kwargs else False
335355

336-
max_seq_length = 128
356+
max_seq_length = 512
337357
if 'max_position_embedding' in self.kwargs:
338358
max_seq_length = self.kwargs['max_position_embedding']
339359
infer_data = []
@@ -533,7 +553,7 @@ def _postprocess(self, inputs):
533553
"""
534554
results = self._decode(inputs['short_input_texts'],
535555
inputs['all_pred_tags'])
536-
resulte = self._concat_short_text_reuslts(inputs['inputs'], results)
556+
results = self._concat_short_text_reuslts(inputs['inputs'], results)
537557
if self.linking is True:
538558
for res in results:
539559
self._term_linking(res)

0 commit comments

Comments
 (0)