Skip to content

Commit 769e5fb

Browse files
authored
Merge pull request #282 from Vectorrent/fix-weight-decay-banning
[Fix] Implement better `wd_ban_list` handling
2 parents 20ed84f + 546531c commit 769e5fb

File tree

2 files changed

+26
-5
lines changed

2 files changed

+26
-5
lines changed

pytorch_optimizer/optimizer/utils.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -198,23 +198,44 @@ def get_optimizer_parameters(
198198
weight_decay: float,
199199
wd_ban_list: List[str] = ('bias', 'LayerNorm.bias', 'LayerNorm.weight'),
200200
) -> PARAMETERS:
201-
r"""Get optimizer parameters while filtering specified modules.
202-
201+
r"""
202+
Get optimizer parameters while filtering specified modules.
203203
:param model_or_parameter: Union[nn.Module, List]. model or parameters.
204204
:param weight_decay: float. weight_decay.
205205
:param wd_ban_list: List[str]. ban list not to set weight decay.
206206
:returns: PARAMETERS. new parameter list.
207207
"""
208+
209+
210+
fully_qualified_names = []
211+
for module_name, module in model_or_parameter.named_modules():
212+
for param_name, _param in module.named_parameters(recurse=False):
213+
# Full parameter name includes module and parameter names
214+
full_param_name = f'{module_name}.{param_name}' if module_name else param_name
215+
# Check if any ban list substring is in the parameter name or module name
216+
if (
217+
any(banned in param_name for banned in wd_ban_list)
218+
or any(banned in module_name for banned in wd_ban_list)
219+
or any(banned in module._get_name() for banned in wd_ban_list)
220+
):
221+
fully_qualified_names.append(full_param_name)
222+
208223
if isinstance(model_or_parameter, nn.Module):
209224
model_or_parameter = list(model_or_parameter.named_parameters())
210225

211226
return [
212227
{
213-
'params': [p for n, p in model_or_parameter if p.requires_grad and not any(nd in n for nd in wd_ban_list)],
228+
'params': [
229+
p
230+
for n, p in model_or_parameter
231+
if p.requires_grad and not any(nd in n for nd in fully_qualified_names)
232+
],
214233
'weight_decay': weight_decay,
215234
},
216235
{
217-
'params': [p for n, p in model_or_parameter if p.requires_grad and any(nd in n for nd in wd_ban_list)],
236+
'params': [
237+
p for n, p in model_or_parameter if p.requires_grad and any(nd in n for nd in fully_qualified_names)
238+
],
218239
'weight_decay': 0.0,
219240
},
220241
]

tests/test_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def test_neuron_mean_norm():
9898

9999
def test_get_optimizer_parameters():
100100
model: nn.Module = Example()
101-
wd_ban_list: List[str] = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
101+
wd_ban_list: List[str] = ['bias', 'LayerNorm.bias', 'LayerNorm.weight', 'LayerNorm']
102102

103103
before_parameters = list(model.named_parameters())
104104
after_parameters = get_optimizer_parameters(model, weight_decay=1e-3, wd_ban_list=wd_ban_list)

0 commit comments

Comments
 (0)