Skip to content

Commit a2d474d

Browse files
committed
update: create_optimizer
1 parent 8e9db2c commit a2d474d

File tree

1 file changed

+28
-1
lines changed

1 file changed

+28
-1
lines changed

pytorch_optimizer/__init__.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# ruff: noqa
22
from typing import Dict, List
33

4-
from pytorch_optimizer.base.types import OPTIMIZER, SCHEDULER
4+
from pytorch_optimizer.base.types import OPTIMIZER, PARAMETERS, SCHEDULER
55
from pytorch_optimizer.experimental.deberta_v3_lr_scheduler import deberta_v3_large_lr_scheduler
66
from pytorch_optimizer.lr_scheduler import (
77
ConstantLR,
@@ -131,6 +131,33 @@ def load_optimizer(optimizer: str) -> OPTIMIZER:
131131
return OPTIMIZERS[optimizer]
132132

133133

134+
def create_optimizer(
135+
parameters: PARAMETERS,
136+
optimizer_name: str,
137+
lr: float = 1e-3,
138+
weight_decay: float = 0.0,
139+
wd_ban_list: List[str] = ('bias', 'LayerNorm.bias', 'LayerNorm.weight'),
140+
use_lookahead: bool = False,
141+
**kwargs,
142+
):
143+
optimizer_name = optimizer_name.lower()
144+
145+
if weight_decay > 0.0:
146+
parameters = get_optimizer_parameters(parameters, weight_decay, wd_ban_list)
147+
148+
optimizer = load_optimizer(optimizer_name)
149+
150+
if optimizer_name == 'alig':
151+
optimizer = optimizer(parameters, max_lr=lr, **kwargs)
152+
else:
153+
optimizer = optimizer(parameters, lr=lr, **kwargs)
154+
155+
if use_lookahead:
156+
optimizer = Lookahead(optimizer, **kwargs)
157+
158+
return optimizer
159+
160+
134161
def load_lr_scheduler(lr_scheduler: str) -> SCHEDULER:
135162
lr_scheduler: str = lr_scheduler.lower()
136163

0 commit comments

Comments
 (0)