Skip to content

Commit 3c491f3

Browse files
authored
Merge pull request #81 from Bing-su/hubconf
[Fix] update hubconf.py
2 parents 2258885 + becf7cd commit 3c491f3

File tree

1 file changed

+14
-5
lines changed

1 file changed

+14
-5
lines changed

hubconf.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,23 @@
77
from functools import partial as _partial
88
from functools import update_wrapper as _update_wrapper
99

10+
from pytorch_optimizer import get_supported_lr_schedulers as _get_supported_lr_schedulers
1011
from pytorch_optimizer import get_supported_optimizers as _get_supported_optimizers
12+
from pytorch_optimizer import load_lr_scheduler as _load_lr_scheduler
1113
from pytorch_optimizer import load_optimizer as _load_optimizer
1214

1315
dependencies = ['torch']
1416

15-
for optimizer in _get_supported_optimizers():
16-
name: str = optimizer.__name__
17+
for _optimizer in _get_supported_optimizers():
18+
name: str = _optimizer.__name__
19+
_func = _partial(_load_optimizer, optimizer=name)
20+
_update_wrapper(_func, _optimizer.__init__)
1721
for n in (name, name.lower(), name.upper()):
18-
func = _partial(_load_optimizer, optimizer=n)
19-
_update_wrapper(func, optimizer)
20-
globals()[n] = func
22+
globals()[n] = _func
23+
24+
for _scheduler in _get_supported_lr_schedulers():
25+
name: str = _scheduler.__name__
26+
_func = _partial(_load_lr_scheduler, lr_scheduler=name)
27+
_update_wrapper(_func, _scheduler.__init__)
28+
for n in (name, name.lower(), name.upper()):
29+
globals()[n] = _func

0 commit comments

Comments
 (0)