Skip to content

Commit 251c5ce

Browse files
committed
feature: load_bnb_optimizer
1 parent a2bf8d6 commit 251c5ce

File tree

1 file changed

+27
-1
lines changed

1 file changed

+27
-1
lines changed

pytorch_optimizer/__init__.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# ruff: noqa
22
from typing import Dict, List
33

4+
import torch.cuda
45
from torch import nn
56

67
from pytorch_optimizer.base.types import OPTIMIZER, PARAMETERS, SCHEDULER
@@ -112,7 +113,7 @@
112113
from pytorch_optimizer.optimizer.yogi import Yogi
113114

114115
try:
115-
import bitsandbytes
116+
import bitsandbytes as bnb
116117

117118
HAS_BNB: bool = True
118119
except ImportError:
@@ -218,9 +219,34 @@
218219
}
219220

220221

222+
def load_bnb_optimizer(optimizer: str) -> OPTIMIZER:
223+
r"""load bnb optimizer instance."""
224+
if 'sgd8bit' in optimizer:
225+
return bnb.optim.SGD8bit
226+
if 'adam8bit' in optimizer:
227+
return bnb.optim.Adam8bit
228+
if 'adamw8bit' in optimizer:
229+
return bnb.optim.AdamW8bit
230+
if 'lamb8bit' in optimizer:
231+
return bnb.optim.LAMB8bit
232+
if 'lars8bit' in optimizer:
233+
return bnb.optim.LARS8bit
234+
if 'lion8bit' in optimizer:
235+
return bnb.optim.Lion8bit
236+
if 'adagrad8bit' in optimizer:
237+
return bnb.optim.Adagrad8bit
238+
if 'rmsprop8bit' in optimizer:
239+
return bnb.optim.RMSprop8bit
240+
raise NotImplementedError(f'[-] not implemented optimizer : {optimizer}')
241+
242+
221243
def load_optimizer(optimizer: str) -> OPTIMIZER:
222244
optimizer: str = optimizer.lower()
223245

246+
if optimizer.startswith('bnb'):
247+
if HAS_BNB and torch.cuda.is_available():
248+
return load_bnb_optimizer(optimizer)
249+
raise ImportError(f'[-] bitsandbytes and CUDA required for bnb optimizers : {optimizer}')
224250
if optimizer not in OPTIMIZERS:
225251
raise NotImplementedError(f'[-] not implemented optimizer : {optimizer}')
226252

0 commit comments

Comments
 (0)