We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent e143cae commit a2bf8d6Copy full SHA for a2bf8d6
pytorch_optimizer/__init__.py
@@ -111,6 +111,13 @@
111
)
112
from pytorch_optimizer.optimizer.yogi import Yogi
113
114
+try:
115
+ import bitsandbytes
116
+
117
+ HAS_BNB: bool = True
118
+except ImportError:
119
+ HAS_BNB: bool = False
120
121
OPTIMIZER_LIST: List[OPTIMIZER] = [
122
AdaBelief,
123
AdaBound,
@@ -240,10 +247,9 @@ def create_optimizer(
240
247
"""
241
248
optimizer_name = optimizer_name.lower()
242
249
243
- if weight_decay > 0.0:
244
- parameters = get_optimizer_parameters(model, weight_decay, wd_ban_list)
245
- else:
246
- parameters = model.parameters()
250
+ parameters = (
251
+ get_optimizer_parameters(model, weight_decay, wd_ban_list) if weight_decay > 0.0 else model.parameters()
252
+ )
253
254
optimizer = load_optimizer(optimizer_name)
255
0 commit comments