@@ -414,6 +414,67 @@ def test_dp_pp(self):
414
414
opt .clear_grad ()
415
415
return losses_by_step , all_losses_in_one_step_md5sum
416
416
417
+ def test_pp_model_with_ClipGradByGlobalNorm (self ):
418
+ """Test pipeline parallel model with ClipGradByGlobalNorm using PPMyModel as the baseline"""
419
+ fix_seeds ()
420
+ pp_model = PPMyModel ()
421
+ opt = paddle .optimizer .AdamW (
422
+ learning_rate = 0.001 ,
423
+ parameters = pp_model .parameters (),
424
+ grad_clip = paddle .nn .ClipGradByGlobalNorm (1.0 ),
425
+ )
426
+ loss_fn = nn .MSELoss ()
427
+ dataset = RandomDataset (image_size = 8 , output_size = 8 , num_samples = 8 )
428
+ loader = DataLoader (dataset , batch_size = 1 )
429
+ pp_losses_step = []
430
+ num_iterations = 20
431
+
432
+ for iter_idx in range (num_iterations ):
433
+ pp_losses_micro_batch = []
434
+ for i , (data , label ) in enumerate (loader ):
435
+ output = pp_model (data )
436
+ loss = loss_fn (output , label )
437
+ pp_losses_micro_batch .append (loss .item ())
438
+ loss .backward ()
439
+ pp_losses_step .append (
440
+ np .array (pp_losses_micro_batch , dtype = np .float32 ).mean ()
441
+ )
442
+ opt .step ()
443
+ opt .clear_grad ()
444
+ return pp_losses_step
445
+
446
+ def test_ScheduleFThenB_with_ClipGradByGlobalNorm (self ):
447
+ fix_seeds ()
448
+ self .model = PPMyModel_SingleStage ()
449
+ self .micro_batches = 8
450
+ self .stage = PipelineStage (self .model , self .rank , 4 , group = self .group )
451
+ self .stage .has_backward = True
452
+ loss_fn_ = nn .MSELoss ()
453
+ schedule = ScheduleFThenB (
454
+ self .stage , self .micro_batches , loss_fn = loss_fn_
455
+ )
456
+ opt = paddle .optimizer .AdamW (
457
+ learning_rate = 0.001 ,
458
+ parameters = self .model .parameters (),
459
+ grad_clip = paddle .nn .ClipGradByGlobalNorm (1.0 ),
460
+ )
461
+ dataset = RandomDataset (image_size = 8 , output_size = 8 , num_samples = 8 )
462
+ loader = DataLoader (dataset , batch_size = 8 )
463
+ losses_by_step = []
464
+ num_iterations = 20
465
+
466
+ for iter_idx in range (num_iterations ):
467
+ losses_by_micro_batch = []
468
+ for i , (data , label ) in enumerate (loader ):
469
+ schedule .step (data , target = label , losses = losses_by_micro_batch )
470
+ if self .rank == 3 :
471
+ losses_by_step .append (
472
+ np .array (losses_by_micro_batch , dtype = np .float32 ).mean ()
473
+ )
474
+ opt .step ()
475
+ opt .clear_grad ()
476
+ return losses_by_step
477
+
417
478
def test_dp_pp_align_mode (self ):
418
479
fix_seeds ()
419
480
paddle .set_flags ({'FLAGS_enable_auto_parallel_align_mode' : True })
@@ -490,6 +551,12 @@ def run_test(self):
490
551
scheduleFThenB_losses = self .test_ScheduleFThenB ()
491
552
schedule1f1b_losses = self .test_Schedule1F1B ()
492
553
schedulevpp_losses = self .test_ScheduleVPP ()
554
+ pp_model_with_ClipGradByGlobalNorm_losses = (
555
+ self .test_pp_model_with_ClipGradByGlobalNorm ()
556
+ )
557
+ scheduleFThenB_with_ClipGradByGlobalNorm_losses = (
558
+ self .test_ScheduleFThenB_with_ClipGradByGlobalNorm ()
559
+ )
493
560
dp_pp_losses , dp_pp_losses_md5sum = self .test_dp_pp ()
494
561
dp_pp_align_mode_losses , dp_pp_align_mode_losses_md5sum = (
495
562
self .test_dp_pp_align_mode ()
@@ -520,6 +587,12 @@ def run_test(self):
520
587
rtol = 1e-5 ,
521
588
)
522
589
590
+ np .testing .assert_allclose (
591
+ pp_model_with_ClipGradByGlobalNorm_losses ,
592
+ scheduleFThenB_with_ClipGradByGlobalNorm_losses ,
593
+ rtol = 1e-5 ,
594
+ )
595
+
523
596
np .testing .assert_allclose (
524
597
dp_pp_align_mode_losses ,
525
598
dp_pp_losses ,
0 commit comments