Skip to content

Commit 6eb1d6d

Browse files
committed
refactor: clip_grad_norm
1 parent e13b7e0 commit 6eb1d6d

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

pytorch_optimizer/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def normalize_gradient(x: torch.Tensor, use_channels: bool = False, epsilon: flo
3434
return x
3535

3636

37-
def clip_grad_norm(parameters: PARAMETERS, max_norm: float = 0, sync: bool = False) -> torch.Tensor:
37+
def clip_grad_norm(parameters: PARAMETERS, max_norm: float = 0, sync: bool = False) -> Union[torch.Tensor, float]:
3838
"""Clips grad norms.
3939
During combination with FSDP, will also ensure that grad norms are aggregated
4040
across all workers, since each worker only stores their shard of the gradients

0 commit comments

Comments
 (0)