Skip to content

Commit 8e9db2c

Browse files
committed
update: get_optimizer_parameters
1 parent c7cd923 commit 8e9db2c

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

pytorch_optimizer/optimizer/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,9 @@ def unit_norm(x: torch.Tensor, norm: float = 2.0) -> torch.Tensor:
168168

169169

170170
def get_optimizer_parameters(
171-
model_or_parameter: Union[nn.Module, List], weight_decay: float, wd_ban_list: List[str] = ('bias', 'LayerNorm.bias', 'LayerNorm.weight')
171+
model_or_parameter: Union[nn.Module, List],
172+
weight_decay: float,
173+
wd_ban_list: List[str] = ('bias', 'LayerNorm.bias', 'LayerNorm.weight'),
172174
) -> PARAMETERS:
173175
r"""Get optimizer parameters while filtering specified modules.
174176

0 commit comments

Comments
 (0)