We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent e13b7e0 commit 6eb1d6dCopy full SHA for 6eb1d6d
pytorch_optimizer/utils.py
@@ -34,7 +34,7 @@ def normalize_gradient(x: torch.Tensor, use_channels: bool = False, epsilon: flo
34
return x
35
36
37
-def clip_grad_norm(parameters: PARAMETERS, max_norm: float = 0, sync: bool = False) -> torch.Tensor:
+def clip_grad_norm(parameters: PARAMETERS, max_norm: float = 0, sync: bool = False) -> Union[torch.Tensor, float]:
38
"""Clips grad norms.
39
During combination with FSDP, will also ensure that grad norms are aggregated
40
across all workers, since each worker only stores their shard of the gradients
0 commit comments