Skip to content

Commit 16f2372

Browse files
committed
fix(optim): catch ModuleNotFound & add warning
1 parent 0d6ec22 commit 16f2372

File tree

1 file changed

+54
-40
lines changed

1 file changed

+54
-40
lines changed

cellseg_models_pytorch/optimizers/__init__.py

Lines changed: 54 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import warnings
2+
13
from torch.optim import ASGD, SGD, Adadelta, Adagrad, Adam, Adamax, AdamW, RMSprop
24
from torch.optim.lr_scheduler import (
35
CosineAnnealingLR,
@@ -7,30 +9,60 @@
79
LambdaLR,
810
ReduceLROnPlateau,
911
)
10-
from torch_optimizer import (
11-
PID,
12-
QHM,
13-
SGDW,
14-
AccSGD,
15-
AdaBelief,
16-
AdaBound,
17-
AdaMod,
18-
AdamP,
19-
Apollo,
20-
DiffGrad,
21-
Lamb,
22-
Lookahead,
23-
NovoGrad,
24-
QHAdam,
25-
RAdam,
26-
Ranger,
27-
RangerQH,
28-
RangerVA,
29-
Yogi,
30-
)
3112

3213
from .utils import adjust_optim_params
3314

15+
EXTRA_OPTIM_LOOKUP = {}
16+
try:
17+
from torch_optimizer import (
18+
PID,
19+
QHM,
20+
SGDW,
21+
AccSGD,
22+
AdaBelief,
23+
AdaBound,
24+
AdaMod,
25+
AdamP,
26+
Apollo,
27+
DiffGrad,
28+
Lamb,
29+
Lookahead,
30+
NovoGrad,
31+
QHAdam,
32+
RAdam,
33+
Ranger,
34+
RangerQH,
35+
RangerVA,
36+
Yogi,
37+
)
38+
39+
EXTRA_OPTIM_LOOKUP = {
40+
"accsgd": AccSGD,
41+
"adabound": AdaBound,
42+
"adabelief": AdaBelief,
43+
"adamp": AdamP,
44+
"apollo": Apollo,
45+
"adamod": AdaMod,
46+
"diffgrad": DiffGrad,
47+
"lamb": Lamb,
48+
"novograd": NovoGrad,
49+
"pid": PID,
50+
"qhadam": QHAdam,
51+
"qhm": QHM,
52+
"radam": RAdam,
53+
"sgwd": SGDW,
54+
"yogi": Yogi,
55+
"ranger": Ranger,
56+
"rangerqh": RangerQH,
57+
"rangerva": RangerVA,
58+
"lookahead": Lookahead,
59+
}
60+
except ModuleNotFoundError:
61+
warnings.warn(
62+
"`torch_optimizer` optimzers not available. To use them, install with "
63+
"`pip install torch-optimizer`."
64+
)
65+
3466
SCHED_LOOKUP = {
3567
"lambda": LambdaLR,
3668
"reduce_on_plateau": ReduceLROnPlateau,
@@ -49,25 +81,7 @@
4981
"adamax": Adamax,
5082
"adamw": AdamW,
5183
"asgd": ASGD,
52-
"accsgd": AccSGD,
53-
"adabound": AdaBound,
54-
"adabelief": AdaBelief,
55-
"adamp": AdamP,
56-
"apollo": Apollo,
57-
"adamod": AdaMod,
58-
"diffgrad": DiffGrad,
59-
"lamb": Lamb,
60-
"novograd": NovoGrad,
61-
"pid": PID,
62-
"qhadam": QHAdam,
63-
"qhm": QHM,
64-
"radam": RAdam,
65-
"sgwd": SGDW,
66-
"yogi": Yogi,
67-
"ranger": Ranger,
68-
"rangerqh": RangerQH,
69-
"rangerva": RangerVA,
70-
"lookahead": Lookahead,
84+
**EXTRA_OPTIM_LOOKUP,
7185
}
7286

7387
__all__ = [

0 commit comments

Comments
 (0)