Skip to content

Commit fa6d4b1

Browse files
committed
update: test_get_optimizer_parameters
1 parent a0950f3 commit fa6d4b1

File tree

1 file changed

+33
-1
lines changed

1 file changed

+33
-1
lines changed

tests/test_utils.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,26 @@
1+
from typing import List
2+
13
import numpy as np
24
import torch
5+
from torch import nn
6+
7+
from pytorch_optimizer.utils import (
8+
clip_grad_norm,
9+
get_optimizer_parameters,
10+
has_overflow,
11+
normalize_gradient,
12+
unit_norm,
13+
)
14+
15+
16+
class Example(nn.Module):
17+
def __init__(self):
18+
super().__init__()
19+
self.fc1 = nn.Linear(1, 1)
20+
self.norm1 = nn.LayerNorm(1)
321

4-
from pytorch_optimizer.utils import clip_grad_norm, has_overflow, normalize_gradient, unit_norm
22+
def forward(self, x: torch.Tensor) -> torch.Tensor:
23+
return self.norm1(self.fc1(x))
524

625

726
def test_has_overflow():
@@ -43,3 +62,16 @@ def test_unit_norm():
4362
np.testing.assert_approx_equal(unit_norm(x.view(1, 10)).numpy(), 16.8819, significant=4)
4463
np.testing.assert_approx_equal(unit_norm(x.view(1, 10, 1, 1)).numpy(), 16.8819, significant=4)
4564
np.testing.assert_approx_equal(unit_norm(x.view(1, 10, 1, 1, 1, 1)).numpy(), 16.8819, significant=4)
65+
66+
67+
def test_get_optimizer_parameters():
68+
model: nn.Module = Example()
69+
wd_ban_list: List[str] = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
70+
71+
before_parameters = list(model.named_parameters())
72+
after_parameters = get_optimizer_parameters(model, weight_decay=1e-3, wd_ban_list=wd_ban_list)
73+
74+
for before, after in zip(before_parameters, after_parameters):
75+
layer_name: str = before[0]
76+
if layer_name.find('bias') != -1 or layer_name in wd_ban_list:
77+
assert after['weight_decay'] == 0.0

0 commit comments

Comments
 (0)