@@ -201,6 +201,8 @@ def create_data_loader(args, tokenizer):
201
201
collate_fn = batchify_fn ,
202
202
num_workers = 0 ,
203
203
return_list = True )
204
+
205
+ return train_data_loader , dev_data_loader_matched , dev_data_loader_mismatched , train_ds , dev_ds_matched , dev_ds_mismatched
204
206
else :
205
207
dev_ds = load_dataset ('glue' , args .task_name , splits = 'dev' )
206
208
dev_ds = dev_ds .map (trans_func , lazy = True )
@@ -214,7 +216,7 @@ def create_data_loader(args, tokenizer):
214
216
num_workers = 0 ,
215
217
return_list = True )
216
218
217
- return train_data_loader , dev_data_loader , train_ds , dev_ds
219
+ return train_data_loader , dev_data_loader , train_ds , dev_ds
218
220
219
221
220
222
def do_train (args ):
@@ -231,8 +233,12 @@ def do_train(args):
231
233
232
234
tokenizer = tokenizer_class .from_pretrained (args .model_name_or_path )
233
235
234
- train_data_loader , dev_data_loader , train_ds , dev_ds = create_data_loader (
235
- args , tokenizer )
236
+ if args .task_name == "mnli" :
237
+ train_data_loader , dev_data_loader_matched , dev_data_loader_mismatched , train_ds , dev_ds_matched , dev_ds_mismatched = create_data_loader (
238
+ args , tokenizer )
239
+ else :
240
+ train_data_loader , dev_data_loader , train_ds , dev_ds = create_data_loader (
241
+ args , tokenizer )
236
242
237
243
num_classes = 1 if train_ds .label_list is None else len (train_ds .label_list )
238
244
model = XLNetForSequenceClassification .from_pretrained (
0 commit comments