Skip to content

Commit 277fdb4

Browse files
authored
support amp in pir dy2st mode. (#8485)
1 parent e0d2809 commit 277fdb4

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

paddlenlp/trainer/auto_trainer.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -130,12 +130,7 @@ def _wrap_for_auto(self, model, train_dataloader):
130130

131131
def _wrap_amp_model(self, args, model):
132132
logger.info("Using half precision")
133-
if args.to_static:
134-
return
135-
self.enable_autocast_context_manager = True
136-
self.do_grad_scaling = True if self.args.fp16 else False
137133
self.amp_dtype = "float16" if self.args.fp16 else "bfloat16"
138-
self.scaler = dist.shard_scaler(paddle.amp.GradScaler(init_loss_scaling=self.args.scale_loss))
139134
if self.args.fp16_opt_level == "O2":
140135
paddle.amp.decorate(
141136
models=model,
@@ -144,6 +139,11 @@ def _wrap_amp_model(self, args, model):
144139
master_grad=self.args.amp_master_grad,
145140
excluded_layers=QuantizationLinear,
146141
)
142+
if args.to_static:
143+
return
144+
self.enable_autocast_context_manager = True
145+
self.do_grad_scaling = True if self.args.fp16 else False
146+
self.scaler = dist.shard_scaler(paddle.amp.GradScaler(init_loss_scaling=self.args.scale_loss))
147147

148148
def _get_item_from_loss(self, loss):
149149
if isinstance(loss, paddle.Tensor):

0 commit comments

Comments
 (0)