Skip to content

Commit f61a294

Browse files
authored
update the some functions for the taskflow (#726)
* update the function for the taskflow * update the examples for the taskflow * update the Text2Knowledge to WordTag * fix the usage in taskflow * update the usage for the taskflow
1 parent 7a53214 commit f61a294

File tree

6 files changed

+178
-48
lines changed

6 files changed

+178
-48
lines changed

paddlenlp/taskflow/sentiment_analysis.py

Lines changed: 63 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@
2323
import paddle
2424
import paddle.nn as nn
2525
import paddle.nn.functional as F
26-
from ..datasets import MapDataset
26+
from ..datasets import load_dataset, MapDataset
2727
from ..data import Stack, Pad, Tuple, Vocab, JiebaTokenizer
28-
from .utils import download_file
28+
from .utils import download_file, add_docstrings
2929
from .model import BoWModel, LSTMModel
3030
from .task import Task
3131

@@ -42,19 +42,57 @@
4242
]
4343
}
4444

45+
usage = r"""
46+
from paddlenlp.taskflow import TaskFlow
47+
48+
task = TaskFlow("sentiment_analysis")
49+
task("怀着十分激动的心情放映,可是看着看着发现,在放映完毕后,出现一集米老鼠的动画片")
50+
'''
51+
[{'text': '怀着十分激动的心情放映,可是看着看着发现,在放映完毕后,出现一集米老鼠的动画片', 'label': 'positive'}]
52+
'''
53+
54+
task = TaskFlow("sentiment_analysis", network="lstm")
55+
task("作为老的四星酒店,房间依然很整洁,相当不错。机场接机服务很好,可以在车上办理入住手续,节省时间。")
56+
'''
57+
[{'text': '作为老的四星酒店,房间依然很整洁,相当不错。机场接机服务很好,可以在车上办理入住手续,节省时间。', 'label': 'positive'}]
58+
'''
59+
60+
task = TaskFlow("sentiment_analysis", lazy_load="True")
61+
task("作为老的四星酒店,房间依然很整洁,相当不错。机场接机服务很好,可以在车上办理入住手续,节省时间。")
62+
'''
63+
[{'text': '作为老的四星酒店,房间依然很整洁,相当不错。机场接机服务很好,可以在车上办理入住手续,节省时间。', 'label': 'positive'}]
64+
'''
65+
66+
task = TaskFlow("sentiment_analysis", batch_size=2)
67+
task(["作为老的四星酒店,房间依然很整洁,相当不错。机场接机服务很好,可以在车上办理入住手续,节省时间。",
68+
"怀着十分激动的心情放映,可是看着看着发现,在放映完毕后,出现一集米老鼠的动画片",
69+
"这个宾馆比较陈旧了,特价的房间也很一般。总体来说一般",
70+
"2001年来福州就住在这里,这次感觉房间就了点,温泉水还是有的.总的来说很满意.早餐简单了些."])
71+
'''
72+
[{'text': '作为老的四星酒店,房间依然很整洁,相当不错。机场接机服务很好,可以在车上办理入住手续,节省时间。', 'label': 'positive'}, {'text': '怀着十分激动的心情放映,可是看着看着发现,在放映完毕后,出现一集米老鼠的动画片', 'label': 'negative'}, {'text': '这个宾馆比较陈旧了,特价的房间也很一般。总体来说一般', 'label': 'negative'}, {'text': '2001年来福州就住在这里,这次感觉房间就了点,温泉水还是有的.总的来说很满意.早餐简单了些.', 'label': 'positive'}]
73+
'''
74+
"""
75+
4576

4677
class SentaTask(Task):
47-
"""The one task of sentiment_analysis which use the RNN or Bow model to analysis the input text.
78+
"""
79+
Sentiment analysis task using RNN or BOW model to predict sentiment opinion on Chinese text.
80+
Args:
81+
task(string): The name of task.
82+
model(string): The model name in the task.
83+
kwargs (dict, optional): Additional keyword arguments passed along to the specific task.
4884
"""
4985

5086
def __init__(self, task, model, **kwargs):
5187
super().__init__(task=task, model=model, **kwargs)
5288
self._tokenizer = self._construct_tokenizer(model)
5389
self._model_instance = self._construct_model(model)
5490
self._label_map = {0: 'negative', 1: 'positive'}
91+
self._usage = usage
5592

5693
def _construct_model(self, model):
57-
"""Construct the inference model for the predictor.
94+
"""
95+
Construct the inference model for the predictor.
5896
"""
5997
vocab_size = self.kwargs['vocab_size']
6098
pad_token_id = self.kwargs['pad_token_id']
@@ -90,7 +128,8 @@ def _construct_model(self, model):
90128
return model
91129

