Skip to content

Commit 8c4c418

Browse files
authored
fix fp16 for paddle version (#4283)
1 parent 5064b11 commit 8c4c418

File tree

1 file changed

+42
-6
lines changed

1 file changed

+42
-6
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,28 @@ def is_datasets_available():
111111
if is_datasets_available():
112112
import datasets
113113

114+
115+
@contextlib.contextmanager
116+
def device_guard(device="cpu", dev_id=0):
117+
origin_device = paddle.device.get_device()
118+
if device == "cpu":
119+
paddle.set_device(device)
120+
elif device in ["gpu", "xpu", "npu"]:
121+
paddle.set_device("{}:{}".format(device, dev_id))
122+
try:
123+
yield
124+
finally:
125+
paddle.set_device(origin_device)
126+
127+
128+
def paddlenlp_load(path, return_numpy=False):
129+
if return_numpy:
130+
with device_guard():
131+
return paddle.load(path)
132+
else:
133+
return paddle.load(path, return_numpy=return_numpy)
134+
135+
114136
__all__ = ["Trainer"]
115137

116138

@@ -267,7 +289,11 @@ def __init__(
267289
self.amp_dtype = "float16" if args.fp16 else "bfloat16"
268290
# fix for load saved fp16 or bf16 ckpt, decorate model first.
269291
if self.args.fp16_opt_level == "O2":
270-
paddle.amp.decorate(models=model, level=self.args.fp16_opt_level, dtype=self.amp_dtype)
292+
if self.amp_dtype == "bfloat16":
293+
# fix for paddlepaddle < 2.4.1, not support for bf16
294+
paddle.amp.decorate(models=model, level=self.args.fp16_opt_level, dtype=self.amp_dtype)
295+
else:
296+
paddle.amp.decorate(models=model, level=self.args.fp16_opt_level)
271297

272298
if self.sharding is not None:
273299
self.scaler = paddle.amp.GradScaler(init_loss_scaling=self.args.scale_loss)
@@ -1130,9 +1156,16 @@ def _wrap_model(self, model, training=True):
11301156
# Mixed precision training
11311157
if training and self.do_grad_scaling: # self.args.fp16_opt_level=="O2":
11321158
# model, self.optimizer
1133-
decorated = paddle.amp.decorate(
1134-
models=model, optimizers=self.optimizer, level=self.args.fp16_opt_level, dtype=self.amp_dtype
1135-
)
1159+
if self.amp_dtype == "bfloat16":
1160+
# fix for paddlepaddle < 2.4.1, not support for bf16
1161+
decorated = paddle.amp.decorate(
1162+
models=model, optimizers=self.optimizer, level=self.args.fp16_opt_level, dtype=self.amp_dtype
1163+
)
1164+
else:
1165+
decorated = paddle.amp.decorate(
1166+
models=model, optimizers=self.optimizer, level=self.args.fp16_opt_level
1167+
)
1168+
11361169
if self.optimizer is None:
11371170
model = decorated
11381171
else:
@@ -1459,15 +1492,18 @@ def _load_optimizer_and_scheduler(self, checkpoint):
14591492
# Load in optimizer and scheduler states
14601493
if self.sharding is not None:
14611494
self.optimizer.set_state_dict(
1462-
paddle.load(
1495+
paddlenlp_load(
14631496
os.path.join(checkpoint, OPTIMIZER_NAME + f"_shard{self.sharding_group.rank}"),
14641497
return_numpy=True,
14651498
)
14661499
)
14671500
empty_dict = paddle.load(os.path.join(checkpoint, OPTIMIZER_NAME), return_numpy=True)
14681501
assert len(empty_dict) == 0, "Optimizer file of sharding, should be empty!"
14691502
else:
1470-
self.optimizer.set_state_dict(paddle.load(os.path.join(checkpoint, OPTIMIZER_NAME), return_numpy=True))
1503+
self.optimizer.set_state_dict(
1504+
paddlenlp_load(os.path.join(checkpoint, OPTIMIZER_NAME), return_numpy=True)
1505+
)
1506+
14711507
self.lr_scheduler.set_state_dict(paddle.load(os.path.join(checkpoint, SCHEDULER_NAME)))
14721508
if self.do_grad_scaling and os.path.isfile(os.path.join(checkpoint, SCALER_NAME)):
14731509
self.scaler.load_state_dict(paddle.load(os.path.join(checkpoint, SCALER_NAME), return_numpy=True))

0 commit comments

Comments
 (0)