Skip to content

Commit 273fe4a

Browse files
authored
[Distributed] fix release grad on moe model (#74972)
1 parent 46cdd05 commit 273fe4a

File tree

2 files changed

+13
-9
lines changed

2 files changed

+13
-9
lines changed

python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,13 @@ def reduce_gradients(self, parameter_list, hcg):
348348
with framework.no_grad():
349349
for param in parameter_list:
350350
g_var = self._get_param_grad(param)
351+
if g_var is None:
352+
if hasattr(param, "main_grad"):
353+
g_var = paddle.zeros_like(param, dtype=paddle.float32)
354+
param.main_grad = g_var
355+
else:
356+
g_var = paddle.zeros_like(param, dtype=param.dtype)
357+
param.grad = g_var
351358
if g_var is not None:
352359
reduce_op = ReduceOp.AVG
353360
if not self.use_reduce_avg:

python/paddle/distributed/fleet/utils/tensor_fusion_helper.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -619,22 +619,19 @@ def _copy_grad_to_buffer(self, param):
619619
)
620620

621621
grad_var = param.main_grad if self.use_main_grad else param.grad
622-
assert grad_var is not None, (
623-
f"The current parameter[{param.name}] has no gradient, its stop_grdient is {param.stop_gradient}"
624-
)
625-
grad_var.stop_gradient = True
626-
grad_var.flatten_()
627622

628-
tmp_var.add_(grad_var)
629-
tmp_var.get_tensor()._set_dims(param.shape)
623+
if grad_var is not None:
624+
grad_var.stop_gradient = True
625+
grad_var.flatten_()
626+
tmp_var.add_(grad_var)
627+
grad_var._clear()
630628

629+
tmp_var.get_tensor()._set_dims(param.shape)
631630
if self.use_main_grad:
632-
param.main_grad._clear()
633631
if not self._free_grads_in_comm:
634632
param.main_grad = tmp_var
635633
param.main_grad.name = "main_grad@" + param.name
636634
else:
637-
param.grad._clear()
638635
if not self._free_grads_in_comm:
639636
param._copy_gradient_from(tmp_var)
640637

0 commit comments

Comments
 (0)