92130
def _construct_tokenizer(self, model):
93-
"""Construct the tokenizer for the predictor.
131+
"""
132+
Construct the tokenizer for the predictor.
94133
"""
95134
full_name = download_file(self.model, "senta_word_dict.txt",
96135
URLS['senta_vocab'][0],
@@ -119,21 +158,26 @@ def _preprocess(self, inputs, padding=True, add_special_tokens=True):
119158
raise TypeError(
120159
"Invalid inputs, input text should be str or list of str, {type(inputs)} found!"
121160
)
161+
# Get the config from the kwargs
162+
batch_size = self.kwargs[
163+
'batch_size'] if 'batch_size' in self.kwargs else 1
164+
num_workers = self.kwargs[
165+
'num_workers'] if 'num_workers' in self.kwargs else 0
166+
lazy_load = self.kwargs[
167+
'lazy_load'] if 'lazy_load' in self.kwargs else False
122168
infer_data = []
123-
for i in range(0, len(inputs)):
124-
ids = self._tokenizer.encode(inputs[i])
125-
lens = len(ids)
126-
infer_data.append([ids, lens])
127-
infer_ds = MapDataset(infer_data)
169+
170+
def read(inputs):
171+
for input_data in inputs:
172+
ids = self._tokenizer.encode(input_data)
173+
lens = len(ids)
174+
yield ids, lens
175+
176+
infer_ds = load_dataset(read, inputs=inputs, lazy=lazy_load)
128177
batchify_fn = lambda samples, fn=Tuple(
129178
Pad(axis=0, pad_val=self._tokenizer.vocab.token_to_idx.get('[PAD]', 0)), # input_ids
130179
Stack(dtype='int64'), # seq_len
131180
): fn(samples)
132-
133-
batch_size = self.kwargs[
134-
'batch_size'] if 'batch_size' in self.kwargs else 1
135-
num_workers = self.kwargs[
136-
'num_workers'] if 'num_workers' in self.kwargs else 0
137181
infer_data_loader = paddle.io.DataLoader(
138182
infer_ds,
139183
collate_fn=batchify_fn,
@@ -147,7 +191,8 @@ def _preprocess(self, inputs, padding=True, add_special_tokens=True):
147191
return outputs
148192

149193
def _run_model(self, inputs):
150-
"""Run the task model from the outputs of the `_tokenize` function.
194+
"""
195+
Run the task model from the outputs of the `_tokenize` function.
151196
"""
152197
results = []
153198
with paddle.no_grad():
@@ -163,7 +208,8 @@ def _run_model(self, inputs):
163208
return inputs
164209

165210
def _postprocess(self, inputs):
166-
"""The model output is allways the logits and pros, this function will convert the model output to raw text.
211+
"""
212+
The model output is allways the logits and pros, this function will convert the model output to raw text.
167213
"""
168214
final_results = []
169215
for text, label in zip(inputs['text'], inputs['result']):

paddlenlp/taskflow/task.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,16 @@
1818

1919

2020
class Task(metaclass=abc.ABCMeta):
21-
""" The meta classs of task in TaskFlow. The meta class has the five abstract function,
21+
"""
22+
The meta classs of task in TaskFlow. The meta class has the five abstract function,
2223
the subclass need to inherit from the meta class.
2324
"""
2425

2526
def __init__(self, model, task, **kwargs):
2627
self.model = model
2728
self.task = task
2829
self.kwargs = kwargs
30+
self._usage = ""
2931

3032
@abstractmethod
3133
def _construct_model(self, model):
@@ -59,6 +61,12 @@ def _postprocess(self, inputs):
5961
The model output is allways the logits and pros, this function will convert the model output to raw text.
6062
"""
6163

64+
def help(self):
65+
"""
66+
Return the usage message of the current task.
67+
"""
68+
print("Examples:\n{}".format(self._usage))
69+
6270
def __call__(self, *args):
6371
inputs = self._preprocess(*args)
6472
outputs = self._run_model(inputs)

paddlenlp/taskflow/taskflow.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import paddle
1818
from ..utils.tools import get_env_device
1919
from ..transformers import ErnieCtmWordtagModel, ErnieCtmTokenizer
20-
from .text2knowledge import Text2KnowledgeTask
20+
from .text2knowledge import WordTagTask
2121
from .sentiment_analysis import SentaTask
2222

2323
warnings.simplefilter(action='ignore', category=Warning, lineno=0, append=False)
@@ -26,7 +26,7 @@
2626
"text2knowledge": {
2727
"models": {
2828
"wordtag": {
29-
"task_class": Text2KnowledgeTask,
29+
"task_class": WordTagTask,
3030
}
3131
},
3232
"default": {
@@ -51,7 +51,7 @@ class TaskFlow(object):
5151
The TaskFlow is the end2end inferface that could convert the raw text to model result, and decode the model result to task result. The main functions as follows:
5252
1) Convert the raw text to task result.
5353
2) Convert the model to the inference model.
54-
3) Offer the usesage and help message.
54+
3) Offer the usage and help message.
5555
Args:
5656
task (str): The task name for the TaskFlow, and get the task class from the name.
5757
model (str, optional): The model name in the task, if set None, will use the default model.
@@ -84,10 +84,26 @@ def __init__(self, task, model=None, device_id=0, **kwargs):
8484
task_class = TASKS[self.task]['models'][self.model]['task_class']
8585
self.task_instance = task_class(
8686
model=self.model, task=self.task, **self.kwargs)
87+
task_list = TASKS.keys()
88+
TaskFlow.task_list = task_list
8789

8890
def __call__(self, *inputs):
91+
"""
92+
The main work function in the taskflow.
93+
"""
8994
results = self.task_instance(inputs)
9095
return results
9196

9297
def help(self):
93-
pass
98+
"""
99+
Return the task usage message.
100+
"""
101+
return self.task_instance.help()
102+
103+
@staticmethod
104+
def tasks():
105+
"""
106+
Return the available task list.
107+
"""
108+
task_list = list(TASKS.keys())
109+
return task_list

0 commit comments

Comments
 (0)