Skip to content

Commit 16aeb2c

Browse files
authored
Merge pull request #49 from kozistr/test/utils
[Test] Add test cases for utils
2 parents 2162e68 + fa6d4b1 commit 16aeb2c

File tree

3 files changed

+81
-3
lines changed

3 files changed

+81
-3
lines changed

pytorch_optimizer/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import math
12
from typing import List, Optional, Tuple, Union
23

34
import torch
@@ -33,7 +34,7 @@ def normalize_gradient(x: torch.Tensor, use_channels: bool = False, epsilon: flo
3334
return x
3435

3536

36-
def clip_grad_norm(parameters: PARAMETERS, max_norm: float = 0, sync: bool = False) -> torch.Tensor:
37+
def clip_grad_norm(parameters: PARAMETERS, max_norm: float = 0, sync: bool = False) -> Union[torch.Tensor, float]:
3738
"""Clips grad norms.
3839
During combination with FSDP, will also ensure that grad norms are aggregated
3940
across all workers, since each worker only stores their shard of the gradients
@@ -59,7 +60,7 @@ def clip_grad_norm(parameters: PARAMETERS, max_norm: float = 0, sync: bool = Fal
5960
# also need to get the norms from all the other sharded works in FSDP
6061
all_reduce(norm_sq)
6162

62-
grad_norm = norm_sq.sqrt()
63+
grad_norm = math.sqrt(norm_sq)
6364
if max_norm > 0:
6465
clip_coef = max_norm / (grad_norm + 1e-6)
6566
for p in parameters:

pytorch_optimizer/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__VERSION__ = '0.3.5'
1+
__VERSION__ = '0.3.6'

tests/test_utils.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
from typing import List
2+
3+
import numpy as np
4+
import torch
5+
from torch import nn
6+
7+
from pytorch_optimizer.utils import (
8+
clip_grad_norm,
9+
get_optimizer_parameters,
10+
has_overflow,
11+
normalize_gradient,
12+
unit_norm,
13+
)
14+
15+
16+
class Example(nn.Module):
17+
def __init__(self):
18+
super().__init__()
19+
self.fc1 = nn.Linear(1, 1)
20+
self.norm1 = nn.LayerNorm(1)
21+
22+
def forward(self, x: torch.Tensor) -> torch.Tensor:
23+
return self.norm1(self.fc1(x))
24+
25+
26+
def test_has_overflow():
27+
assert has_overflow(np.inf)
28+
assert has_overflow(np.nan)
29+
assert not has_overflow(torch.Tensor([1]))
30+
31+
32+
def test_normalized_gradient():
33+
x = torch.arange(0, 10, dtype=torch.float32)
34+
35+
np.testing.assert_allclose(
36+
normalize_gradient(x).numpy(),
37+
np.asarray([0.0000, 0.3303, 0.6606, 0.9909, 1.3212, 1.6514, 1.9817, 2.3120, 2.6423, 2.9726]),
38+
rtol=1e-4,
39+
atol=1e-4,
40+
)
41+
42+
np.testing.assert_allclose(
43+
normalize_gradient(x.view(1, 10), use_channels=True).numpy(),
44+
np.asarray([[0.0000, 0.3303, 0.6606, 0.9909, 1.3212, 1.6514, 1.9817, 2.3120, 2.6423, 2.9726]]),
45+
rtol=1e-4,
46+
atol=1e-4,
47+
)
48+
49+
50+
def test_clip_grad_norm():
51+
x = torch.arange(0, 10, dtype=torch.float32, requires_grad=True)
52+
x.grad = torch.arange(0, 10, dtype=torch.float32)
53+
54+
np.testing.assert_approx_equal(clip_grad_norm(x), 16.881943016134134, significant=4)
55+
np.testing.assert_approx_equal(clip_grad_norm(x, max_norm=2), 16.881943016134134, significant=4)
56+
57+
58+
def test_unit_norm():
59+
x = torch.arange(0, 10, dtype=torch.float32)
60+
61+
np.testing.assert_approx_equal(unit_norm(x).numpy(), 16.8819, significant=4)
62+
np.testing.assert_approx_equal(unit_norm(x.view(1, 10)).numpy(), 16.8819, significant=4)
63+
np.testing.assert_approx_equal(unit_norm(x.view(1, 10, 1, 1)).numpy(), 16.8819, significant=4)
64+
np.testing.assert_approx_equal(unit_norm(x.view(1, 10, 1, 1, 1, 1)).numpy(), 16.8819, significant=4)
65+
66+
67+
def test_get_optimizer_parameters():
68+
model: nn.Module = Example()
69+
wd_ban_list: List[str] = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
70+
71+
before_parameters = list(model.named_parameters())
72+
after_parameters = get_optimizer_parameters(model, weight_decay=1e-3, wd_ban_list=wd_ban_list)
73+
74+
for before, after in zip(before_parameters, after_parameters):
75+
layer_name: str = before[0]
76+
if layer_name.find('bias') != -1 or layer_name in wd_ban_list:
77+
assert after['weight_decay'] == 0.0

0 commit comments

Comments
 (0)