Skip to content

Commit 6018b07

Browse files
authored
Error message to inform bitsandbytes is only supported on CUDA (#19360)
1 parent bcc8de8 commit 6018b07

File tree

4 files changed

+28
-0
lines changed

4 files changed

+28
-0
lines changed

src/lightning/fabric/connector.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from lightning.fabric.accelerators.mps import MPSAccelerator
2525
from lightning.fabric.accelerators.xla import XLAAccelerator
2626
from lightning.fabric.plugins import (
27+
BitsandbytesPrecision,
2728
CheckpointIO,
2829
DeepSpeedPrecision,
2930
HalfPrecision,
@@ -448,6 +449,10 @@ def _init_strategy(self) -> None:
448449

449450
def _check_and_init_precision(self) -> Precision:
450451
if isinstance(self._precision_instance, Precision):
452+
if isinstance(self._precision_instance, BitsandbytesPrecision) and not isinstance(
453+
self.accelerator, CUDAAccelerator
454+
):
455+
raise RuntimeError("Bitsandbytes is only supported on CUDA GPUs.")
451456
return self._precision_instance
452457
if isinstance(self.strategy, (SingleDeviceXLAStrategy, XLAStrategy, XLAFSDPStrategy)):
453458
return XLAPrecision(self._precision_input) # type: ignore

src/lightning/pytorch/trainer/connectors/accelerator_connector.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from lightning.pytorch.accelerators.xla import XLAAccelerator
3838
from lightning.pytorch.plugins import (
3939
_PLUGIN_INPUT,
40+
BitsandbytesPrecision,
4041
CheckpointIO,
4142
DeepSpeedPrecision,
4243
DoublePrecision,
@@ -565,6 +566,10 @@ def _check_and_init_precision(self) -> Precision:
565566

566567
def _validate_precision_choice(self) -> None:
567568
"""Validate the combination of choices for precision, AMP type, and accelerator."""
569+
if isinstance(self._precision_plugin_flag, BitsandbytesPrecision) and not isinstance(
570+
self.accelerator, CUDAAccelerator
571+
):
572+
raise RuntimeError("Bitsandbytes is only supported on CUDA GPUs.")
568573
if _habana_available_and_importable():
569574
from lightning_habana import HPUAccelerator
570575

tests/tests_fabric/test_connector.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License
1414
import inspect
1515
import os
16+
import sys
1617
from typing import Any, Dict
1718
from unittest import mock
1819
from unittest.mock import Mock
@@ -29,6 +30,7 @@
2930
from lightning.fabric.accelerators.mps import MPSAccelerator
3031
from lightning.fabric.connector import _Connector
3132
from lightning.fabric.plugins import (
33+
BitsandbytesPrecision,
3234
DeepSpeedPrecision,
3335
DoublePrecision,
3436
FSDPPrecision,
@@ -864,6 +866,13 @@ def test_precision_selection_amp_ddp(strategy, devices, is_custom_plugin, plugin
864866
assert isinstance(connector.precision, plugin_cls)
865867

866868

869+
def test_bitsandbytes_precision_cuda_required(monkeypatch):
870+
monkeypatch.setattr(lightning.fabric.plugins.precision.bitsandbytes, "_BITSANDBYTES_AVAILABLE", True)
871+
monkeypatch.setitem(sys.modules, "bitsandbytes", Mock())
872+
with pytest.raises(RuntimeError, match="Bitsandbytes is only supported on CUDA GPUs"):
873+
_Connector(accelerator="cpu", plugins=BitsandbytesPrecision(mode="int8"))
874+
875+
867876
@pytest.mark.parametrize(("strategy", "strategy_cls"), [("DDP", DDPStrategy), ("Ddp", DDPStrategy)])
868877
@mock.patch("lightning.fabric.accelerators.mps.MPSAccelerator.is_available", return_value=False)
869878
def test_strategy_str_passed_being_case_insensitive(_, strategy, strategy_cls):

tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from unittest import mock
1919
from unittest.mock import Mock
2020

21+
import lightning.fabric
2122
import lightning.pytorch
2223
import pytest
2324
import torch
@@ -35,6 +36,7 @@
3536
from lightning.pytorch.plugins.io import TorchCheckpointIO
3637
from lightning.pytorch.plugins.layer_sync import LayerSync, TorchSyncBatchNorm
3738
from lightning.pytorch.plugins.precision import (
39+
BitsandbytesPrecision,
3840
DeepSpeedPrecision,
3941
DoublePrecision,
4042
FSDPPrecision,
@@ -1115,3 +1117,10 @@ def test_connector_num_nodes_input_validation():
11151117
def test_precision_selection(precision_str, strategy_str, expected_precision_cls):
11161118
connector = _AcceleratorConnector(precision=precision_str, strategy=strategy_str)
11171119
assert isinstance(connector.precision_plugin, expected_precision_cls)
1120+
1121+
1122+
def test_bitsandbytes_precision_cuda_required(monkeypatch):
1123+
monkeypatch.setattr(lightning.fabric.plugins.precision.bitsandbytes, "_BITSANDBYTES_AVAILABLE", True)
1124+
monkeypatch.setitem(sys.modules, "bitsandbytes", Mock())
1125+
with pytest.raises(RuntimeError, match="Bitsandbytes is only supported on CUDA GPUs"):
1126+
_AcceleratorConnector(accelerator="cpu", plugins=BitsandbytesPrecision(mode="int8"))

0 commit comments

Comments
 (0)