Skip to content

Commit a2bf8d6

Browse files
committed
update: create_optimizer
1 parent e143cae commit a2bf8d6

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

pytorch_optimizer/__init__.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,13 @@
111111
)
112112
from pytorch_optimizer.optimizer.yogi import Yogi
113113

114+
try:
115+
import bitsandbytes
116+
117+
HAS_BNB: bool = True
118+
except ImportError:
119+
HAS_BNB: bool = False
120+
114121
OPTIMIZER_LIST: List[OPTIMIZER] = [
115122
AdaBelief,
116123
AdaBound,
@@ -240,10 +247,9 @@ def create_optimizer(
240247
"""
241248
optimizer_name = optimizer_name.lower()
242249

243-
if weight_decay > 0.0:
244-
parameters = get_optimizer_parameters(model, weight_decay, wd_ban_list)
245-
else:
246-
parameters = model.parameters()
250+
parameters = (
251+
get_optimizer_parameters(model, weight_decay, wd_ban_list) if weight_decay > 0.0 else model.parameters()
252+
)
247253

248254
optimizer = load_optimizer(optimizer_name)
249255

0 commit comments

Comments
 (0)