File tree Expand file tree Collapse file tree 2 files changed +16
-4
lines changed Expand file tree Collapse file tree 2 files changed +16
-4
lines changed Original file line number Diff line number Diff line change 1+ dependencies = ["torch" ]
2+
3+ from functools import partial
4+
5+ from pytorch_optimizer import get_supported_optimizers , load_optimizer
6+
7+ for optimizer in get_supported_optimizers ():
8+ name = optimizer .__name__
9+ for n in (name , name .lower ()):
10+ globals ()[n ] = partial (load_optimizer , optimizer = n )
Original file line number Diff line number Diff line change 11# pylint: disable=unused-import
2- from typing import Callable , Dict , List
2+ from typing import Dict , List , Type
3+
4+ from torch .optim import Optimizer
35
46from pytorch_optimizer .adabelief import AdaBelief
57from pytorch_optimizer .adabound import AdaBound
5456 SGDP ,
5557 Shampoo ,
5658]
57- OPTIMIZERS : Dict [str , Callable ] = {str (optimizer .__name__ ).lower (): optimizer for optimizer in OPTIMIZER_LIST }
59+ OPTIMIZERS : Dict [str , Type [ Optimizer ] ] = {str (optimizer .__name__ ).lower (): optimizer for optimizer in OPTIMIZER_LIST }
5860
5961
60- def load_optimizer (optimizer : str ) -> Callable :
62+ def load_optimizer (optimizer : str ) -> Type [ Optimizer ] :
6163 optimizer : str = optimizer .lower ()
6264
6365 if optimizer not in OPTIMIZERS :
@@ -66,5 +68,5 @@ def load_optimizer(optimizer: str) -> Callable:
6668 return OPTIMIZERS [optimizer ]
6769
6870
69- def get_supported_optimizers () -> List :
71+ def get_supported_optimizers () -> List [ Type [ Optimizer ]] :
7072 return OPTIMIZER_LIST
You can’t perform that action at this time.
0 commit comments