@@ -18,9 +18,9 @@ def has_overflow(grad_norm: torch.Tensor) -> bool:
1818
1919def normalize_gradient (x : torch .Tensor , use_channels : bool = False , epsilon : float = 1e-8 ) -> torch .Tensor :
2020 """normalize gradient with stddev
21- :param x: torch.Tensor. gradient.
22- :param use_channels: bool. channel-wise normalization.
23- :param epsilon: float. eps.
21+ :param x: torch.Tensor. gradient
22+ :param use_channels: bool. channel-wise normalization
23+ :param epsilon: float. eps
2424 :return: torch.Tensor. normalized gradient.
2525 """
2626 size : int = x .dim ()
@@ -36,12 +36,12 @@ def normalize_gradient(x: torch.Tensor, use_channels: bool = False, epsilon: flo
3636def clip_grad_norm (parameters : PARAMETERS , max_norm : float = 0 , sync : bool = False ) -> torch .Tensor :
3737 """Clips grad norms.
3838 During combination with FSDP, will also ensure that grad norms are aggregated
39- across all workers, since each worker only stores their shard of the gradients.
39+ across all workers, since each worker only stores their shard of the gradients
4040 :param parameters: Parameters whose gradients we wish to clip
4141 :param max_norm: Maximum norm we wish the gradients to have. If non-positive, then
42- we will not perform clipping.
42+ we will not perform clipping
4343 :param sync: Boolean indicating whether we should aggregate across the distributed
44- group. Used only in combination with FSDP.
44+ group. Used only in combination with FSDP
4545 :returns: The gradient norm across all parameters, before clipping.
4646 """
4747 if isinstance (parameters , torch .Tensor ):
0 commit comments