Skip to content

Commit 905fca7

Browse files
committed
implement better logic for detecting weights/modules
1 parent 20ed84f commit 905fca7

File tree

2 files changed

+29
-4
lines changed

2 files changed

+29
-4
lines changed

pytorch_optimizer/optimizer/utils.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -205,16 +205,41 @@ def get_optimizer_parameters(
205205
:param wd_ban_list: List[str]. ban list not to set weight decay.
206206
:returns: PARAMETERS. new parameter list.
207207
"""
208+
209+
def find_fully_qualified_names(
210+
model: nn.Module,
211+
wd_ban_list: List[str] = ("bias", "LayerNorm.weight", "LayerNorm.bias"),
212+
):
213+
names_without_wd = []
214+
215+
for module_name, module in model.named_modules():
216+
for param_name, param in module.named_parameters(recurse=False):
217+
# Full parameter name includes module and parameter names
218+
full_param_name = (
219+
f"{module_name}.{param_name}" if module_name else param_name
220+
)
221+
# Check if any ban list substring is in the parameter name or module name
222+
if (
223+
any(banned in param_name for banned in wd_ban_list)
224+
or any(banned in module_name for banned in wd_ban_list)
225+
or any(banned in module._get_name() for banned in wd_ban_list)
226+
):
227+
names_without_wd.append(full_param_name)
228+
229+
return names_without_wd
230+
231+
full_names = find_fully_qualified_names(model_or_parameter, wd_ban_list)
232+
208233
if isinstance(model_or_parameter, nn.Module):
209234
model_or_parameter = list(model_or_parameter.named_parameters())
210-
235+
211236
return [
212237
{
213-
'params': [p for n, p in model_or_parameter if p.requires_grad and not any(nd in n for nd in wd_ban_list)],
238+
'params': [p for n, p in model_or_parameter if p.requires_grad and not any(nd in n for nd in full_names)],
214239
'weight_decay': weight_decay,
215240
},
216241
{
217-
'params': [p for n, p in model_or_parameter if p.requires_grad and any(nd in n for nd in wd_ban_list)],
242+
'params': [p for n, p in model_or_parameter if p.requires_grad and any(nd in n for nd in full_names)],
218243
'weight_decay': 0.0,
219244
},
220245
]

tests/test_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def test_neuron_mean_norm():
9898

9999
def test_get_optimizer_parameters():
100100
model: nn.Module = Example()
101-
wd_ban_list: List[str] = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
101+
wd_ban_list: List[str] = ['bias', 'LayerNorm.bias', 'LayerNorm.weight', 'LayerNorm']
102102

103103
before_parameters = list(model.named_parameters())
104104
after_parameters = get_optimizer_parameters(model, weight_decay=1e-3, wd_ban_list=wd_ban_list)

0 commit comments

Comments
 (0)