diff --git a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py index 60c2f495fef8..a9cd107d102e 100644 --- a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py +++ b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py @@ -61,8 +61,6 @@ def __init__(self, quantization_config, **kwargs): self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules def validate_environment(self, *args, **kwargs): - if not torch.cuda.is_available(): - raise RuntimeError("No GPU found. A GPU is needed for quantization.") if not is_accelerate_available() or is_accelerate_version("<", "0.26.0"): raise ImportError( "Using `bitsandbytes` 4-bit quantization requires Accelerate: `pip install 'accelerate>=0.26.0'`" @@ -72,6 +70,12 @@ def validate_environment(self, *args, **kwargs): "Using `bitsandbytes` 4-bit quantization requires the latest version of bitsandbytes: `pip install -U bitsandbytes`" ) + from ...utils import is_bitsandbytes_multi_backend_available + from .utils import validate_bnb_backend_availability + + bnb_multibackend_is_enabled = is_bitsandbytes_multi_backend_available() + validate_bnb_backend_availability(raise_exception=True) + if kwargs.get("from_flax", False): raise ValueError( "Converting into 4-bit weights from flax weights is currently not supported, please make" @@ -87,7 +91,9 @@ def validate_environment(self, *args, **kwargs): device_map_without_no_convert = { key: device_map[key] for key in device_map.keys() if key not in self.modules_to_not_convert } - if "cpu" in device_map_without_no_convert.values() or "disk" in device_map_without_no_convert.values(): + if set(device_map.values()) == {"cpu"} and bnb_multibackend_is_enabled: + pass + elif "cpu" in device_map_without_no_convert.values() or "disk" in device_map_without_no_convert.values(): raise ValueError( "Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit the " "quantized model. If you want to dispatch the model on the CPU or the disk while keeping these modules " @@ -240,10 +246,15 @@ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": # Commenting this for discussions on the PR. # def update_device_map(self, device_map): # if device_map is None: - # device_map = {"": torch.cuda.current_device()} + # if torch.cuda.is_available(): + # device_map = {"": torch.cuda.current_device()} + # elif is_torch_xpu_available(): + # device_map = {"": f"xpu:{torch.xpu.current_device()}"} + # else: + # device_map = {"": "cpu"} # logger.info( # "The device_map was not initialized. " - # "Setting device_map to {'':torch.cuda.current_device()}. " + # f"Setting device_map to {device_map}. " # "If you want to use the model for inference, please set device_map ='auto' " # ) # return device_map @@ -344,8 +355,6 @@ def __init__(self, quantization_config, **kwargs): self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules def validate_environment(self, *args, **kwargs): - if not torch.cuda.is_available(): - raise RuntimeError("No GPU found. A GPU is needed for quantization.") if not is_accelerate_available() or is_accelerate_version("<", "0.26.0"): raise ImportError( "Using `bitsandbytes` 8-bit quantization requires Accelerate: `pip install 'accelerate>=0.26.0'`" @@ -355,6 +364,12 @@ def validate_environment(self, *args, **kwargs): "Using `bitsandbytes` 8-bit quantization requires the latest version of bitsandbytes: `pip install -U bitsandbytes`" ) + from ...utils import is_bitsandbytes_multi_backend_available + from .utils import validate_bnb_backend_availability + + bnb_multibackend_is_enabled = is_bitsandbytes_multi_backend_available() + validate_bnb_backend_availability(raise_exception=True) + if kwargs.get("from_flax", False): raise ValueError( "Converting into 8-bit weights from flax weights is currently not supported, please make" @@ -370,7 +385,9 @@ def validate_environment(self, *args, **kwargs): device_map_without_no_convert = { key: device_map[key] for key in device_map.keys() if key not in self.modules_to_not_convert } - if "cpu" in device_map_without_no_convert.values() or "disk" in device_map_without_no_convert.values(): + if set(device_map.values()) == {"cpu"} and bnb_multibackend_is_enabled: + pass + elif "cpu" in device_map_without_no_convert.values() or "disk" in device_map_without_no_convert.values(): raise ValueError( "Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit the " "quantized model. If you want to dispatch the model on the CPU or the disk while keeping these modules " @@ -403,10 +420,15 @@ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": # # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.update_device_map # def update_device_map(self, device_map): # if device_map is None: - # device_map = {"": torch.cuda.current_device()} + # if torch.cuda.is_available(): + # device_map = {"": torch.cuda.current_device()} + # elif is_torch_xpu_available(): + # device_map = {"": f"xpu:{torch.xpu.current_device()}"} + # else: + # device_map = {"": "cpu"} # logger.info( # "The device_map was not initialized. " - # "Setting device_map to {'':torch.cuda.current_device()}. " + # f"Setting device_map to {device_map}. " # "If you want to use the model for inference, please set device_map ='auto' " # ) # return device_map diff --git a/src/diffusers/quantizers/bitsandbytes/utils.py b/src/diffusers/quantizers/bitsandbytes/utils.py index a9771b368a86..df5a2bcd7e72 100644 --- a/src/diffusers/quantizers/bitsandbytes/utils.py +++ b/src/diffusers/quantizers/bitsandbytes/utils.py @@ -16,11 +16,22 @@ https://github.com/huggingface/transformers/blob/c409cd81777fb27aadc043ed3d8339dbc020fb3b/src/transformers/integrations/bitsandbytes.py """ +import importlib import inspect from inspect import signature from typing import Union -from ...utils import is_accelerate_available, is_bitsandbytes_available, is_torch_available, logging +from packaging import version + +from ...utils import ( + get_available_devices, + is_accelerate_available, + is_bitsandbytes_available, + is_bitsandbytes_multi_backend_available, + is_ipex_available, + is_torch_available, + logging, +) from ..quantization_config import QuantizationMethod @@ -172,7 +183,7 @@ def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None, dtype: "torc logger.warning_once( f"The model is going to be dequantized in {output_tensor.dtype} - if you want to upcast it to another dtype, make sure to pass the desired dtype when quantizing the model through `bnb_4bit_quant_type` argument of `BitsAndBytesConfig`" ) - return output_tensor + return output_tensor.to(dtype) if state.SCB is None: state.SCB = weight.SCB @@ -310,3 +321,80 @@ def _check_bnb_status(module) -> Union[bool, bool]: and getattr(module, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES ) return is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb, is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb + + +def _validate_bnb_multi_backend_availability(raise_exception): + import bitsandbytes as bnb + + bnb_supported_devices = getattr(bnb, "supported_torch_devices", set()) + available_devices = get_available_devices() + + if available_devices == {"cpu"} and not is_ipex_available(): + from importlib.util import find_spec + + if find_spec("intel_extension_for_pytorch"): + logger.warning( + "You have Intel IPEX installed but if you're intending to use it for CPU, it might not have the right version. Be sure to double check that your PyTorch and IPEX installs are compatible." + ) + + available_devices.discard("cpu") # Only Intel CPU is supported by BNB at the moment + + if not available_devices.intersection(bnb_supported_devices): + if raise_exception: + bnb_supported_devices_with_info = set( # noqa: C401 + '"cpu" (needs an Intel CPU and intel_extension_for_pytorch installed and compatible with the PyTorch version)' + if device == "cpu" + else device + for device in bnb_supported_devices + ) + err_msg = ( + f"None of the available devices `available_devices = {available_devices or None}` are supported by the bitsandbytes version you have installed: `bnb_supported_devices = {bnb_supported_devices_with_info}`. " + "Please check the docs to see if the backend you intend to use is available and how to install it: https://huggingface.co/docs/bitsandbytes/main/en/installation#multi-backend" + ) + + logger.error(err_msg) + raise RuntimeError(err_msg) + + logger.warning("No supported devices found for bitsandbytes multi-backend.") + return False + + logger.debug("Multi-backend validation successful.") + return True + + +def _validate_bnb_cuda_backend_availability(raise_exception): + if not is_torch_available(): + return False + + import torch + + if not torch.cuda.is_available(): + log_msg = ( + "CUDA is required but not available for bitsandbytes. Please consider installing the multi-platform enabled version of bitsandbytes, which is currently a work in progress. " + "Please check currently supported platforms and installation instructions at https://huggingface.co/docs/bitsandbytes/main/en/installation#multi-backend" + ) + if raise_exception: + logger.error(log_msg) + raise RuntimeError(log_msg) + + logger.warning(log_msg) + return False + + logger.debug("CUDA backend validation successful.") + return True + + +def validate_bnb_backend_availability(raise_exception=False): + """ + Validates if the available devices are supported by bitsandbytes, optionally raising an exception if not. + """ + if not is_bitsandbytes_available(): + if importlib.util.find_spec("bitsandbytes") and version.parse( + importlib.metadata.version("bitsandbytes") + ) < version.parse("0.43.1"): + return _validate_bnb_cuda_backend_availability(raise_exception) + return False + + if is_bitsandbytes_multi_backend_available(): + return _validate_bnb_multi_backend_availability(raise_exception) + return _validate_bnb_cuda_backend_availability(raise_exception) diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index d82aded4c435..f0a1ff0bbcb6 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -14,6 +14,8 @@ import os +from functools import lru_cache +from typing import FrozenSet from packaging import version @@ -63,6 +65,7 @@ is_accelerate_available, is_accelerate_version, is_bitsandbytes_available, + is_bitsandbytes_multi_backend_available, is_bitsandbytes_version, is_bs4_available, is_flax_available, @@ -73,6 +76,7 @@ is_hf_hub_version, is_inflect_available, is_invisible_watermark_available, + is_ipex_available, is_k_diffusion_available, is_k_diffusion_version, is_librosa_available, @@ -87,10 +91,15 @@ is_tensorboard_available, is_timm_available, is_torch_available, + is_torch_cuda_available, + is_torch_mlu_available, + is_torch_mps_available, + is_torch_musa_available, is_torch_npu_available, is_torch_version, is_torch_xla_available, is_torch_xla_version, + is_torch_xpu_available, is_torchao_available, is_torchsde_available, is_torchvision_available, @@ -139,3 +148,31 @@ def check_min_version(min_version): error_message = f"This example requires a minimum version of {min_version}," error_message += f" but the version found is {__version__}.\n" raise ImportError(error_message) + + +@lru_cache() +def get_available_devices() -> FrozenSet[str]: + """ + Returns a frozenset of devices available for the current PyTorch installation. + """ + devices = {"cpu"} # `cpu` is always supported as a device in PyTorch + + if is_torch_cuda_available(): + devices.add("cuda") + + if is_torch_mps_available(): + devices.add("mps") + + if is_torch_xpu_available(): + devices.add("xpu") + + if is_torch_npu_available(): + devices.add("npu") + + if is_torch_mlu_available(): + devices.add("mlu") + + if is_torch_musa_available(): + devices.add("musa") + + return frozenset(devices) diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 37535366ed44..8f3a168b46f6 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -20,9 +20,10 @@ import os import sys from collections import OrderedDict +from functools import lru_cache from itertools import chain from types import ModuleType -from typing import Any, Union +from typing import Any, Optional, Union from huggingface_hub.utils import is_jinja_available # noqa: F401 from packaging import version @@ -364,11 +365,134 @@ def is_timm_available(): except importlib_metadata.PackageNotFoundError: _is_torchao_available = False +_ipex_available = importlib.util.find_spec("intel_extension_for_pytorch") is not None +if _ipex_available: + try: + _ipex_version = importlib_metadata.version("intel_extension_for_pytorch") + logger.debug(f"Successfully import intel_extension_for_pytorch version {_ipex_version}") + except importlib_metadata.PackageNotFoundError: + _ipex_available = False + def is_torch_available(): return _torch_available +def is_torch_cuda_available(): + if is_torch_available(): + import torch + + return torch.cuda.is_available() + else: + return False + + +def is_torch_mps_available(min_version: Optional[str] = None): + if is_torch_available(): + import torch + + if hasattr(torch.backends, "mps"): + backend_available = torch.backends.mps.is_available() and torch.backends.mps.is_built() + if min_version is not None: + flag = version.parse(_torch_version) >= version.parse(min_version) + backend_available = backend_available and flag + return backend_available + return False + + +def is_ipex_available(): + def get_major_and_minor_from_version(full_version): + return str(version.parse(full_version).major) + "." + str(version.parse(full_version).minor) + + if not is_torch_available() or not _ipex_available: + return False + + torch_major_and_minor = get_major_and_minor_from_version(_torch_version) + ipex_major_and_minor = get_major_and_minor_from_version(_ipex_version) + if torch_major_and_minor != ipex_major_and_minor: + logger.warning( + f"Intel Extension for PyTorch {ipex_major_and_minor} needs to work with PyTorch {ipex_major_and_minor}.*," + f" but PyTorch {_torch_version} is found. Please switch to the matching version and run again." + ) + return False + return True + + +@lru_cache +def is_torch_xpu_available(check_device=False): + """ + Checks if XPU acceleration is available either via `intel_extension_for_pytorch` or via stock PyTorch (>=2.4) and + potentially if a XPU is in the environment + """ + if not is_torch_available(): + return False + + torch_version = version.parse(_torch_version) + if is_ipex_available(): + import intel_extension_for_pytorch # noqa: F401 + elif torch_version.major < 2 or (torch_version.major == 2 and torch_version.minor < 4): + return False + + import torch + + if check_device: + try: + # Will raise a RuntimeError if no XPU is found + _ = torch.xpu.device_count() + return torch.xpu.is_available() + except RuntimeError: + return False + return hasattr(torch, "xpu") and torch.xpu.is_available() + + +@lru_cache() +def is_torch_mlu_available(check_device=False): + """ + Checks if `mlu` is available via an `cndev-based` check which won't trigger the drivers and leave mlu + uninitialized. + """ + if not _torch_available or importlib.util.find_spec("torch_mlu") is None: + return False + + import torch + import torch_mlu # noqa: F401 + + pytorch_cndev_based_mlu_check_previous_value = os.environ.get("PYTORCH_CNDEV_BASED_MLU_CHECK") + try: + os.environ["PYTORCH_CNDEV_BASED_MLU_CHECK"] = str(1) + available = torch.mlu.is_available() + finally: + if pytorch_cndev_based_mlu_check_previous_value: + os.environ["PYTORCH_CNDEV_BASED_MLU_CHECK"] = pytorch_cndev_based_mlu_check_previous_value + else: + os.environ.pop("PYTORCH_CNDEV_BASED_MLU_CHECK", None) + + return available + + +@lru_cache() +def is_torch_musa_available(check_device=False): + "Checks if `torch_musa` is installed and potentially if a MUSA is in the environment" + if not _torch_available or importlib.util.find_spec("torch_musa") is None: + return False + + import torch + import torch_musa # noqa: F401 + + torch_musa_min_version = "0.33.0" + if _accelerate_available and version.parse(_accelerate_version) < version.parse(torch_musa_min_version): + return False + + if check_device: + try: + # Will raise a RuntimeError if no MUSA is found + _ = torch.musa.device_count() + return torch.musa.is_available() + except RuntimeError: + return False + return hasattr(torch, "musa") and torch.musa.is_available() + + def is_torch_xla_available(): return _torch_xla_available @@ -473,6 +597,15 @@ def is_bitsandbytes_available(): return _bitsandbytes_available +def is_bitsandbytes_multi_backend_available() -> bool: + if not is_bitsandbytes_available(): + return False + + import bitsandbytes as bnb + + return "multi_backend" in getattr(bnb, "features", set()) + + def is_google_colab(): return _is_google_colab diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 7eda13716025..abb45b0b69a2 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -30,6 +30,7 @@ BACKENDS_MAPPING, is_accelerate_available, is_bitsandbytes_available, + is_bitsandbytes_multi_backend_available, is_compel_available, is_flax_available, is_gguf_available, @@ -40,6 +41,7 @@ is_timm_available, is_torch_available, is_torch_version, + is_torch_xpu_available, is_torchao_available, is_torchsde_available, is_transformers_available, @@ -201,6 +203,17 @@ def parse_flag_from_env(key, default=False): _run_compile_tests = parse_flag_from_env("RUN_COMPILE", default=False) +def get_device_count(): + import torch + + if is_torch_xpu_available(): + num_devices = torch.xpu.device_count() + else: + num_devices = torch.cuda.device_count() + + return num_devices + + def floats_tensor(shape, scale=1.0, rng=None, name=None): """Creates a random float32 tensor""" if rng is None: @@ -299,9 +312,9 @@ def require_torch_multi_gpu(test_case): if not is_torch_available(): return unittest.skip("test requires PyTorch")(test_case) - import torch + device_count = get_device_count() - return unittest.skipUnless(torch.cuda.device_count() > 1, "test requires multiple GPUs")(test_case) + return unittest.skipUnless(device_count > 1, "test requires multiple GPUs")(test_case) def require_torch_accelerator_with_fp16(test_case): @@ -346,6 +359,38 @@ def require_torch_accelerator_with_training(test_case): )(test_case) +def require_torch_gpu_if_bnb_not_multi_backend_enabled(test_case): + """ + Decorator marking a test that requires a GPU if bitsandbytes multi-backend feature is not enabled. + """ + if is_bitsandbytes_available() and is_bitsandbytes_multi_backend_available(): + return test_case + return require_torch_gpu(test_case) + + +def skip_if_not_implemented(test_func): + @functools.wraps(test_func) + def wrapper(*args, **kwargs): + try: + return test_func(*args, **kwargs) + except NotImplementedError as e: + raise unittest.SkipTest(f"Test skipped due to NotImplementedError: {e}") + + return wrapper + + +def apply_skip_if_not_implemented(cls): + """ + Class decorator to apply @skip_if_not_implemented to all test methods. + """ + for attr_name in dir(cls): + if attr_name.startswith("test_"): + attr = getattr(cls, attr_name) + if callable(attr): + setattr(cls, attr_name, skip_if_not_implemented(attr)) + return cls + + def skip_mps(test_case): """Decorator marking a test to skip if torch_device is 'mps'""" return unittest.skipUnless(torch_device != "mps", "test requires non 'mps' device")(test_case) diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index a9b9ab753084..f82881f03fa9 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -34,7 +34,7 @@ require_accelerate, require_bitsandbytes_version_greater, require_torch, - require_torch_gpu, + require_torch_gpu_if_bnb_not_multi_backend_enabled, require_transformers_version_greater, slow, torch_device, @@ -86,7 +86,7 @@ def forward(self, input, *args, **kwargs): @require_bitsandbytes_version_greater("0.43.2") @require_accelerate @require_torch -@require_torch_gpu +@require_torch_gpu_if_bnb_not_multi_backend_enabled @slow class Base4bitTests(unittest.TestCase): # We need to test on relatively large models (aka >1b parameters otherwise the quantiztion may not work as expected) @@ -124,7 +124,8 @@ def get_dummy_inputs(self): class BnB4BitBasicTests(Base4bitTests): def setUp(self): gc.collect() - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() # Models self.model_fp16 = SD3Transformer2DModel.from_pretrained( @@ -144,7 +145,8 @@ def tearDown(self): del self.model_4bit gc.collect() - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() def test_quantization_num_parameters(self): r""" @@ -214,7 +216,7 @@ def test_keep_modules_in_fp32(self): self.assertTrue(module.weight.dtype == torch.uint8) # test if inference works. - with torch.no_grad() and torch.amp.autocast("cuda", dtype=torch.float16): + with torch.no_grad() and torch.autocast(torch_device, dtype=torch.float16): input_dict_for_transformer = self.get_dummy_inputs() model_inputs = { k: v.to(device=torch_device) for k, v in input_dict_for_transformer.items() if not isinstance(v, bool) @@ -255,15 +257,16 @@ def test_device_assignment(self): self.assertEqual(self.model_4bit.device.type, "cpu") self.assertAlmostEqual(self.model_4bit.get_memory_footprint(), mem_before) - # Move back to CUDA device - for device in [0, "cuda", "cuda:0", "call()"]: - if device == "call()": - self.model_4bit.cuda(0) - else: - self.model_4bit.to(device) - self.assertEqual(self.model_4bit.device, torch.device(0)) - self.assertAlmostEqual(self.model_4bit.get_memory_footprint(), mem_before) - self.model_4bit.to("cpu") + if torch.cuda.is_available(): + # Move back to CUDA device + for device in [0, "cuda", "cuda:0", "call()"]: + if device == "call()": + self.model_4bit.cuda(0) + else: + self.model_4bit.to(device) + self.assertEqual(self.model_4bit.device, torch.device(0)) + self.assertAlmostEqual(self.model_4bit.get_memory_footprint(), mem_before) + self.model_4bit.to("cpu") def test_device_and_dtype_assignment(self): r""" @@ -274,10 +277,6 @@ def test_device_and_dtype_assignment(self): # Tries with a `dtype` self.model_4bit.to(torch.float16) - with self.assertRaises(ValueError): - # Tries with a `device` and `dtype` - self.model_4bit.to(device="cuda:0", dtype=torch.float16) - with self.assertRaises(ValueError): # Tries with a cast self.model_4bit.float() @@ -287,7 +286,7 @@ def test_device_and_dtype_assignment(self): self.model_4bit.half() # This should work - self.model_4bit.to("cuda") + self.model_4bit.to(torch.device(torch_device)) # Test if we did not break anything self.model_fp16 = self.model_fp16.to(dtype=torch.float32, device=torch_device) @@ -310,8 +309,9 @@ def test_device_and_dtype_assignment(self): # Check this does not throw an error _ = self.model_fp16.float() - # Check that this does not throw an error - _ = self.model_fp16.cuda() + if torch.cuda.is_available(): + # Check that this does not throw an error + _ = self.model_fp16.cuda() def test_bnb_4bit_wrong_config(self): r""" @@ -354,7 +354,8 @@ def test_bnb_4bit_errors_loading_incorrect_state_dict(self): class BnB4BitTrainingTests(Base4bitTests): def setUp(self): gc.collect() - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() nf4_config = BitsAndBytesConfig( load_in_4bit=True, @@ -388,7 +389,7 @@ def test_training(self): model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs}) # Step 4: Check if the gradient is not None - with torch.amp.autocast("cuda", dtype=torch.float16): + with torch.autocast(torch_device, dtype=torch.float16): out = self.model_4bit(**model_inputs)[0] out.norm().backward() @@ -402,7 +403,8 @@ def test_training(self): class SlowBnb4BitTests(Base4bitTests): def setUp(self) -> None: gc.collect() - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() nf4_config = BitsAndBytesConfig( load_in_4bit=True, @@ -421,7 +423,8 @@ def tearDown(self): del self.pipeline_4bit gc.collect() - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() def test_quality(self): output = self.pipeline_4bit( @@ -514,13 +517,14 @@ def test_pipeline_cuda_placement_works_with_nf4(self): quantization_config=text_encoder_3_nf4_config, torch_dtype=torch.float16, ) - # CUDA device placement works. - pipeline_4bit = DiffusionPipeline.from_pretrained( - self.model_name, - transformer=transformer_4bit, - text_encoder_3=text_encoder_3_4bit, - torch_dtype=torch.float16, - ).to("cuda") + if torch.cuda.is_available(): + # CUDA device placement works. + pipeline_4bit = DiffusionPipeline.from_pretrained( + self.model_name, + transformer=transformer_4bit, + text_encoder_3=text_encoder_3_4bit, + torch_dtype=torch.float16, + ).to("cuda") # Check if inference works. _ = pipeline_4bit("table", max_sequence_length=20, num_inference_steps=2) @@ -532,7 +536,8 @@ def test_pipeline_cuda_placement_works_with_nf4(self): class SlowBnb4BitFluxTests(Base4bitTests): def setUp(self) -> None: gc.collect() - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() model_id = "hf-internal-testing/flux.1-dev-nf4-pkg" t5_4bit = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2") @@ -549,7 +554,8 @@ def tearDown(self): del self.pipeline_4bit gc.collect() - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() def test_quality(self): # keep the resolution and max tokens to a lower number for faster execution. @@ -595,7 +601,8 @@ def test_lora_loading(self): class BaseBnb4BitSerializationTests(Base4bitTests): def tearDown(self): gc.collect() - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() def test_serialization(self, quant_type="nf4", double_quant=True, safe_serialization=True): r""" diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py index d1404a2f8929..6ff3fb763813 100644 --- a/tests/quantization/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -40,7 +40,7 @@ require_bitsandbytes_version_greater, require_peft_version_greater, require_torch, - require_torch_gpu, + require_torch_gpu_if_bnb_not_multi_backend_enabled, require_transformers_version_greater, slow, torch_device, @@ -92,7 +92,7 @@ def forward(self, input, *args, **kwargs): @require_bitsandbytes_version_greater("0.43.2") @require_accelerate @require_torch -@require_torch_gpu +@require_torch_gpu_if_bnb_not_multi_backend_enabled @slow class Base8bitTests(unittest.TestCase): # We need to test on relatively large models (aka >1b parameters otherwise the quantiztion may not work as expected) @@ -130,7 +130,8 @@ def get_dummy_inputs(self): class BnB8bitBasicTests(Base8bitTests): def setUp(self): gc.collect() - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() # Models self.model_fp16 = SD3Transformer2DModel.from_pretrained( @@ -146,7 +147,8 @@ def tearDown(self): del self.model_8bit gc.collect() - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() def test_quantization_num_parameters(self): r""" @@ -212,7 +214,7 @@ def test_keep_modules_in_fp32(self): self.assertTrue(module.weight.dtype == torch.int8) # test if inference works. - with torch.no_grad() and torch.amp.autocast("cuda", dtype=torch.float16): + with torch.no_grad() and torch.autocast(torch_device, dtype=torch.float16): input_dict_for_transformer = self.get_dummy_inputs() model_inputs = { k: v.to(device=torch_device) for k, v in input_dict_for_transformer.items() if not isinstance(v, bool) @@ -272,10 +274,6 @@ def test_device_and_dtype_assignment(self): # Tries with a `dtype`` self.model_8bit.to(torch.float16) - with self.assertRaises(ValueError): - # Tries with a `device` - self.model_8bit.to(torch.device("cuda:0")) - with self.assertRaises(ValueError): # Tries with a `device` self.model_8bit.float() @@ -305,8 +303,9 @@ def test_device_and_dtype_assignment(self): # Check this does not throw an error _ = self.model_fp16.float() - # Check that this does not throw an error - _ = self.model_fp16.cuda() + if torch.cuda.is_available(): + # Check that this does not throw an error + _ = self.model_fp16.cuda() class Bnb8bitDeviceTests(Base8bitTests): @@ -339,7 +338,8 @@ def test_buffers_device_assignment(self): class BnB8bitTrainingTests(Base8bitTests): def setUp(self): gc.collect() - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True) self.model_8bit = SD3Transformer2DModel.from_pretrained( @@ -369,7 +369,7 @@ def test_training(self): model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs}) # Step 4: Check if the gradient is not None - with torch.amp.autocast("cuda", dtype=torch.float16): + with torch.autocast(torch_device, dtype=torch.float16): out = self.model_8bit(**model_inputs)[0] out.norm().backward() @@ -383,7 +383,8 @@ def test_training(self): class SlowBnb8bitTests(Base8bitTests): def setUp(self) -> None: gc.collect() - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True) model_8bit = SD3Transformer2DModel.from_pretrained( @@ -398,7 +399,8 @@ def tearDown(self): del self.pipeline_8bit gc.collect() - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() def test_quality(self): output = self.pipeline_8bit( @@ -462,7 +464,7 @@ def test_generate_quality_dequantize(self): self.assertTrue(max_diff < 1e-2) # 8bit models cannot be offloaded to CPU. - self.assertTrue(self.pipeline_8bit.transformer.device.type == "cuda") + self.assertTrue(self.pipeline_8bit.transformer.device.type == torch.device(torch_device).type) # calling it again shouldn't be a problem _ = self.pipeline_8bit( prompt=self.prompt, @@ -491,13 +493,14 @@ def test_pipeline_cuda_placement_works_with_mixed_int8(self): quantization_config=text_encoder_3_8bit_config, torch_dtype=torch.float16, ) - # CUDA device placement works. - pipeline_8bit = DiffusionPipeline.from_pretrained( - self.model_name, - transformer=transformer_8bit, - text_encoder_3=text_encoder_3_8bit, - torch_dtype=torch.float16, - ).to("cuda") + if torch.cuda.is_available(): + # CUDA device placement works. + pipeline_8bit = DiffusionPipeline.from_pretrained( + self.model_name, + transformer=transformer_8bit, + text_encoder_3=text_encoder_3_8bit, + torch_dtype=torch.float16, + ).to("cuda") # Check if inference works. _ = pipeline_8bit("table", max_sequence_length=20, num_inference_steps=2) @@ -509,7 +512,8 @@ def test_pipeline_cuda_placement_works_with_mixed_int8(self): class SlowBnb8bitFluxTests(Base8bitTests): def setUp(self) -> None: gc.collect() - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() model_id = "hf-internal-testing/flux.1-dev-int8-pkg" t5_8bit = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2") @@ -526,7 +530,8 @@ def tearDown(self): del self.pipeline_8bit gc.collect() - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() def test_quality(self): # keep the resolution and max tokens to a lower number for faster execution. @@ -573,7 +578,8 @@ def test_lora_loading(self): class BaseBnb8bitSerializationTests(Base8bitTests): def setUp(self): gc.collect() - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() quantization_config = BitsAndBytesConfig( load_in_8bit=True, @@ -586,7 +592,8 @@ def tearDown(self): del self.model_0 gc.collect() - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() def test_serialization(self): r"""