Skip to content
This repository was archived by the owner on Jan 15, 2024. It is now read-only.

Commit 9e268c0

Browse files
authored
Fix electra (#1291)
* update Dockerfile * fix num_out_files * fix run_electra * Revert "update Dockerfile" This reverts commit 80593a2.
1 parent c33e62e commit 9e268c0

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

scripts/pretraining/data_preprocessing.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,7 @@ def main(args):
5353
random.shuffle(fnames)
5454
num_files = len(fnames)
5555
num_out_files = min(args.num_out_files, num_files)
56-
file_volume = math.ceil(num_files / num_out_files)
57-
splited_files = np.array_split(fnames, file_volume)
58-
num_out_files = len(splited_files)
56+
splited_files = np.array_split(fnames, num_out_files)
5957
output_files = [os.path.join(
6058
args.output, "owt-pretrain-record-{}.npz".format(str(i).zfill(4))) for i in range(num_out_files)]
6159
print("All preprocessed features will be saved in {} npz files".format(num_out_files))

scripts/pretraining/run_electra.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -472,9 +472,11 @@ def train(args):
472472
train_end_time - train_start_time))
473473
if writer is not None:
474474
writer.close()
475-
model_name = args.model_name.replace('google', 'gluon')
476-
save_dir = os.path.join(args.output_dir, model_name)
477-
final_save(model, save_dir, tokenizer)
475+
476+
if local_rank == 0:
477+
model_name = args.model_name.replace('google', 'gluon')
478+
save_dir = os.path.join(args.output_dir, model_name)
479+
final_save(model, save_dir, tokenizer)
478480

479481
# TODO(zheyuye), Directly implement a metric for weighted accuracy
480482

0 commit comments

Comments
 (0)