Skip to content

Commit 9896bf5

Browse files
committed
update: get_optimizer_parameters
1 parent 0b960fa commit 9896bf5

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

pytorch_optimizer/optimizer/utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,10 +184,13 @@ def get_optimizer_parameters(
184184

185185
return [
186186
{
187-
'params': [p for n, p in model_or_parameter if not any(nd in n for nd in wd_ban_list)],
187+
'params': [p for n, p in model_or_parameter if p.requires_grad and not any(nd in n for nd in wd_ban_list)],
188188
'weight_decay': weight_decay,
189189
},
190-
{'params': [p for n, p in model_or_parameter if any(nd in n for nd in wd_ban_list)], 'weight_decay': 0.0},
190+
{
191+
'params': [p for n, p in model_or_parameter if p.requires_grad and any(nd in n for nd in wd_ban_list)],
192+
'weight_decay': 0.0,
193+
},
191194
]
192195

193196

0 commit comments

Comments
 (0)