@@ -330,7 +330,8 @@ def __init__(
330
330
"Passing `optimizers` is not allowed if sharding is enabled."
331
331
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
332
332
)
333
- if self .args .pipeline_parallel_degree > 1 :
333
+
334
+ if self .args .pipeline_parallel_degree > 1 and self .args .use_hybrid_parallel :
334
335
from paddle .distributed .fleet .meta_parallel import PipelineLayer
335
336
336
337
assert (isinstance (model , LoRAModel ) and isinstance (model .model , PipelineLayer )) or isinstance (
@@ -352,52 +353,9 @@ def __init__(
352
353
353
354
self .do_grad_scaling = False
354
355
self .enable_autocast_context_manager = False
355
-
356
356
if args .fp16 or args .bf16 :
357
- logger .info ("Using half precision" )
358
- self .enable_autocast_context_manager = True
359
- self .do_grad_scaling = True if args .fp16 else False
360
- self .amp_dtype = "float16" if args .fp16 else "bfloat16"
361
- # fix for load saved fp16 or bf16 ckpt, decorate model first.
362
- if self .args .fp16_opt_level == "O2" :
363
- paddle .amp .decorate (
364
- models = model ,
365
- level = self .args .fp16_opt_level ,
366
- dtype = self .amp_dtype ,
367
- excluded_layers = QuantizationLinear ,
368
- )
369
- # for pipeline mode and pure tensor parallel
370
- if self .args .pipeline_parallel_degree > 1 or (
371
- self .args .tensor_parallel_degree > 1 and self .sharding is None
372
- ):
373
- self .scaler = paddle .amp .GradScaler (init_loss_scaling = self .args .scale_loss )
374
- if self .args .amp_master_grad :
375
- mix_precision_utils .MixPrecisionScaler (self .scaler ) # retun value has no use
376
- self .scaler = fleet .distributed_scaler (self .scaler )
377
- elif self .sharding is not None :
378
- self .scaler = paddle .amp .GradScaler (init_loss_scaling = self .args .scale_loss )
379
- if self .amp_dtype == "float16" or self .amp_dtype == "bfloat16" :
380
- if ShardingOption .SHARD_OP in self .args .sharding :
381
- self .scaler = fleet .distributed_scaler (self .scaler )
382
- if self .args .amp_master_grad :
383
- mix_precision_utils .MixPrecisionScaler (self .scaler ) # retun value has no use
384
- else :
385
- # scaler for stage2 and stage3
386
- from paddle .distributed .fleet .meta_parallel .sharding .group_sharded_utils import (
387
- GroupShardedScaler ,
388
- )
389
-
390
- if self .args .amp_master_grad :
391
- mix_precision_utils .MixPrecisionScaler (self .scaler ) # return value has no use
392
-
393
- self .scaler = GroupShardedScaler (self .scaler )
394
- else :
395
- self .do_grad_scaling = False
396
- self .use_cuda_amp = False
397
- self .amp_dtype = None
398
-
399
- else :
400
- self .scaler = paddle .amp .GradScaler (init_loss_scaling = self .args .scale_loss )
357
+ # set do_grad_scaling, enable_autocast_context_manager
358
+ self ._wrap_amp_model (args , model )
401
359
402
360
if args .recompute :
403
361
@@ -422,6 +380,50 @@ def fn(layer):
422
380
# very last
423
381
self ._memory_tracker .stop_and_update_metrics ()
424
382
383
+ def _wrap_amp_model (self , args , model ):
384
+ logger .info ("Using half precision" )
385
+ self .enable_autocast_context_manager = True
386
+ self .do_grad_scaling = True if args .fp16 else False
387
+ self .amp_dtype = "float16" if args .fp16 else "bfloat16"
388
+ # fix for load saved fp16 or bf16 ckpt, decorate model first.
389
+ if self .args .fp16_opt_level == "O2" :
390
+ paddle .amp .decorate (
391
+ models = model ,
392
+ level = self .args .fp16_opt_level ,
393
+ dtype = self .amp_dtype ,
394
+ excluded_layers = QuantizationLinear ,
395
+ )
396
+ # for pipeline mode and pure tensor parallel
397
+ if self .args .pipeline_parallel_degree > 1 or (self .args .tensor_parallel_degree > 1 and self .sharding is None ):
398
+ self .scaler = paddle .amp .GradScaler (init_loss_scaling = self .args .scale_loss )
399
+ if self .args .amp_master_grad :
400
+ mix_precision_utils .MixPrecisionScaler (self .scaler ) # retun value has no use
401
+ self .scaler = fleet .distributed_scaler (self .scaler )
402
+ elif self .sharding is not None :
403
+ self .scaler = paddle .amp .GradScaler (init_loss_scaling = self .args .scale_loss )
404
+ if self .amp_dtype == "float16" or self .amp_dtype == "bfloat16" :
405
+ if ShardingOption .SHARD_OP in self .args .sharding :
406
+ self .scaler = fleet .distributed_scaler (self .scaler )
407
+ if self .args .amp_master_grad :
408
+ mix_precision_utils .MixPrecisionScaler (self .scaler ) # retun value has no use
409
+ else :
410
+ # scaler for stage2 and stage3
411
+ from paddle .distributed .fleet .meta_parallel .sharding .group_sharded_utils import (
412
+ GroupShardedScaler ,
413
+ )
414
+
415
+ if self .args .amp_master_grad :
416
+ mix_precision_utils .MixPrecisionScaler (self .scaler ) # return value has no use
417
+
418
+ self .scaler = GroupShardedScaler (self .scaler )
419
+ else :
420
+ self .do_grad_scaling = False
421
+ self .use_cuda_amp = False
422
+ self .amp_dtype = None
423
+
424
+ else :
425
+ self .scaler = paddle .amp .GradScaler (init_loss_scaling = self .args .scale_loss )
426
+
425
427
def add_callback (self , callback ):
426
428
"""
427
429
Add a callback to the current list of [`~TrainerCallback`].
@@ -747,6 +749,33 @@ def train(
747
749
# so, the trainable numel is a little bigger than real.
748
750
logger .info (f" Number of trainable parameters = { trainable_numel :,} (all devices, roughly)" )
749
751
752
+ return self ._inner_training_loop (
753
+ args ,
754
+ model ,
755
+ train_dataloader ,
756
+ len_dataloader ,
757
+ max_steps ,
758
+ num_train_epochs ,
759
+ num_update_steps_per_epoch ,
760
+ num_train_samples ,
761
+ resume_from_checkpoint ,
762
+ ignore_keys_for_eval ,
763
+ )
764
+
765
+ def _inner_training_loop (
766
+ self ,
767
+ args ,
768
+ model ,
769
+ train_dataloader ,
770
+ len_dataloader ,
771
+ max_steps ,
772
+ num_train_epochs ,
773
+ num_update_steps_per_epoch ,
774
+ num_train_samples ,
775
+ resume_from_checkpoint ,
776
+ ignore_keys_for_eval ,
777
+ ):
778
+
750
779
start_time = time .time ()
751
780
self ._globalstep_last_start_time = time .time ()
752
781
self .state .epoch = 0
@@ -2274,6 +2303,7 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_
2274
2303
logger .warning ("Trainer.model is not a `PretrainedModel`, not suppor for merge_tensor_parallel." )
2275
2304
if state_dict is None :
2276
2305
state_dict = self .model .state_dict ()
2306
+
2277
2307
paddle .save (
2278
2308
state_dict ,
2279
2309
os .path .join (output_dir , _add_variant (PADDLE_WEIGHTS_NAME , self .args .weight_name_suffix )),
0 commit comments