|
1 | 1 | # ruff: noqa |
2 | 2 | from typing import Dict, List |
3 | 3 |
|
| 4 | +import torch.cuda |
4 | 5 | from torch import nn |
5 | 6 |
|
6 | 7 | from pytorch_optimizer.base.types import OPTIMIZER, PARAMETERS, SCHEDULER |
|
112 | 113 | from pytorch_optimizer.optimizer.yogi import Yogi |
113 | 114 |
|
114 | 115 | try: |
115 | | - import bitsandbytes |
| 116 | + import bitsandbytes as bnb |
116 | 117 |
|
117 | 118 | HAS_BNB: bool = True |
118 | 119 | except ImportError: |
|
218 | 219 | } |
219 | 220 |
|
220 | 221 |
|
| 222 | +def load_bnb_optimizer(optimizer: str) -> OPTIMIZER: |
| 223 | + r"""load bnb optimizer instance.""" |
| 224 | + if 'sgd8bit' in optimizer: |
| 225 | + return bnb.optim.SGD8bit |
| 226 | + if 'adam8bit' in optimizer: |
| 227 | + return bnb.optim.Adam8bit |
| 228 | + if 'adamw8bit' in optimizer: |
| 229 | + return bnb.optim.AdamW8bit |
| 230 | + if 'lamb8bit' in optimizer: |
| 231 | + return bnb.optim.LAMB8bit |
| 232 | + if 'lars8bit' in optimizer: |
| 233 | + return bnb.optim.LARS8bit |
| 234 | + if 'lion8bit' in optimizer: |
| 235 | + return bnb.optim.Lion8bit |
| 236 | + if 'adagrad8bit' in optimizer: |
| 237 | + return bnb.optim.Adagrad8bit |
| 238 | + if 'rmsprop8bit' in optimizer: |
| 239 | + return bnb.optim.RMSprop8bit |
| 240 | + raise NotImplementedError(f'[-] not implemented optimizer : {optimizer}') |
| 241 | + |
| 242 | + |
221 | 243 | def load_optimizer(optimizer: str) -> OPTIMIZER: |
222 | 244 | optimizer: str = optimizer.lower() |
223 | 245 |
|
| 246 | + if optimizer.startswith('bnb'): |
| 247 | + if HAS_BNB and torch.cuda.is_available(): |
| 248 | + return load_bnb_optimizer(optimizer) |
| 249 | + raise ImportError(f'[-] bitsandbytes and CUDA required for bnb optimizers : {optimizer}') |
224 | 250 | if optimizer not in OPTIMIZERS: |
225 | 251 | raise NotImplementedError(f'[-] not implemented optimizer : {optimizer}') |
226 | 252 |
|
|
0 commit comments