23
23
import paddle
24
24
import paddle .nn as nn
25
25
import paddle .nn .functional as F
26
- from ..datasets import MapDataset
26
+ from ..datasets import load_dataset , MapDataset
27
27
from ..data import Stack , Pad , Tuple , Vocab , JiebaTokenizer
28
- from .utils import download_file
28
+ from .utils import download_file , add_docstrings
29
29
from .model import BoWModel , LSTMModel
30
30
from .task import Task
31
31
42
42
]
43
43
}
44
44
45
+ usage = r"""
46
+ from paddlenlp.taskflow import TaskFlow
47
+
48
+ task = TaskFlow("sentiment_analysis")
49
+ task("怀着十分激动的心情放映,可是看着看着发现,在放映完毕后,出现一集米老鼠的动画片")
50
+ '''
51
+ [{'text': '怀着十分激动的心情放映,可是看着看着发现,在放映完毕后,出现一集米老鼠的动画片', 'label': 'positive'}]
52
+ '''
53
+
54
+ task = TaskFlow("sentiment_analysis", network="lstm")
55
+ task("作为老的四星酒店,房间依然很整洁,相当不错。机场接机服务很好,可以在车上办理入住手续,节省时间。")
56
+ '''
57
+ [{'text': '作为老的四星酒店,房间依然很整洁,相当不错。机场接机服务很好,可以在车上办理入住手续,节省时间。', 'label': 'positive'}]
58
+ '''
59
+
60
+ task = TaskFlow("sentiment_analysis", lazy_load="True")
61
+ task("作为老的四星酒店,房间依然很整洁,相当不错。机场接机服务很好,可以在车上办理入住手续,节省时间。")
62
+ '''
63
+ [{'text': '作为老的四星酒店,房间依然很整洁,相当不错。机场接机服务很好,可以在车上办理入住手续,节省时间。', 'label': 'positive'}]
64
+ '''
65
+
66
+ task = TaskFlow("sentiment_analysis", batch_size=2)
67
+ task(["作为老的四星酒店,房间依然很整洁,相当不错。机场接机服务很好,可以在车上办理入住手续,节省时间。",
68
+ "怀着十分激动的心情放映,可是看着看着发现,在放映完毕后,出现一集米老鼠的动画片",
69
+ "这个宾馆比较陈旧了,特价的房间也很一般。总体来说一般",
70
+ "2001年来福州就住在这里,这次感觉房间就了点,温泉水还是有的.总的来说很满意.早餐简单了些."])
71
+ '''
72
+ [{'text': '作为老的四星酒店,房间依然很整洁,相当不错。机场接机服务很好,可以在车上办理入住手续,节省时间。', 'label': 'positive'}, {'text': '怀着十分激动的心情放映,可是看着看着发现,在放映完毕后,出现一集米老鼠的动画片', 'label': 'negative'}, {'text': '这个宾馆比较陈旧了,特价的房间也很一般。总体来说一般', 'label': 'negative'}, {'text': '2001年来福州就住在这里,这次感觉房间就了点,温泉水还是有的.总的来说很满意.早餐简单了些.', 'label': 'positive'}]
73
+ '''
74
+ """
75
+
45
76
46
77
class SentaTask (Task ):
47
- """The one task of sentiment_analysis which use the RNN or Bow model to analysis the input text.
78
+ """
79
+ Sentiment analysis task using RNN or BOW model to predict sentiment opinion on Chinese text.
80
+ Args:
81
+ task(string): The name of task.
82
+ model(string): The model name in the task.
83
+ kwargs (dict, optional): Additional keyword arguments passed along to the specific task.
48
84
"""
49
85
50
86
def __init__ (self , task , model , ** kwargs ):
51
87
super ().__init__ (task = task , model = model , ** kwargs )
52
88
self ._tokenizer = self ._construct_tokenizer (model )
53
89
self ._model_instance = self ._construct_model (model )
54
90
self ._label_map = {0 : 'negative' , 1 : 'positive' }
91
+ self ._usage = usage
55
92
56
93
def _construct_model (self , model ):
57
- """Construct the inference model for the predictor.
94
+ """
95
+ Construct the inference model for the predictor.
58
96
"""
59
97
vocab_size = self .kwargs ['vocab_size' ]
60
98
pad_token_id = self .kwargs ['pad_token_id' ]
@@ -90,7 +128,8 @@ def _construct_model(self, model):
90
128
return model
91
129
92
130
def _construct_tokenizer (self , model ):
93
- """Construct the tokenizer for the predictor.
131
+ """
132
+ Construct the tokenizer for the predictor.
94
133
"""
95
134
full_name = download_file (self .model , "senta_word_dict.txt" ,
96
135
URLS ['senta_vocab' ][0 ],
@@ -119,21 +158,26 @@ def _preprocess(self, inputs, padding=True, add_special_tokens=True):
119
158
raise TypeError (
120
159
"Invalid inputs, input text should be str or list of str, {type(inputs)} found!"
121
160
)
161
+ # Get the config from the kwargs
162
+ batch_size = self .kwargs [
163
+ 'batch_size' ] if 'batch_size' in self .kwargs else 1
164
+ num_workers = self .kwargs [
165
+ 'num_workers' ] if 'num_workers' in self .kwargs else 0
166
+ lazy_load = self .kwargs [
167
+ 'lazy_load' ] if 'lazy_load' in self .kwargs else False
122
168
infer_data = []
123
- for i in range (0 , len (inputs )):
124
- ids = self ._tokenizer .encode (inputs [i ])
125
- lens = len (ids )
126
- infer_data .append ([ids , lens ])
127
- infer_ds = MapDataset (infer_data )
169
+
170
+ def read (inputs ):
171
+ for input_data in inputs :
172
+ ids = self ._tokenizer .encode (input_data )
173
+ lens = len (ids )
174
+ yield ids , lens
175
+
176
+ infer_ds = load_dataset (read , inputs = inputs , lazy = lazy_load )
128
177
batchify_fn = lambda samples , fn = Tuple (
129
178
Pad (axis = 0 , pad_val = self ._tokenizer .vocab .token_to_idx .get ('[PAD]' , 0 )), # input_ids
130
179
Stack (dtype = 'int64' ), # seq_len
131
180
): fn (samples )
132
-
133
- batch_size = self .kwargs [
134
- 'batch_size' ] if 'batch_size' in self .kwargs else 1
135
- num_workers = self .kwargs [
136
- 'num_workers' ] if 'num_workers' in self .kwargs else 0
137
181
infer_data_loader = paddle .io .DataLoader (
138
182
infer_ds ,
139
183
collate_fn = batchify_fn ,
@@ -147,7 +191,8 @@ def _preprocess(self, inputs, padding=True, add_special_tokens=True):
147
191
return outputs
148
192
149
193
def _run_model (self , inputs ):
150
- """Run the task model from the outputs of the `_tokenize` function.
194
+ """
195
+ Run the task model from the outputs of the `_tokenize` function.
151
196
"""
152
197
results = []
153
198
with paddle .no_grad ():
@@ -163,7 +208,8 @@ def _run_model(self, inputs):
163
208
return inputs
164
209
165
210
def _postprocess (self , inputs ):
166
- """The model output is allways the logits and pros, this function will convert the model output to raw text.
211
+ """
212
+ The model output is allways the logits and pros, this function will convert the model output to raw text.
167
213
"""
168
214
final_results = []
169
215
for text , label in zip (inputs ['text' ], inputs ['result' ]):
0 commit comments