@@ -111,6 +111,28 @@ def is_datasets_available():
111
111
if is_datasets_available ():
112
112
import datasets
113
113
114
+
115
+ @contextlib .contextmanager
116
+ def device_guard (device = "cpu" , dev_id = 0 ):
117
+ origin_device = paddle .device .get_device ()
118
+ if device == "cpu" :
119
+ paddle .set_device (device )
120
+ elif device in ["gpu" , "xpu" , "npu" ]:
121
+ paddle .set_device ("{}:{}" .format (device , dev_id ))
122
+ try :
123
+ yield
124
+ finally :
125
+ paddle .set_device (origin_device )
126
+
127
+
128
+ def paddlenlp_load (path , return_numpy = False ):
129
+ if return_numpy :
130
+ with device_guard ():
131
+ return paddle .load (path )
132
+ else :
133
+ return paddle .load (path , return_numpy = return_numpy )
134
+
135
+
114
136
__all__ = ["Trainer" ]
115
137
116
138
@@ -267,7 +289,11 @@ def __init__(
267
289
self .amp_dtype = "float16" if args .fp16 else "bfloat16"
268
290
# fix for load saved fp16 or bf16 ckpt, decorate model first.
269
291
if self .args .fp16_opt_level == "O2" :
270
- paddle .amp .decorate (models = model , level = self .args .fp16_opt_level , dtype = self .amp_dtype )
292
+ if self .amp_dtype == "bfloat16" :
293
+ # fix for paddlepaddle < 2.4.1, not support for bf16
294
+ paddle .amp .decorate (models = model , level = self .args .fp16_opt_level , dtype = self .amp_dtype )
295
+ else :
296
+ paddle .amp .decorate (models = model , level = self .args .fp16_opt_level )
271
297
272
298
if self .sharding is not None :
273
299
self .scaler = paddle .amp .GradScaler (init_loss_scaling = self .args .scale_loss )
@@ -1130,9 +1156,16 @@ def _wrap_model(self, model, training=True):
1130
1156
# Mixed precision training
1131
1157
if training and self .do_grad_scaling : # self.args.fp16_opt_level=="O2":
1132
1158
# model, self.optimizer
1133
- decorated = paddle .amp .decorate (
1134
- models = model , optimizers = self .optimizer , level = self .args .fp16_opt_level , dtype = self .amp_dtype
1135
- )
1159
+ if self .amp_dtype == "bfloat16" :
1160
+ # fix for paddlepaddle < 2.4.1, not support for bf16
1161
+ decorated = paddle .amp .decorate (
1162
+ models = model , optimizers = self .optimizer , level = self .args .fp16_opt_level , dtype = self .amp_dtype
1163
+ )
1164
+ else :
1165
+ decorated = paddle .amp .decorate (
1166
+ models = model , optimizers = self .optimizer , level = self .args .fp16_opt_level
1167
+ )
1168
+
1136
1169
if self .optimizer is None :
1137
1170
model = decorated
1138
1171
else :
@@ -1459,15 +1492,18 @@ def _load_optimizer_and_scheduler(self, checkpoint):
1459
1492
# Load in optimizer and scheduler states
1460
1493
if self .sharding is not None :
1461
1494
self .optimizer .set_state_dict (
1462
- paddle . load (
1495
+ paddlenlp_load (
1463
1496
os .path .join (checkpoint , OPTIMIZER_NAME + f"_shard{ self .sharding_group .rank } " ),
1464
1497
return_numpy = True ,
1465
1498
)
1466
1499
)
1467
1500
empty_dict = paddle .load (os .path .join (checkpoint , OPTIMIZER_NAME ), return_numpy = True )
1468
1501
assert len (empty_dict ) == 0 , "Optimizer file of sharding, should be empty!"
1469
1502
else :
1470
- self .optimizer .set_state_dict (paddle .load (os .path .join (checkpoint , OPTIMIZER_NAME ), return_numpy = True ))
1503
+ self .optimizer .set_state_dict (
1504
+ paddlenlp_load (os .path .join (checkpoint , OPTIMIZER_NAME ), return_numpy = True )
1505
+ )
1506
+
1471
1507
self .lr_scheduler .set_state_dict (paddle .load (os .path .join (checkpoint , SCHEDULER_NAME )))
1472
1508
if self .do_grad_scaling and os .path .isfile (os .path .join (checkpoint , SCALER_NAME )):
1473
1509
self .scaler .load_state_dict (paddle .load (os .path .join (checkpoint , SCALER_NAME ), return_numpy = True ))
0 commit comments