Skip to content

Commit 67d93b1

Browse files
authored
Fix hyp dataset loading (#1900)
1 parent 6f0535a commit 67d93b1

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

examples/language_model/ernie-doc/run_classifier.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def do_train(args):
165165
if eval_name == test_name:
166166
test_ds = eval_ds
167167
else:
168-
test_ds = load_dataset(args.dataset, splits=["train", test_name])
168+
test_ds = load_dataset(args.dataset, splits=[test_name])
169169

170170
num_classes = len(train_ds.label_list)
171171

paddlenlp/transformers/ernie_doc/modeling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def _cache_mem(self, curr_out, prev_mem):
280280
if self.mem_len is None or self.mem_len == 0:
281281
return None
282282
if prev_mem is None:
283-
new_mem = curr[:, -self.mem_len:, :]
283+
new_mem = curr_out[:, -self.mem_len:, :]
284284
else:
285285
new_mem = paddle.concat([prev_mem, curr_out],
286286
1)[:, -self.mem_len:, :]

0 commit comments

Comments
 (0)