Skip to content

Commit 935d5ae

Browse files
committed
feature: create_optimizer
1 parent 7177865 commit 935d5ae

File tree

1 file changed

+21
-3
lines changed

1 file changed

+21
-3
lines changed

pytorch_optimizer/__init__.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# ruff: noqa
22
from typing import Dict, List
33

4+
from torch import nn
5+
46
from pytorch_optimizer.base.types import OPTIMIZER, PARAMETERS, SCHEDULER
57
from pytorch_optimizer.experimental.deberta_v3_lr_scheduler import deberta_v3_large_lr_scheduler
68
from pytorch_optimizer.lr_scheduler import (
@@ -132,18 +134,29 @@ def load_optimizer(optimizer: str) -> OPTIMIZER:
132134

133135

134136
def create_optimizer(
135-
parameters: PARAMETERS,
137+
model: nn.Module,
136138
optimizer_name: str,
137139
lr: float = 1e-3,
138140
weight_decay: float = 0.0,
139141
wd_ban_list: List[str] = ('bias', 'LayerNorm.bias', 'LayerNorm.weight'),
140142
use_lookahead: bool = False,
141143
**kwargs,
142144
):
145+
r"""Build optimizer.
146+
147+
:param model: nn.Module. model.
148+
:param optimizer_name: str. name of optimizer.
149+
:param lr: float. learning rate.
150+
:param weight_decay: float. weight decay.
151+
:param wd_ban_list: List[str]. weight decay ban list by layer.
152+
:param use_lookahead: bool. use lookahead.
153+
"""
143154
optimizer_name = optimizer_name.lower()
144155

145156
if weight_decay > 0.0:
146-
parameters = get_optimizer_parameters(parameters, weight_decay, wd_ban_list)
157+
parameters = get_optimizer_parameters(model, weight_decay, wd_ban_list)
158+
else:
159+
parameters = model.parameters()
147160

148161
optimizer = load_optimizer(optimizer_name)
149162

@@ -153,7 +166,12 @@ def create_optimizer(
153166
optimizer = optimizer(parameters, lr=lr, **kwargs)
154167

155168
if use_lookahead:
156-
optimizer = Lookahead(optimizer, **kwargs)
169+
optimizer = Lookahead(
170+
optimizer,
171+
k=kwargs['k'] if 'k' in kwargs else 5,
172+
alpha=kwargs['alpha'] if 'alpha' in kwargs else 0.5,
173+
pullback_momentum=kwargs['pullback_momentum'] if 'pullback_momentum' in kwargs else 'none',
174+
)
157175

158176
return optimizer
159177

0 commit comments

Comments
 (0)