Skip to content

Commit ea80942

Browse files
committed
make it slightly more concise
1 parent 905fca7 commit ea80942

File tree

1 file changed

+16
-25
lines changed

1 file changed

+16
-25
lines changed

pytorch_optimizer/optimizer/utils.py

Lines changed: 16 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -206,40 +206,31 @@ def get_optimizer_parameters(
206206
:returns: PARAMETERS. new parameter list.
207207
"""
208208

209-
def find_fully_qualified_names(
210-
model: nn.Module,
211-
wd_ban_list: List[str] = ("bias", "LayerNorm.weight", "LayerNorm.bias"),
212-
):
213-
names_without_wd = []
214-
215-
for module_name, module in model.named_modules():
216-
for param_name, param in module.named_parameters(recurse=False):
217-
# Full parameter name includes module and parameter names
218-
full_param_name = (
219-
f"{module_name}.{param_name}" if module_name else param_name
220-
)
221-
# Check if any ban list substring is in the parameter name or module name
222-
if (
223-
any(banned in param_name for banned in wd_ban_list)
224-
or any(banned in module_name for banned in wd_ban_list)
225-
or any(banned in module._get_name() for banned in wd_ban_list)
226-
):
227-
names_without_wd.append(full_param_name)
228-
229-
return names_without_wd
230-
231-
full_names = find_fully_qualified_names(model_or_parameter, wd_ban_list)
209+
fully_qualified_names = []
210+
for module_name, module in model_or_parameter.named_modules():
211+
for param_name, param in module.named_parameters(recurse=False):
212+
# Full parameter name includes module and parameter names
213+
full_param_name = (
214+
f"{module_name}.{param_name}" if module_name else param_name
215+
)
216+
# Check if any ban list substring is in the parameter name or module name
217+
if (
218+
any(banned in param_name for banned in wd_ban_list)
219+
or any(banned in module_name for banned in wd_ban_list)
220+
or any(banned in module._get_name() for banned in wd_ban_list)
221+
):
222+
fully_qualified_names.append(full_param_name)
232223

233224
if isinstance(model_or_parameter, nn.Module):
234225
model_or_parameter = list(model_or_parameter.named_parameters())
235226

236227
return [
237228
{
238-
'params': [p for n, p in model_or_parameter if p.requires_grad and not any(nd in n for nd in full_names)],
229+
'params': [p for n, p in model_or_parameter if p.requires_grad and not any(nd in n for nd in fully_qualified_names)],
239230
'weight_decay': weight_decay,
240231
},
241232
{
242-
'params': [p for n, p in model_or_parameter if p.requires_grad and any(nd in n for nd in full_names)],
233+
'params': [p for n, p in model_or_parameter if p.requires_grad and any(nd in n for nd in fully_qualified_names)],
243234
'weight_decay': 0.0,
244235
},
245236
]

0 commit comments

Comments
 (0)