Skip to content

Commit 7599d81

Browse files
committed
refactor: unit_norm
1 parent 433939f commit 7599d81

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

pytorch_optimizer/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def normalize_gradient(x: torch.Tensor, use_channels: bool = False, epsilon: flo
2626
return x
2727

2828

29-
def unit_norm(x: torch.Tensor) -> torch.Tensor:
29+
def unit_norm(x: torch.Tensor, norm: float = 2.0) -> torch.Tensor:
3030
keep_dim: bool = True
3131
dim: Optional[Union[int, Tuple[int, ...]]] = None
3232

@@ -40,4 +40,4 @@ def unit_norm(x: torch.Tensor) -> torch.Tensor:
4040
else:
4141
dim = tuple(range(1, x_len))
4242

43-
return x.norm(dim=dim, keepdim=keep_dim, p=2.0)
43+
return x.norm(dim=dim, keepdim=keep_dim, p=norm)

0 commit comments

Comments
 (0)