39
39
warnings .simplefilter (action = 'ignore' , category = Warning , lineno = 0 , append = False )
40
40
41
41
TASKS = {
42
+ 'dependency_parsing' : {
43
+ "models" : {
44
+ "ddparser" : {
45
+ "task_class" : DDParserTask ,
46
+ "task_flag" : 'dependency_parsing-biaffine' ,
47
+ },
48
+ "ddparser-ernie-1.0" : {
49
+ "task_class" : DDParserTask ,
50
+ "task_flag" : 'dependency_parsing-ernie-1.0' ,
51
+ },
52
+ "ddparser-ernie-gram-zh" : {
53
+ "task_class" : DDParserTask ,
54
+ "task_flag" : 'dependency_parsing-ernie-gram-zh' ,
55
+ },
56
+ },
57
+ "default" : {
58
+ "model" : "ddparser" ,
59
+ }
60
+ },
61
+ 'dialogue' : {
62
+ "models" : {
63
+ "plato-mini" : {
64
+ "task_class" : DialogueTask ,
65
+ "task_flag" : "dialogue-plato-mini"
66
+ },
67
+ },
68
+ "default" : {
69
+ "model" : "plato-mini" ,
70
+ }
71
+ },
42
72
"knowledge_mining" : {
43
73
"models" : {
44
74
"wordtag" : {
45
75
"task_class" : WordTagTask ,
46
76
"task_flag" : 'knowledge_mining-wordtag' ,
77
+ "task_priority_path" : "wordtag" ,
47
78
},
48
79
"nptag" : {
49
80
"task_class" : NPTagTask ,
50
81
"task_flag" : 'knowledge_mining-nptag' ,
51
82
},
52
83
},
53
84
"default" : {
54
- "model" : "wordtag"
85
+ "model" : "wordtag" ,
86
+ }
87
+ },
88
+ "lexical_analysis" : {
89
+ "models" : {
90
+ "lac" : {
91
+ "task_class" : LacTask ,
92
+ "hidden_size" : 128 ,
93
+ "emb_dim" : 128 ,
94
+ "task_flag" : 'lexical_analysis-gru_crf' ,
95
+ "task_priority_path" : "lac" ,
96
+ }
97
+ },
98
+ "default" : {
99
+ "model" : "lac"
55
100
}
56
101
},
57
102
"ner" : {
58
103
"modes" : {
59
104
"accurate" : {
60
105
"task_class" : NERWordTagTask ,
61
106
"task_flag" : "ner-wordtag" ,
107
+ "task_priority_path" : "wordtag" ,
62
108
"linking" : False ,
63
109
},
64
110
"fast" : {
65
111
"task_class" : NERLACTask ,
66
112
"hidden_size" : 128 ,
67
113
"emb_dim" : 128 ,
68
114
"task_flag" : "ner-lac" ,
115
+ "task_priority_path" : "lac" ,
69
116
}
70
117
},
71
118
"default" : {
77
124
"gpt-cpm-large-cn" : {
78
125
"task_class" : PoetryGenerationTask ,
79
126
"task_flag" : 'poetry_generation-gpt-cpm-large-cn' ,
127
+ "task_priority_path" : "gpt-cpm-large-cn" ,
80
128
},
81
129
},
82
130
"default" : {
83
131
"model" : "gpt-cpm-large-cn" ,
84
132
}
85
133
},
86
- "question_answering" : {
87
- "models" : {
88
- "gpt-cpm-large-cn" : {
89
- "task_class" : QuestionAnsweringTask ,
90
- "task_flag" : 'question_answering-gpt-cpm-large-cn' ,
91
- },
92
- },
93
- "default" : {
94
- "model" : "gpt-cpm-large-cn" ,
95
- }
96
- },
97
- "lexical_analysis" : {
134
+ "pos_tagging" : {
98
135
"models" : {
99
136
"lac" : {
100
- "task_class" : LacTask ,
137
+ "task_class" : POSTaggingTask ,
101
138
"hidden_size" : 128 ,
102
139
"emb_dim" : 128 ,
103
- "task_flag" : 'lexical_analysis-gru_crf' ,
140
+ "task_flag" : 'pos_tagging-gru_crf' ,
141
+ "task_priority_path" : "lac" ,
104
142
}
105
143
},
106
144
"default" : {
107
145
"model" : "lac"
108
146
}
109
147
},
110
- "word_segmentation" : {
111
- "modes" : {
112
- "fast" : {
113
- "task_class" : SegJiebaTask ,
114
- "task_flag" : "word_segmentation-jieba" ,
115
- },
116
- "base" : {
117
- "task_class" : SegLACTask ,
118
- "hidden_size" : 128 ,
119
- "emb_dim" : 128 ,
120
- "task_flag" : "word_segmentation-gru_crf" ,
121
- },
122
- "accurate" : {
123
- "task_class" : SegWordTagTask ,
124
- "task_flag" : "word_segmentation-wordtag" ,
125
- "linking" : False ,
126
- },
127
- },
128
- "default" : {
129
- "mode" : "base"
130
- }
131
- },
132
- "pos_tagging" : {
148
+ "question_answering" : {
133
149
"models" : {
134
- "lac" : {
135
- "task_class" : POSTaggingTask ,
136
- "hidden_size" : 128 ,
137
- "emb_dim" : 128 ,
138
- "task_flag" : 'pos_tagging-gru_crf' ,
139
- }
150
+ "gpt-cpm-large-cn" : {
151
+ "task_class" : QuestionAnsweringTask ,
152
+ "task_flag" : 'question_answering-gpt-cpm-large-cn' ,
153
+ "task_priority_path" : "gpt-cpm-large-cn" ,
154
+ },
140
155
},
141
156
"default" : {
142
- "model" : "lac"
157
+ "model" : "gpt-cpm-large-cn" ,
143
158
}
144
159
},
145
160
'sentiment_analysis' : {
157
172
"model" : "bilstm"
158
173
}
159
174
},
160
- 'dependency_parsing' : {
161
- "models" : {
162
- "ddparser" : {
163
- "task_class" : DDParserTask ,
164
- "task_flag" : 'dependency_parsing-biaffine' ,
165
- },
166
- "ddparser-ernie-1.0" : {
167
- "task_class" : DDParserTask ,
168
- "task_flag" : 'dependency_parsing-ernie-1.0' ,
169
- },
170
- "ddparser-ernie-gram-zh" : {
171
- "task_class" : DDParserTask ,
172
- "task_flag" : 'dependency_parsing-ernie-gram-zh' ,
173
- },
174
- },
175
- "default" : {
176
- "model" : "ddparser"
177
- }
178
- },
179
175
'text_correction' : {
180
176
"models" : {
181
- "csc- ernie-1.0 " : {
177
+ "ernie-csc " : {
182
178
"task_class" : CSCTask ,
183
- "task_flag" : "text_correction-csc- ernie-1.0 "
179
+ "task_flag" : "text_correction-ernie-csc "
184
180
},
185
181
},
186
182
"default" : {
187
- "model" : "csc- ernie-1.0 "
183
+ "model" : "ernie-csc "
188
184
}
189
185
},
190
186
'text_similarity' : {
198
194
"model" : "simbert-base-chinese"
199
195
}
200
196
},
201
- 'dialogue' : {
202
- "models" : {
203
- "plato-mini" : {
204
- "task_class" : DialogueTask ,
205
- "task_flag" : "dialogue-plato-mini"
197
+ "word_segmentation" : {
198
+ "modes" : {
199
+ "fast" : {
200
+ "task_class" : SegJiebaTask ,
201
+ "task_flag" : "word_segmentation-jieba" ,
202
+ },
203
+ "base" : {
204
+ "task_class" : SegLACTask ,
205
+ "hidden_size" : 128 ,
206
+ "emb_dim" : 128 ,
207
+ "task_flag" : "word_segmentation-gru_crf" ,
208
+ "task_priority_path" : "lac" ,
209
+ },
210
+ "accurate" : {
211
+ "task_class" : SegWordTagTask ,
212
+ "task_flag" : "word_segmentation-wordtag" ,
213
+ "task_priority_path" : "wordtag" ,
214
+ "linking" : False ,
206
215
},
207
216
},
208
217
"default" : {
209
- "model " : "plato-mini "
218
+ "mode " : "base "
210
219
}
211
220
},
212
221
}
@@ -247,6 +256,13 @@ def __init__(self, task, model=None, mode=None, device_id=0, **kwargs):
247
256
)), "The {} name:{} is not in task:[{}]" .format (tag , model , task )
248
257
else :
249
258
self .model = TASKS [task ]['default' ][ind_tag ]
259
+
260
+ if "task_priority_path" in TASKS [self .task ][tag ][self .model ]:
261
+ self .priority_path = TASKS [self .task ][tag ][self .model ][
262
+ "task_priority_path" ]
263
+ else :
264
+ self .priority_path = None
265
+
250
266
# Set the device for the task
251
267
device = get_env_device ()
252
268
if device == 'cpu' or device_id == - 1 :
@@ -261,7 +277,10 @@ def __init__(self, task, model=None, mode=None, device_id=0, **kwargs):
261
277
self .kwargs = kwargs
262
278
task_class = TASKS [self .task ][tag ][self .model ]['task_class' ]
263
279
self .task_instance = task_class (
264
- model = self .model , task = self .task , ** self .kwargs )
280
+ model = self .model ,
281
+ task = self .task ,
282
+ priority_path = self .priority_path ,
283
+ ** self .kwargs )
265
284
task_list = TASKS .keys ()
266
285
Taskflow .task_list = task_list
267
286
@@ -297,7 +316,7 @@ def from_segments(self, *inputs):
297
316
return results
298
317
299
318
def interactive_mode (self , max_turn ):
300
- with self .task_instance .interactive_mode (max_turn = 3 ):
319
+ with self .task_instance .interactive_mode (max_turn ):
301
320
while True :
302
321
human = input ("[Human]:" ).strip ()
303
322
if human .lower () == "exit" :
0 commit comments