|
1 | 1 | # ruff: noqa |
2 | 2 | from typing import Dict, List |
3 | 3 |
|
4 | | -from pytorch_optimizer.base.types import OPTIMIZER, SCHEDULER |
| 4 | +from pytorch_optimizer.base.types import OPTIMIZER, PARAMETERS, SCHEDULER |
5 | 5 | from pytorch_optimizer.experimental.deberta_v3_lr_scheduler import deberta_v3_large_lr_scheduler |
6 | 6 | from pytorch_optimizer.lr_scheduler import ( |
7 | 7 | ConstantLR, |
@@ -131,6 +131,33 @@ def load_optimizer(optimizer: str) -> OPTIMIZER: |
131 | 131 | return OPTIMIZERS[optimizer] |
132 | 132 |
|
133 | 133 |
|
| 134 | +def create_optimizer( |
| 135 | + parameters: PARAMETERS, |
| 136 | + optimizer_name: str, |
| 137 | + lr: float = 1e-3, |
| 138 | + weight_decay: float = 0.0, |
| 139 | + wd_ban_list: List[str] = ('bias', 'LayerNorm.bias', 'LayerNorm.weight'), |
| 140 | + use_lookahead: bool = False, |
| 141 | + **kwargs, |
| 142 | +): |
| 143 | + optimizer_name = optimizer_name.lower() |
| 144 | + |
| 145 | + if weight_decay > 0.0: |
| 146 | + parameters = get_optimizer_parameters(parameters, weight_decay, wd_ban_list) |
| 147 | + |
| 148 | + optimizer = load_optimizer(optimizer_name) |
| 149 | + |
| 150 | + if optimizer_name == 'alig': |
| 151 | + optimizer = optimizer(parameters, max_lr=lr, **kwargs) |
| 152 | + else: |
| 153 | + optimizer = optimizer(parameters, lr=lr, **kwargs) |
| 154 | + |
| 155 | + if use_lookahead: |
| 156 | + optimizer = Lookahead(optimizer, **kwargs) |
| 157 | + |
| 158 | + return optimizer |
| 159 | + |
| 160 | + |
134 | 161 | def load_lr_scheduler(lr_scheduler: str) -> SCHEDULER: |
135 | 162 | lr_scheduler: str = lr_scheduler.lower() |
136 | 163 |
|
|
0 commit comments