Skip to content

Commit 546531c

Browse files
committed
fix some makefile issues
1 parent ea80942 commit 546531c

File tree

1 file changed

+14
-9
lines changed

1 file changed

+14
-9
lines changed

pytorch_optimizer/optimizer/utils.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -198,21 +198,20 @@ def get_optimizer_parameters(
198198
weight_decay: float,
199199
wd_ban_list: List[str] = ('bias', 'LayerNorm.bias', 'LayerNorm.weight'),
200200
) -> PARAMETERS:
201-
r"""Get optimizer parameters while filtering specified modules.
202-
201+
r"""
202+
Get optimizer parameters while filtering specified modules.
203203
:param model_or_parameter: Union[nn.Module, List]. model or parameters.
204204
:param weight_decay: float. weight_decay.
205205
:param wd_ban_list: List[str]. ban list not to set weight decay.
206206
:returns: PARAMETERS. new parameter list.
207207
"""
208208

209+
209210
fully_qualified_names = []
210211
for module_name, module in model_or_parameter.named_modules():
211-
for param_name, param in module.named_parameters(recurse=False):
212+
for param_name, _param in module.named_parameters(recurse=False):
212213
# Full parameter name includes module and parameter names
213-
full_param_name = (
214-
f"{module_name}.{param_name}" if module_name else param_name
215-
)
214+
full_param_name = f'{module_name}.{param_name}' if module_name else param_name
216215
# Check if any ban list substring is in the parameter name or module name
217216
if (
218217
any(banned in param_name for banned in wd_ban_list)
@@ -223,14 +222,20 @@ def get_optimizer_parameters(
223222

224223
if isinstance(model_or_parameter, nn.Module):
225224
model_or_parameter = list(model_or_parameter.named_parameters())
226-
225+
227226
return [
228227
{
229-
'params': [p for n, p in model_or_parameter if p.requires_grad and not any(nd in n for nd in fully_qualified_names)],
228+
'params': [
229+
p
230+
for n, p in model_or_parameter
231+
if p.requires_grad and not any(nd in n for nd in fully_qualified_names)
232+
],
230233
'weight_decay': weight_decay,
231234
},
232235
{
233-
'params': [p for n, p in model_or_parameter if p.requires_grad and any(nd in n for nd in fully_qualified_names)],
236+
'params': [
237+
p for n, p in model_or_parameter if p.requires_grad and any(nd in n for nd in fully_qualified_names)
238+
],
234239
'weight_decay': 0.0,
235240
},
236241
]

0 commit comments

Comments
 (0)