11# ruff: noqa
22from typing import Dict , List
33
4+ from torch import nn
5+
46from pytorch_optimizer .base .types import OPTIMIZER , PARAMETERS , SCHEDULER
57from pytorch_optimizer .experimental .deberta_v3_lr_scheduler import deberta_v3_large_lr_scheduler
68from pytorch_optimizer .lr_scheduler import (
@@ -132,18 +134,29 @@ def load_optimizer(optimizer: str) -> OPTIMIZER:
132134
133135
134136def create_optimizer (
135- parameters : PARAMETERS ,
137+ model : nn . Module ,
136138 optimizer_name : str ,
137139 lr : float = 1e-3 ,
138140 weight_decay : float = 0.0 ,
139141 wd_ban_list : List [str ] = ('bias' , 'LayerNorm.bias' , 'LayerNorm.weight' ),
140142 use_lookahead : bool = False ,
141143 ** kwargs ,
142144):
145+ r"""Build optimizer.
146+
147+ :param model: nn.Module. model.
148+ :param optimizer_name: str. name of optimizer.
149+ :param lr: float. learning rate.
150+ :param weight_decay: float. weight decay.
151+ :param wd_ban_list: List[str]. weight decay ban list by layer.
152+ :param use_lookahead: bool. use lookahead.
153+ """
143154 optimizer_name = optimizer_name .lower ()
144155
145156 if weight_decay > 0.0 :
146- parameters = get_optimizer_parameters (parameters , weight_decay , wd_ban_list )
157+ parameters = get_optimizer_parameters (model , weight_decay , wd_ban_list )
158+ else :
159+ parameters = model .parameters ()
147160
148161 optimizer = load_optimizer (optimizer_name )
149162
@@ -153,7 +166,12 @@ def create_optimizer(
153166 optimizer = optimizer (parameters , lr = lr , ** kwargs )
154167
155168 if use_lookahead :
156- optimizer = Lookahead (optimizer , ** kwargs )
169+ optimizer = Lookahead (
170+ optimizer ,
171+ k = kwargs ['k' ] if 'k' in kwargs else 5 ,
172+ alpha = kwargs ['alpha' ] if 'alpha' in kwargs else 0.5 ,
173+ pullback_momentum = kwargs ['pullback_momentum' ] if 'pullback_momentum' in kwargs else 'none' ,
174+ )
157175
158176 return optimizer
159177
0 commit comments