Skip to content

Commit 82b1cc4

Browse files
authored
[BugFix] [Pre-training] Fix mlm label of dataset_utils (#1633)
1 parent 2455ec2 commit 82b1cc4

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

examples/language_model/data_tools/dataset_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -388,11 +388,12 @@ def create_masked_lm_predictions(tokens,
388388
output_tokens[index] = masked_token
389389
masked_lms.append(
390390
MaskedLmInstance(
391-
index=index, label=tokens[index]))
391+
index=index, label=backup_output_tokens[index]))
392392

393393
masked_spans.append(
394394
MaskedLmInstance(
395-
index=index_set, label=[tokens[index] for index in index_set]))
395+
index=index_set,
396+
label=[backup_output_tokens[index] for index in index_set]))
396397

397398
assert len(masked_lms) <= num_to_predict
398399
np_rng.shuffle(ngram_indexes)

0 commit comments

Comments
 (0)