77from argparse import ArgumentParser
88from collections import OrderedDict
99
10+ import numpy
1011import torch
1112import torch .utils .data
1213from pytorch_lightning import Trainer
14+ from tqdm import tqdm
1315
1416import ltp
1517from ltp import (
2022from ltp .data import dataset as datasets
2123from ltp .data .utils import collate , MultiTaskDataloader
2224from ltp .transformer_multitask import TransformerMultiTask as Model
23- from ltp .utils import TaskInfo , common_train , tune_train
25+ from ltp .utils import TaskInfo , common_train , tune_train , map2device , convert2npy
2426from ltp .utils import deploy_model
2527
2628os .environ ['TOKENIZERS_PARALLELISM' ] = 'true'
@@ -190,6 +192,55 @@ def configure_optimizers(self: Model):
190192)
191193
192194
195+ def build_ner_distill_dataset (args ):
196+ model = Model .load_from_checkpoint (
197+ args .resume_from_checkpoint , hparams = args
198+ )
199+
200+ model .eval ()
201+ model .freeze ()
202+
203+ dataset , metric = task_named_entity_recognition .build_dataset (model , args .ner_data_dir , task_info .task_name )
204+ train_dataloader = torch .utils .data .DataLoader (
205+ dataset [datasets .Split .TRAIN ],
206+ batch_size = args .batch_size ,
207+ collate_fn = collate ,
208+ num_workers = args .num_workers
209+ )
210+
211+ output = os .path .join (args .ner_data_dir , task_info .task_name , 'output.npz' )
212+
213+ if torch .cuda .is_available ():
214+ model .cuda ()
215+ map2cpu = lambda x : map2device (x )
216+ map2cuda = lambda x : map2device (x , model .device )
217+ else :
218+ map2cpu = lambda x : x
219+ map2cuda = lambda x : x
220+
221+ with torch .no_grad ():
222+ batchs = []
223+ for batch in tqdm (train_dataloader ):
224+ batch = map2cuda (batch )
225+ logits = model .forward (task = 'ner' , ** batch ).logits
226+ batch .update (logits = logits )
227+ batchs .append (map2cpu (batch ))
228+ try :
229+ numpy .savez (
230+ output ,
231+ data = convert2npy (batchs ),
232+ extra = convert2npy ({
233+ 'transitions' : model .ner_classifier .crf .transitions ,
234+ 'start_transitions' : model .ner_classifier .crf .start_transitions ,
235+ 'end_transitions' : model .ner_classifier .crf .end_transitions
236+ })
237+ )
238+ except Exception as e :
239+ numpy .savez (output , data = convert2npy (batchs ))
240+
241+ print ("Done" )
242+
243+
193244def add_task_specific_args (parent_parser ):
194245 parser = ArgumentParser (parents = [parent_parser ], add_help = False )
195246 parser .add_argument ('--seed' , type = int , default = 19980524 )
@@ -210,6 +261,7 @@ def add_task_specific_args(parent_parser):
210261 parser .add_argument ('--dep_data_dir' , type = str , default = None )
211262 parser .add_argument ('--sdp_data_dir' , type = str , default = None )
212263 parser .add_argument ('--srl_data_dir' , type = str , default = None )
264+ parser .add_argument ('--build_ner_dataset' , action = 'store_true' )
213265 return parser
214266
215267
@@ -226,6 +278,8 @@ def main():
226278
227279 if args .ltp_model is not None and args .resume_from_checkpoint is not None :
228280 deploy_model (args , args .ltp_version )
281+ elif args .build_ner_dataset :
282+ build_ner_distill_dataset (args )
229283 elif args .tune :
230284 tune_train (args , model_class = Model , task_info = task_info , build_method = build_method )
231285 else :
0 commit comments