Skip to content

Commit c1aad02

Browse files
authored
bug fix (#8238)
1 parent 70f4a6f commit c1aad02

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -413,9 +413,9 @@ def _wrap_amp_model(self, args, model):
413413
self.scaler = paddle.amp.GradScaler(init_loss_scaling=self.args.scale_loss)
414414
if self.amp_dtype == "float16" or self.amp_dtype == "bfloat16":
415415
if ShardingOption.SHARD_OP in self.args.sharding:
416-
self.scaler = fleet.distributed_scaler(self.scaler)
417416
if self.args.amp_master_grad:
418417
mix_precision_utils.MixPrecisionScaler(self.scaler) # retun value has no use
418+
self.scaler = fleet.distributed_scaler(self.scaler)
419419
else:
420420
# scaler for stage2 and stage3
421421
from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_utils import (

0 commit comments

Comments
 (0)