Skip to content

Commit 74caee6

Browse files
committed
feat: support torch.hub
1 parent 569bf89 commit 74caee6

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

hubconf.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
dependencies = ["torch"]
2+
3+
from functools import partial
4+
5+
from pytorch_optimizer import get_supported_optimizers, load_optimizer
6+
7+
for optimizer in get_supported_optimizers():
8+
name = optimizer.__name__
9+
for n in (name, name.lower()):
10+
globals()[n] = partial(load_optimizer, optimizer=n)

pytorch_optimizer/__init__.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# pylint: disable=unused-import
2-
from typing import Callable, Dict, List
2+
from typing import Dict, List, Type
3+
4+
from torch.optim import Optimizer
35

46
from pytorch_optimizer.adabelief import AdaBelief
57
from pytorch_optimizer.adabound import AdaBound
@@ -54,10 +56,10 @@
5456
SGDP,
5557
Shampoo,
5658
]
57-
OPTIMIZERS: Dict[str, Callable] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}
59+
OPTIMIZERS: Dict[str, Type[Optimizer]] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}
5860

5961

60-
def load_optimizer(optimizer: str) -> Callable:
62+
def load_optimizer(optimizer: str) -> Type[Optimizer]:
6163
optimizer: str = optimizer.lower()
6264

6365
if optimizer not in OPTIMIZERS:
@@ -66,5 +68,5 @@ def load_optimizer(optimizer: str) -> Callable:
6668
return OPTIMIZERS[optimizer]
6769

6870

69-
def get_supported_optimizers() -> List:
71+
def get_supported_optimizers() -> List[Type[Optimizer]]:
7072
return OPTIMIZER_LIST

0 commit comments

Comments
 (0)