Skip to content

Commit c55ec22

Browse files
committed
feature: load_optimizers
1 parent bdd887d commit c55ec22

File tree

1 file changed

+48
-0
lines changed

1 file changed

+48
-0
lines changed

pytorch_optimizer/optimizers.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from pytorch_optimizer.adabelief import AdaBelief
2+
from pytorch_optimizer.adabound import AdaBound
3+
from pytorch_optimizer.adahessian import AdaHessian
4+
from pytorch_optimizer.adamp import AdamP
5+
from pytorch_optimizer.diffgrad import DiffGrad
6+
from pytorch_optimizer.diffrgrad import DiffRGrad
7+
from pytorch_optimizer.fp16 import SafeFP16Optimizer
8+
from pytorch_optimizer.madgrad import MADGRAD
9+
from pytorch_optimizer.radam import RAdam
10+
from pytorch_optimizer.ranger import Ranger
11+
from pytorch_optimizer.ranger21 import Ranger21
12+
from pytorch_optimizer.sgdp import SGDP
13+
14+
15+
def load_optimizers(optimizer: str, use_fp16: bool = False):
16+
optimizer: str = optimizer.lower()
17+
18+
if optimizer == 'adamp':
19+
opt = AdamP
20+
elif optimizer == 'ranger':
21+
opt = Ranger
22+
elif optimizer == 'ranger21':
23+
opt = Ranger21
24+
elif optimizer == 'sgdp':
25+
opt = SGDP
26+
elif optimizer == 'radam':
27+
opt = RAdam
28+
elif optimizer == 'adabelief':
29+
opt = AdaBelief
30+
elif optimizer == 'adabound':
31+
opt = AdaBound
32+
elif optimizer == 'madgrad':
33+
opt = MADGRAD
34+
elif optimizer == 'diffrgrad':
35+
opt = DiffRGrad
36+
elif optimizer == 'diffgrad':
37+
opt = DiffGrad
38+
elif optimizer == 'diffgrad':
39+
opt = DiffGrad
40+
elif optimizer == 'adahessian':
41+
opt = AdaHessian
42+
else:
43+
raise NotImplementedError(f'[-] not implemented optimizer : {optimizer}')
44+
45+
if use_fp16:
46+
opt = SafeFP16Optimizer(opt)
47+
48+
return opt

0 commit comments

Comments
 (0)