Skip to content

Commit 44569cc

Browse files
authored
fix xlnet bug (#1937)
1 parent 61b51ea commit 44569cc

File tree

2 files changed

+20
-6
lines changed

2 files changed

+20
-6
lines changed

examples/language_model/xlnet/run_glue.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,8 @@ def create_data_loader(args, tokenizer):
201201
collate_fn=batchify_fn,
202202
num_workers=0,
203203
return_list=True)
204+
205+
return train_data_loader, dev_data_loader_matched, dev_data_loader_mismatched, train_ds, dev_ds_matched, dev_ds_mismatched
204206
else:
205207
dev_ds = load_dataset('glue', args.task_name, splits='dev')
206208
dev_ds = dev_ds.map(trans_func, lazy=True)
@@ -214,7 +216,7 @@ def create_data_loader(args, tokenizer):
214216
num_workers=0,
215217
return_list=True)
216218

217-
return train_data_loader, dev_data_loader, train_ds, dev_ds
219+
return train_data_loader, dev_data_loader, train_ds, dev_ds
218220

219221

220222
def do_train(args):
@@ -231,8 +233,12 @@ def do_train(args):
231233

232234
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
233235

234-
train_data_loader, dev_data_loader, train_ds, dev_ds = create_data_loader(
235-
args, tokenizer)
236+
if args.task_name == "mnli":
237+
train_data_loader, dev_data_loader_matched, dev_data_loader_mismatched, train_ds, dev_ds_matched, dev_ds_mismatched = create_data_loader(
238+
args, tokenizer)
239+
else:
240+
train_data_loader, dev_data_loader, train_ds, dev_ds = create_data_loader(
241+
args, tokenizer)
236242

237243
num_classes = 1 if train_ds.label_list is None else len(train_ds.label_list)
238244
model = XLNetForSequenceClassification.from_pretrained(

tests/test_tipc/benchmark/modules/xlnet.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,21 @@ def create_data_loader(self, args, **kwargs):
4444
args.task_name = args.task_name.lower()
4545
tokenizer = XLNetTokenizer.from_pretrained(args.model_name_or_path)
4646

47-
train_loader, dev_loader, train_ds, dev_ds = create_data_loader(
48-
args, tokenizer)
47+
if args.task_name == "mnli":
48+
train_data_loader, dev_data_loader_matched, dev_data_loader_mismatched, train_ds, _, _ = create_data_loader(
49+
args, tokenizer)
50+
else:
51+
train_loader, dev_loader, train_ds, _ = create_data_loader(
52+
args, tokenizer)
4953

5054
self.num_batch = len(train_loader)
5155
self.label_list = train_ds.label_list
5256

53-
return train_loader, dev_loader
57+
if args.task_name == "mnli":
58+
return train_data_loader, (dev_data_loader_matched,
59+
dev_data_loader_mismatched)
60+
else:
61+
return train_loader, dev_loader
5462

5563
def build_model(self, args, **kwargs):
5664
num_classes = 1 if self.label_list is None else len(self.label_list)

0 commit comments

Comments
 (0)