Skip to content

Commit 2d4c3c6

Browse files
authored
Merge pull request #35 from kozistr/feature/wd-utils
[Feature] weight decay ban list
2 parents 49d3937 + b8e71b1 commit 2d4c3c6

File tree

3 files changed

+19
-3
lines changed

3 files changed

+19
-3
lines changed

lint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def get_configuration() -> Namespace:
1414
parser.add_argument(
1515
'-t',
1616
'--threshold',
17-
default=9.75,
17+
default=9.9,
1818
type=float,
1919
)
2020

pytorch_optimizer/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,6 @@
1616
from pytorch_optimizer.ranger21 import Ranger21
1717
from pytorch_optimizer.sam import SAM
1818
from pytorch_optimizer.sgdp import SGDP
19+
from pytorch_optimizer.utils import get_optimizer_parameters, normalize_gradient, unit_norm
1920

20-
__VERSION__ = '0.1.0'
21+
__VERSION__ = '0.1.1'

pytorch_optimizer/utils.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from typing import Optional, Tuple, Union
1+
from typing import List, Optional, Tuple, Union
22

33
import torch
4+
from torch import nn
45

56
from pytorch_optimizer.types import PARAMETERS
67

@@ -41,3 +42,17 @@ def unit_norm(x: torch.Tensor, norm: float = 2.0) -> torch.Tensor:
4142
dim = tuple(range(1, x_len))
4243

4344
return x.norm(dim=dim, keepdim=keep_dim, p=norm)
45+
46+
47+
def get_optimizer_parameters(
48+
model: nn.Module, weight_decay: float, wd_ban_list: List[str] = ('bias', 'LayerNorm.bias', 'LayerNorm.weight')
49+
) -> PARAMETERS:
50+
param_optimizer: List[Tuple[str, nn.Parameter]] = list(model.named_parameters())
51+
52+
return [
53+
{
54+
'params': [p for n, p in param_optimizer if not any(nd in n for nd in wd_ban_list)],
55+
'weight_decay': weight_decay,
56+
},
57+
{'params': [p for n, p in param_optimizer if any(nd in n for nd in wd_ban_list)], 'weight_decay': 0.0},
58+
]

0 commit comments

Comments
 (0)