11# pylint: disable=unused-import
2- from typing import Dict , List , Type
2+ from typing import Dict , List
33
4- from torch . optim import Optimizer
5-
6- from pytorch_optimizer .adabelief import AdaBelief
7- from pytorch_optimizer .adabound import AdaBound
8- from pytorch_optimizer .adamp import AdamP
9- from pytorch_optimizer .adan import Adan
10- from pytorch_optimizer .adapnm import AdaPNM
11- from pytorch_optimizer .agc import agc
12- from pytorch_optimizer .chebyshev_schedule import get_chebyshev_schedule
13- from pytorch_optimizer .diffgrad import DiffGrad
14- from pytorch_optimizer .diffrgrad import DiffRGrad
15- from pytorch_optimizer .fp16 import DynamicLossScaler , SafeFP16Optimizer
16- from pytorch_optimizer .gc import centralize_gradient
17- from pytorch_optimizer .lamb import Lamb
18- from pytorch_optimizer .lars import LARS
19- from pytorch_optimizer .lookahead import Lookahead
20- from pytorch_optimizer .madgrad import MADGRAD
21- from pytorch_optimizer .nero import Nero
22- from pytorch_optimizer .pcgrad import PCGrad
23- from pytorch_optimizer .pnm import PNM
24- from pytorch_optimizer .radam import RAdam
25- from pytorch_optimizer .ralamb import RaLamb
26- from pytorch_optimizer .ranger import Ranger
27- from pytorch_optimizer .ranger21 import Ranger21
28- from pytorch_optimizer .sam import SAM
29- from pytorch_optimizer .sgdp import SGDP
30- from pytorch_optimizer .shampoo import Shampoo
31- from pytorch_optimizer .utils import (
4+ from pytorch_optimizer . base . types import OPTIMIZER , SCHEDULER
5+ from pytorch_optimizer . lr_scheduler . chebyshev import get_chebyshev_schedule
6+ from pytorch_optimizer .lr_scheduler . cosine_anealing import CosineAnnealingWarmupRestarts
7+ from pytorch_optimizer .optimizer . adabelief import AdaBelief
8+ from pytorch_optimizer .optimizer . adabound import AdaBound
9+ from pytorch_optimizer .optimizer . adamp import AdamP
10+ from pytorch_optimizer .optimizer . adan import Adan
11+ from pytorch_optimizer .optimizer . adapnm import AdaPNM
12+ from pytorch_optimizer .optimizer . agc import agc
13+ from pytorch_optimizer .optimizer . diffgrad import DiffGrad
14+ from pytorch_optimizer .optimizer . diffrgrad import DiffRGrad
15+ from pytorch_optimizer .optimizer . fp16 import DynamicLossScaler , SafeFP16Optimizer
16+ from pytorch_optimizer .optimizer . gc import centralize_gradient
17+ from pytorch_optimizer .optimizer . lamb import Lamb
18+ from pytorch_optimizer .optimizer . lars import LARS
19+ from pytorch_optimizer .optimizer . lookahead import Lookahead
20+ from pytorch_optimizer .optimizer . madgrad import MADGRAD
21+ from pytorch_optimizer .optimizer . nero import Nero
22+ from pytorch_optimizer .optimizer . pcgrad import PCGrad
23+ from pytorch_optimizer .optimizer . pnm import PNM
24+ from pytorch_optimizer .optimizer . radam import RAdam
25+ from pytorch_optimizer .optimizer . ralamb import RaLamb
26+ from pytorch_optimizer .optimizer . ranger import Ranger
27+ from pytorch_optimizer .optimizer . ranger21 import Ranger21
28+ from pytorch_optimizer .optimizer . sam import SAM
29+ from pytorch_optimizer .optimizer . sgdp import SGDP
30+ from pytorch_optimizer .optimizer . shampoo import Shampoo
31+ from pytorch_optimizer .optimizer . utils import (
3232 clip_grad_norm ,
3333 get_optimizer_parameters ,
3434 matrix_power ,
3535 normalize_gradient ,
3636 unit_norm ,
3737)
3838
39- OPTIMIZER_LIST : List [Type [ Optimizer ] ] = [
39+ OPTIMIZER_LIST : List [OPTIMIZER ] = [
4040 AdaBelief ,
4141 AdaBound ,
4242 AdamP ,
5656 SGDP ,
5757 Shampoo ,
5858]
59- OPTIMIZERS : Dict [str , Type [Optimizer ]] = {str (optimizer .__name__ ).lower (): optimizer for optimizer in OPTIMIZER_LIST }
59+ OPTIMIZERS : Dict [str , OPTIMIZER ] = {str (optimizer .__name__ ).lower (): optimizer for optimizer in OPTIMIZER_LIST }
60+
61+ LR_SCHEDULER_LIST : List [SCHEDULER ] = [
62+ CosineAnnealingWarmupRestarts ,
63+ ]
64+ LR_SCHEDULERS : Dict [str , SCHEDULER ] = {
65+ str (lr_scheduler .__name__ ).lower (): lr_scheduler for lr_scheduler in LR_SCHEDULER_LIST
66+ }
6067
6168
62- def load_optimizer (optimizer : str ) -> Type [ Optimizer ] :
69+ def load_optimizer (optimizer : str ) -> OPTIMIZER :
6370 optimizer : str = optimizer .lower ()
6471
6572 if optimizer not in OPTIMIZERS :
@@ -68,5 +75,18 @@ def load_optimizer(optimizer: str) -> Type[Optimizer]:
6875 return OPTIMIZERS [optimizer ]
6976
7077
71- def get_supported_optimizers () -> List [Type [Optimizer ]]:
78+ def load_lr_scheduler (lr_scheduler : str ) -> SCHEDULER :
79+ lr_scheduler : str = lr_scheduler .lower ()
80+
81+ if lr_scheduler not in LR_SCHEDULERS :
82+ raise NotImplementedError (f'[-] not implemented lr_scheduler : { lr_scheduler } ' )
83+
84+ return LR_SCHEDULERS [lr_scheduler ]
85+
86+
87+ def get_supported_optimizers () -> List [OPTIMIZER ]:
7288 return OPTIMIZER_LIST
89+
90+
91+ def get_supported_lr_schedulers () -> List [SCHEDULER ]:
92+ return LR_SCHEDULER_LIST
0 commit comments