Skip to content

Commit 4f8d5a8

Browse files
zty-kingmaxiaolong001
authored andcommitted
[AutoParallel] fix the grad_clip logic of auto_hybrid_pp (PaddlePaddle#74409)
* fix the grad clip performance * add test * empty commit to rerun CI * modify the note * Simplify code logic
1 parent 5a7e35a commit 4f8d5a8

File tree

2 files changed

+106
-0
lines changed

2 files changed

+106
-0
lines changed

python/paddle/nn/clip.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -717,6 +717,7 @@ def _dygraph_clip(self, params_grads):
717717
sum_square_list = []
718718
sum_square_list_fp16 = []
719719
sum_square_list_fp32 = []
720+
flag_auto_hybrid_pp = True # Determine whether to use the new dynamic graph semi-automatic parallel pp framework
720721
if len(params_grads) > 0 and len(params_grads[0]) > 0:
721722
src_mesh = params_grads[0][0].process_mesh
722723
else:
@@ -742,6 +743,7 @@ def _dygraph_clip(self, params_grads):
742743
# if the gradient mesh is not equal to src mesh
743744
# do reshard to get the result of squared_l2 from other pp stage mesh
744745
if src_mesh is not None and g.process_mesh != src_mesh:
746+
flag_auto_hybrid_pp = False
745747
pp_mesh = get_complete_pp_mesh(g.process_mesh)
746748
if set(g.process_mesh.process_ids) < set(pp_mesh.process_ids):
747749
sum_square = dist.reshard(
@@ -791,6 +793,37 @@ def async_add_n(var_list):
791793

792794
global_norm_var = async_add_n(global_norm_var)
793795

796+
# NOTE(zhengtianyu): Fix grad_clip in auto_hybrid_pp mode.
797+
# Reason: In auto_hybrid_pp mode, each rank only keeps local parameters and gradient information,
798+
# so global_norm_var is in a partial state, leading to incorrect calculation.
799+
# Reference dynamic manual-parallel: Each rank computes local global_norm_var,
800+
# then performs pp group communication reduce(sum) to get correct global_norm_var.
801+
# For complete alignment with old dygraph semi-auto parallel PP logic,
802+
# refer to NOTE: align ClipGradByGlobalNorm in auto_parallel_align_mode
803+
if flag_auto_hybrid_pp and src_mesh is not None:
804+
g_mesh = dist.get_mesh()
805+
if (
806+
g_mesh
807+
and "pp" in g_mesh.dim_names
808+
and g_mesh.get_dim_size("pp") > 1
809+
):
810+
# Get the pipeline parallelism subgroup for communication
811+
pp_group = g_mesh.get_submesh_with_dim("pp").get_group("pp")
812+
813+
# Perform all-reduce on the local tensor value across the PP group
814+
global_norm_var_local = global_norm_var._local_value()
815+
dist.all_reduce(
816+
global_norm_var_local,
817+
op=dist.ReduceOp.SUM,
818+
group=pp_group,
819+
)
820+
821+
global_norm_var = dist.shard_tensor(
822+
global_norm_var_local,
823+
global_norm_var.process_mesh,
824+
global_norm_var.placements,
825+
)
826+
794827
if self.should_comm_on_shard_dim and hasattr(self, 'sharding_group'):
795828
paddle.distributed.all_reduce(
796829
global_norm_var._local_value(), group=self.sharding_group

test/auto_parallel/PP_Schedules_demo.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,67 @@ def test_dp_pp(self):
414414
opt.clear_grad()
415415
return losses_by_step, all_losses_in_one_step_md5sum
416416

417+
def test_pp_model_with_ClipGradByGlobalNorm(self):
418+
"""Test pipeline parallel model with ClipGradByGlobalNorm using PPMyModel as the baseline"""
419+
fix_seeds()
420+
pp_model = PPMyModel()
421+
opt = paddle.optimizer.AdamW(
422+
learning_rate=0.001,
423+
parameters=pp_model.parameters(),
424+
grad_clip=paddle.nn.ClipGradByGlobalNorm(1.0),
425+
)
426+
loss_fn = nn.MSELoss()
427+
dataset = RandomDataset(image_size=8, output_size=8, num_samples=8)
428+
loader = DataLoader(dataset, batch_size=1)
429+
pp_losses_step = []
430+
num_iterations = 20
431+
432+
for iter_idx in range(num_iterations):
433+
pp_losses_micro_batch = []
434+
for i, (data, label) in enumerate(loader):
435+
output = pp_model(data)
436+
loss = loss_fn(output, label)
437+
pp_losses_micro_batch.append(loss.item())
438+
loss.backward()
439+
pp_losses_step.append(
440+
np.array(pp_losses_micro_batch, dtype=np.float32).mean()
441+
)
442+
opt.step()
443+
opt.clear_grad()
444+
return pp_losses_step
445+
446+
def test_ScheduleFThenB_with_ClipGradByGlobalNorm(self):
447+
fix_seeds()
448+
self.model = PPMyModel_SingleStage()
449+
self.micro_batches = 8
450+
self.stage = PipelineStage(self.model, self.rank, 4, group=self.group)
451+
self.stage.has_backward = True
452+
loss_fn_ = nn.MSELoss()
453+
schedule = ScheduleFThenB(
454+
self.stage, self.micro_batches, loss_fn=loss_fn_
455+
)
456+
opt = paddle.optimizer.AdamW(
457+
learning_rate=0.001,
458+
parameters=self.model.parameters(),
459+
grad_clip=paddle.nn.ClipGradByGlobalNorm(1.0),
460+
)
461+
dataset = RandomDataset(image_size=8, output_size=8, num_samples=8)
462+
loader = DataLoader(dataset, batch_size=8)
463+
losses_by_step = []
464+
num_iterations = 20
465+
466+
for iter_idx in range(num_iterations):
467+
losses_by_micro_batch = []
468+
for i, (data, label) in enumerate(loader):
469+
schedule.step(data, target=label, losses=losses_by_micro_batch)
470+
if self.rank == 3:
471+
losses_by_step.append(
472+
np.array(losses_by_micro_batch, dtype=np.float32).mean()
473+
)
474+
opt.step()
475+
opt.clear_grad()
476+
return losses_by_step
477+
417478
def test_dp_pp_align_mode(self):
418479
fix_seeds()
419480
paddle.set_flags({'FLAGS_enable_auto_parallel_align_mode': True})
@@ -490,6 +551,12 @@ def run_test(self):
490551
scheduleFThenB_losses = self.test_ScheduleFThenB()
491552
schedule1f1b_losses = self.test_Schedule1F1B()
492553
schedulevpp_losses = self.test_ScheduleVPP()
554+
pp_model_with_ClipGradByGlobalNorm_losses = (
555+
self.test_pp_model_with_ClipGradByGlobalNorm()
556+
)
557+
scheduleFThenB_with_ClipGradByGlobalNorm_losses = (
558+
self.test_ScheduleFThenB_with_ClipGradByGlobalNorm()
559+
)
493560
dp_pp_losses, dp_pp_losses_md5sum = self.test_dp_pp()
494561
dp_pp_align_mode_losses, dp_pp_align_mode_losses_md5sum = (
495562
self.test_dp_pp_align_mode()
@@ -520,6 +587,12 @@ def run_test(self):
520587
rtol=1e-5,
521588
)
522589

590+
np.testing.assert_allclose(
591+
pp_model_with_ClipGradByGlobalNorm_losses,
592+
scheduleFThenB_with_ClipGradByGlobalNorm_losses,
593+
rtol=1e-5,
594+
)
595+
523596
np.testing.assert_allclose(
524597
dp_pp_align_mode_losses,
525598
dp_pp_losses,

0 commit comments

Comments
 (0)