Skip to content

Commit ed1d3e1

Browse files
authored
Merge pull request #283 from kozistr/refactor/get-model-parameters
[Fix] when `model_or_parameter` is not `nn.Module` instance.
2 parents 769e5fb + 23adc86 commit ed1d3e1

File tree

6 files changed

+117
-67
lines changed

6 files changed

+117
-67
lines changed

docs/changelogs/v3.2.0.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,12 @@
88
* `bnb_ademamix8bit`, `bnb_ademamix32bit`, `bnb_paged_ademamix8bit`, `bnb_paged_ademamix32bit`
99
* Support 8/4bit, fp8 optimizers. (#208, #281)
1010
* `torchao_adamw8bit`, `torchao_adamw4bit`, `torchao_adamwfp8`.
11+
* Support a module-name-level (e.g. `LayerNorm`) weight decay exclusion for `get_optimizer_parameters`. (#282, #283)
1112

1213
### Bug
1314

1415
* Fix `should_grokfast` condition when initialization. (#279, #280)
16+
17+
### Contributions
18+
19+
thanks to @Vectorrent

poetry.lock

Lines changed: 75 additions & 38 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pytorch_optimizer/optimizer/utils.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import math
22
import warnings
33
from importlib.util import find_spec
4-
from typing import Callable, Dict, List, Optional, Tuple, Union
4+
from typing import Callable, Dict, List, Optional, Set, Tuple, Union
55

66
import numpy as np
77
import torch
@@ -198,43 +198,45 @@ def get_optimizer_parameters(
198198
weight_decay: float,
199199
wd_ban_list: List[str] = ('bias', 'LayerNorm.bias', 'LayerNorm.weight'),
200200
) -> PARAMETERS:
201-
r"""
202-
Get optimizer parameters while filtering specified modules.
201+
r"""Get optimizer parameters while filtering specified modules.
202+
203+
Notice that, You can also ban by a module name level (e.g. LayerNorm) if you pass nn.Module instance. You just only
204+
need to input `LayerNorm` to exclude weight decay from the layer norm layer(s).
205+
203206
:param model_or_parameter: Union[nn.Module, List]. model or parameters.
204207
:param weight_decay: float. weight_decay.
205208
:param wd_ban_list: List[str]. ban list not to set weight decay.
206209
:returns: PARAMETERS. new parameter list.
207210
"""
208-
209-
210-
fully_qualified_names = []
211-
for module_name, module in model_or_parameter.named_modules():
212-
for param_name, _param in module.named_parameters(recurse=False):
213-
# Full parameter name includes module and parameter names
214-
full_param_name = f'{module_name}.{param_name}' if module_name else param_name
215-
# Check if any ban list substring is in the parameter name or module name
216-
if (
217-
any(banned in param_name for banned in wd_ban_list)
218-
or any(banned in module_name for banned in wd_ban_list)
219-
or any(banned in module._get_name() for banned in wd_ban_list)
220-
):
221-
fully_qualified_names.append(full_param_name)
211+
banned_parameter_patterns: Set[str] = set()
222212

223213
if isinstance(model_or_parameter, nn.Module):
214+
for module_name, module in model_or_parameter.named_modules():
215+
for param_name, _ in module.named_parameters(recurse=False):
216+
full_param_name: str = f'{module_name}.{param_name}' if module_name else param_name
217+
if any(
218+
banned in pattern for banned in wd_ban_list for pattern in (full_param_name, module._get_name())
219+
):
220+
banned_parameter_patterns.add(full_param_name)
221+
224222
model_or_parameter = list(model_or_parameter.named_parameters())
223+
else:
224+
banned_parameter_patterns.update(wd_ban_list)
225225

226226
return [
227227
{
228228
'params': [
229229
p
230230
for n, p in model_or_parameter
231-
if p.requires_grad and not any(nd in n for nd in fully_qualified_names)
231+
if p.requires_grad and not any(nd in n for nd in banned_parameter_patterns)
232232
],
233233
'weight_decay': weight_decay,
234234
},
235235
{
236236
'params': [
237-
p for n, p in model_or_parameter if p.requires_grad and any(nd in n for nd in fully_qualified_names)
237+
p
238+
for n, p in model_or_parameter
239+
if p.requires_grad and any(nd in n for nd in banned_parameter_patterns)
238240
],
239241
'weight_decay': 0.0,
240242
},

requirements-dev.txt

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@ colorama==0.4.6 ; python_version >= "3.8" and (sys_platform == "win32" or platfo
66
coverage[toml]==7.6.1 ; python_version >= "3.8"
77
exceptiongroup==1.2.2 ; python_version < "3.11" and python_version >= "3.8"
88
filelock==3.16.1 ; python_version >= "3.8"
9-
fsspec==2024.9.0 ; python_version >= "3.8"
9+
fsspec==2024.10.0 ; python_version >= "3.8"
1010
iniconfig==2.0.0 ; python_version >= "3.8"
1111
isort==5.13.2 ; python_version >= "3.8"
1212
jinja2==3.1.4 ; python_version >= "3.8"
1313
markupsafe==2.1.5 ; python_version >= "3.8"
14-
mpmath==1.3.0 ; python_version >= "3.8"
14+
mpmath==1.3.0 ; python_version >= "3.9" or python_version == "3.8"
1515
mypy-extensions==1.0.0 ; python_version >= "3.8"
1616
networkx==3.1 ; python_version >= "3.8"
1717
numpy==1.24.4 ; python_version < "3.9" and python_version >= "3.8"
@@ -22,8 +22,10 @@ platformdirs==4.3.6 ; python_version >= "3.8"
2222
pluggy==1.5.0 ; python_version >= "3.8"
2323
pytest-cov==5.0.0 ; python_version >= "3.8"
2424
pytest==8.3.3 ; python_version >= "3.8"
25-
ruff==0.6.9 ; python_version >= "3.8"
26-
sympy==1.13.3 ; python_version >= "3.8"
25+
ruff==0.7.0 ; python_version >= "3.8"
26+
setuptools==75.2.0 ; python_version >= "3.12"
27+
sympy==1.12.1 ; python_version == "3.8"
28+
sympy==1.13.1 ; python_version >= "3.9"
2729
tomli==2.0.2 ; python_full_version <= "3.11.0a6" and python_version >= "3.8"
28-
torch==2.4.1+cpu ; python_version >= "3.8"
30+
torch==2.5.0+cpu ; python_version >= "3.8"
2931
typing-extensions==4.12.2 ; python_version >= "3.8"

requirements.txt

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
--extra-index-url https://download.pytorch.org/whl/cpu
22

33
filelock==3.16.1 ; python_version >= "3.8"
4-
fsspec==2024.9.0 ; python_version >= "3.8"
4+
fsspec==2024.10.0 ; python_version >= "3.8"
55
jinja2==3.1.4 ; python_version >= "3.8"
66
markupsafe==2.1.5 ; python_version >= "3.8"
7-
mpmath==1.3.0 ; python_version >= "3.8"
7+
mpmath==1.3.0 ; python_version >= "3.9" or python_version == "3.8"
88
networkx==3.1 ; python_version >= "3.8"
99
numpy==1.24.4 ; python_version < "3.9" and python_version >= "3.8"
1010
numpy==2.0.2 ; python_version >= "3.9"
11-
sympy==1.13.3 ; python_version >= "3.8"
12-
torch==2.4.1+cpu ; python_version >= "3.8"
11+
setuptools==75.2.0 ; python_version >= "3.12"
12+
sympy==1.12.1 ; python_version == "3.8"
13+
sympy==1.13.1 ; python_version >= "3.9"
14+
torch==2.5.0+cpu ; python_version >= "3.8"
1315
typing-extensions==4.12.2 ; python_version >= "3.8"

tests/test_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,11 +101,13 @@ def test_get_optimizer_parameters():
101101
wd_ban_list: List[str] = ['bias', 'LayerNorm.bias', 'LayerNorm.weight', 'LayerNorm']
102102

103103
before_parameters = list(model.named_parameters())
104+
105+
_ = get_optimizer_parameters(before_parameters, weight_decay=1e-3, wd_ban_list=wd_ban_list)
104106
after_parameters = get_optimizer_parameters(model, weight_decay=1e-3, wd_ban_list=wd_ban_list)
105107

106108
for before, after in zip(before_parameters, after_parameters):
107109
layer_name: str = before[0]
108-
if layer_name.find('bias') != -1 or layer_name in wd_ban_list:
110+
if layer_name.find('bias') != -1 or layer_name.find('LayerNorm') != -1:
109111
assert after['weight_decay'] == 0.0
110112

111113

0 commit comments

Comments
 (0)