Skip to content

Commit f691cf6

Browse files
authored
Merge pull request #73 from Bing-su/hubconf
[Feature] support torch.hub.load
2 parents 569bf89 + b1f97fd commit f691cf6

File tree

2 files changed

+21
-4
lines changed

2 files changed

+21
-4
lines changed

hubconf.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
dependencies = ['torch']
2+
3+
from functools import partial as _partial, update_wrapper as _update_wrapper
4+
5+
from pytorch_optimizer import (
6+
get_supported_optimizers as _get_supported_optimizers,
7+
load_optimizer as _load_optimizer,
8+
)
9+
10+
for optimizer in _get_supported_optimizers():
11+
name = optimizer.__name__
12+
for n in (name, name.lower()):
13+
func = _partial(_load_optimizer, optimizer=n)
14+
_update_wrapper(func, optimizer)
15+
globals()[n] = func

pytorch_optimizer/__init__.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# pylint: disable=unused-import
2-
from typing import Callable, Dict, List
2+
from typing import Dict, List, Type
3+
4+
from torch.optim import Optimizer
35

46
from pytorch_optimizer.adabelief import AdaBelief
57
from pytorch_optimizer.adabound import AdaBound
@@ -54,10 +56,10 @@
5456
SGDP,
5557
Shampoo,
5658
]
57-
OPTIMIZERS: Dict[str, Callable] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}
59+
OPTIMIZERS: Dict[str, Type[Optimizer]] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}
5860

5961

60-
def load_optimizer(optimizer: str) -> Callable:
62+
def load_optimizer(optimizer: str) -> Type[Optimizer]:
6163
optimizer: str = optimizer.lower()
6264

6365
if optimizer not in OPTIMIZERS:
@@ -66,5 +68,5 @@ def load_optimizer(optimizer: str) -> Callable:
6668
return OPTIMIZERS[optimizer]
6769

6870

69-
def get_supported_optimizers() -> List:
71+
def get_supported_optimizers() -> List[Type[Optimizer]]:
7072
return OPTIMIZER_LIST

0 commit comments

Comments
 (0)