Skip to content

Commit 4d06421

Browse files
authored
simpify codes and fix dtype bug in win (#1469)
1 parent b4542fa commit 4d06421

File tree

1 file changed

+3
-46
lines changed
  • examples/sentiment_analysis/skep/deploy/python

1 file changed

+3
-46
lines changed

examples/sentiment_analysis/skep/deploy/python/predict.py

Lines changed: 3 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -41,54 +41,11 @@ def convert_example(example,
4141
label_list,
4242
max_seq_length=512,
4343
is_test=False):
44-
"""
45-
Builds model inputs from a sequence or a pair of sequence for sequence classification tasks
46-
by concatenating and adding special tokens. And creates a mask from the two sequences passed
47-
to be used in a sequence-pair classification task.
48-
49-
A BERT sequence has the following format:
50-
51-
- single sequence: ``[CLS] X [SEP]``
52-
- pair of sequences: ``[CLS] A [SEP] B [SEP]``
53-
54-
A BERT sequence pair mask has the following format:
55-
::
56-
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
57-
| first sequence | second sequence |
58-
59-
If only one sequence, only returns the first portion of the mask (0's).
60-
61-
62-
Args:
63-
example(obj:`list[str]`): List of input data, containing text and label if it have label.
64-
tokenizer(obj:`PretrainedTokenizer`): This tokenizer inherits from :class:`~paddlenlp.transformers.PretrainedTokenizer`
65-
which contains most of the methods. Users should refer to the superclass for more information regarding methods.
66-
label_list(obj:`list[str]`): All the labels that the data has.
67-
max_seq_len(obj:`int`): The maximum total input sequence length after tokenization.
68-
Sequences longer than this will be truncated, sequences shorter will be padded.
69-
is_test(obj:`False`, defaults to `False`): Whether the example contains label or not.
70-
71-
Returns:
72-
input_ids(obj:`list[int]`): The list of token ids.
73-
segment_ids(obj: `list[int]`): List of sequence pair mask.
74-
label(obj:`numpy.array`, data type of int64, optional): The input label if not is_test.
75-
"""
7644
text = example
7745
encoded_inputs = tokenizer(text=text, max_seq_len=max_seq_length)
78-
input_ids = encoded_inputs["input_ids"]
79-
segment_ids = encoded_inputs["token_type_ids"]
80-
81-
if not is_test:
82-
# create label maps
83-
label_map = {}
84-
for (i, l) in enumerate(label_list):
85-
label_map[l] = i
86-
87-
label = label_map[label]
88-
label = np.array([label], dtype="int64")
89-
return input_ids, segment_ids, label
90-
else:
91-
return input_ids, segment_ids
46+
input_ids = np.array(tokenized_input['input_ids'], dtype="int64")
47+
token_type_ids = np.array(tokenized_input['token_type_ids'], dtype="int64")
48+
return input_ids, token_type_ids
9249

9350

9451
class Predictor(object):

0 commit comments

Comments
 (0)