Skip to content
This repository was archived by the owner on Dec 15, 2024. It is now read-only.

Commit c3b531a

Browse files
committed
1 parent e02721d commit c3b531a

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

run_qg.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -178,14 +178,16 @@ def main(args_file=None):
178178
using_tpu=training_args.tpu_num_cores is not None
179179
)
180180

181+
# Prediction Loss
182+
training_args.prediction_loss_only=True
183+
181184
# Initialize our Trainer
182185
trainer = Trainer(
183186
model=model,
184187
args=training_args,
185188
train_dataset=train_dataset,
186189
eval_dataset=valid_dataset,
187190
data_collator=data_collator,
188-
prediction_loss_only=True,
189191
label_smoothing=model_args.label_smoothing
190192
)
191193

@@ -200,8 +202,8 @@ def main(args_file=None):
200202
trainer.save_model()
201203
# For convenience, we also re-save the tokenizer to the same directory,
202204
# so that you can share your model easily on huggingface.co/models =)
203-
if trainer.is_world_master():
204-
tokenizer.save_pretrained(training_args.output_dir)
205+
if trainer.is_world_process_zero():
206+
tokenizer.save_pretrained(training_args.output_dir)
205207

206208
# Evaluation
207209
results = {}
@@ -233,4 +235,4 @@ def run_qg(args_dict):
233235
main(args_file="args.json")
234236

235237
if __name__ == "__main__":
236-
main()
238+
main()

0 commit comments

Comments
 (0)