Skip to content

Commit 64e54d5

Browse files
wconstabpytorchmergebot
authored andcommitted
[Pipelining] Relax scale_grads assert (pytorch#145010)
The assert felt morally valid- if no gradients are scaled, then something is definitely wrong with the setup. In one instance, PP + optimizer-in-backward (in torchtitan) resulted in grad=None after running .backward() and before scaling grads. On the other hand, the existing assert is too restrictive. It's possible that a model used with pipelining would have some parameters that do not receieve gradients, and we shouldn't hard-error in these cases. (E.g. if the parameter is literally not used, or is frozen). In the extreme case, the whole stage could be frozen. So we do not complain if no grads are scaled. Pull Request resolved: pytorch#145010 Approved by: https://github.com/mori360, https://github.com/tianyu-l
1 parent 07e2365 commit 64e54d5

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

torch/distributed/pipelining/stage.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -586,9 +586,9 @@ def scale_grads(self, grad_scale_factor: int) -> None:
586586
# PP scales only for its own contribution (microbatches), but relies on DP to scale further
587587
# for DP degree.
588588
if grad_scale_factor != 1:
589-
for name, p in self.submod.named_parameters():
590-
assert p.grad is not None, name
591-
p.grad.div_(grad_scale_factor)
589+
for p in self.submod.parameters():
590+
if p.grad is not None:
591+
p.grad.div_(grad_scale_factor)
592592

593593
def backward_maybe_with_nosync(
594594
self,

0 commit comments

Comments
 (0)