Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(self, quantization_config, **kwargs):
self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules

def validate_environment(self, *args, **kwargs):
if not torch.cuda.is_available():
if not (torch.cuda.is_available() or torch.xpu.is_available()):
raise RuntimeError("No GPU found. A GPU is needed for quantization.")
if not is_accelerate_available() or is_accelerate_version("<", "0.26.0"):
raise ImportError(
Expand Down Expand Up @@ -237,11 +237,15 @@ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":

def update_device_map(self, device_map):
if device_map is None:
device_map = {"": f"cuda:{torch.cuda.current_device()}"}
if torch.xpu.is_available():
current_device = f"xpu:{torch.xpu.current_device()}"
else:
current_device = f"cuda:{torch.cuda.current_device()}"
device_map = {"": current_device}
logger.info(
"The device_map was not initialized. "
"Setting device_map to {"
": f`cuda:{torch.cuda.current_device()}`}. "
": {current_device}}. "
"If you want to use the model for inference, please set device_map ='auto' "
)
return device_map
Expand Down Expand Up @@ -311,7 +315,10 @@ def _dequantize(self, model):
logger.info(
"Model was found to be on CPU (could happen as a result of `enable_model_cpu_offload()`). So, moving it to GPU. After dequantization, will move the model back to CPU again to preserve the previous device."
)
model.to(torch.cuda.current_device())
if torch.xpu.is_available():
model.to(torch.xpu.current_device())
else:
model.to(torch.cuda.current_device())

model = dequantize_and_replace(
model, self.modules_to_not_convert, quantization_config=self.quantization_config
Expand Down Expand Up @@ -342,7 +349,7 @@ def __init__(self, quantization_config, **kwargs):
self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules

def validate_environment(self, *args, **kwargs):
if not torch.cuda.is_available():
if not (torch.cuda.is_available() or torch.xpu.is_available()):
raise RuntimeError("No GPU found. A GPU is needed for quantization.")
if not is_accelerate_available() or is_accelerate_version("<", "0.26.0"):
raise ImportError(
Expand Down Expand Up @@ -401,11 +408,15 @@ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
# Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.update_device_map
def update_device_map(self, device_map):
if device_map is None:
device_map = {"": f"cuda:{torch.cuda.current_device()}"}
if torch.xpu.is_available():
current_device = f"xpu:{torch.xpu.current_device()}"
else:
current_device = f"cuda:{torch.cuda.current_device()}"
device_map = {"": current_device}
logger.info(
"The device_map was not initialized. "
"Setting device_map to {"
": f`cuda:{torch.cuda.current_device()}`}. "
": {current_device}}. "
"If you want to use the model for inference, please set device_map ='auto' "
)
return device_map
Expand Down
35 changes: 18 additions & 17 deletions tests/quantization/bnb/test_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from diffusers.utils import is_accelerate_version, logging
from diffusers.utils.testing_utils import (
CaptureLogger,
backend_empty_cache,
is_bitsandbytes_available,
is_torch_available,
is_transformers_available,
Expand All @@ -34,7 +35,7 @@
require_accelerate,
require_bitsandbytes_version_greater,
require_torch,
require_torch_gpu,
require_torch_accelerator,
require_transformers_version_greater,
slow,
torch_device,
Expand Down Expand Up @@ -86,7 +87,7 @@ def forward(self, input, *args, **kwargs):
@require_bitsandbytes_version_greater("0.43.2")
@require_accelerate
@require_torch
@require_torch_gpu
@require_torch_accelerator
@slow
class Base4bitTests(unittest.TestCase):
# We need to test on relatively large models (aka >1b parameters otherwise the quantiztion may not work as expected)
Expand All @@ -102,13 +103,13 @@ class Base4bitTests(unittest.TestCase):

def get_dummy_inputs(self):
prompt_embeds = load_pt(
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/prompt_embeds.pt"
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/prompt_embeds.pt", torch_device
)
pooled_prompt_embeds = load_pt(
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/pooled_prompt_embeds.pt"
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/pooled_prompt_embeds.pt", torch_device
)
latent_model_input = load_pt(
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/latent_model_input.pt"
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/latent_model_input.pt", torch_device
)

input_dict_for_transformer = {
Expand All @@ -124,7 +125,7 @@ def get_dummy_inputs(self):
class BnB4BitBasicTests(Base4bitTests):
def setUp(self):
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

# Models
self.model_fp16 = SD3Transformer2DModel.from_pretrained(
Expand All @@ -144,7 +145,7 @@ def tearDown(self):
del self.model_4bit

gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

def test_quantization_num_parameters(self):
r"""
Expand Down Expand Up @@ -256,9 +257,9 @@ def test_device_assignment(self):
self.assertAlmostEqual(self.model_4bit.get_memory_footprint(), mem_before)

# Move back to CUDA device
for device in [0, "cuda", "cuda:0", "call()"]:
for device in [0, f"{torch_device}", f"{torch_device}:0", "call()"]:
if device == "call()":
self.model_4bit.cuda(0)
self.model_4bit.to(f"{torch_device}:0")
else:
self.model_4bit.to(device)
self.assertEqual(self.model_4bit.device, torch.device(0))
Expand All @@ -276,7 +277,7 @@ def test_device_and_dtype_assignment(self):

with self.assertRaises(ValueError):
# Tries with a `device` and `dtype`
self.model_4bit.to(device="cuda:0", dtype=torch.float16)
self.model_4bit.to(device=f"{torch_device}:0", dtype=torch.float16)

with self.assertRaises(ValueError):
# Tries with a cast
Expand All @@ -287,7 +288,7 @@ def test_device_and_dtype_assignment(self):
self.model_4bit.half()

# This should work
self.model_4bit.to("cuda")
self.model_4bit.to(torch_device)

# Test if we did not break anything
self.model_fp16 = self.model_fp16.to(dtype=torch.float32, device=torch_device)
Expand All @@ -311,7 +312,7 @@ def test_device_and_dtype_assignment(self):
_ = self.model_fp16.float()

# Check that this does not throw an error
_ = self.model_fp16.cuda()
_ = self.model_fp16.to(torch_device)

def test_bnb_4bit_wrong_config(self):
r"""
Expand Down Expand Up @@ -402,7 +403,7 @@ def test_training(self):
class SlowBnb4BitTests(Base4bitTests):
def setUp(self) -> None:
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
Expand All @@ -421,7 +422,7 @@ def tearDown(self):
del self.pipeline_4bit

gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

def test_quality(self):
output = self.pipeline_4bit(
Expand Down Expand Up @@ -491,7 +492,7 @@ def test_moving_to_cpu_throws_warning(self):
reason="Test will pass after https://github.com/huggingface/accelerate/pull/3223 is in a release.",
strict=True,
)
def test_pipeline_cuda_placement_works_with_nf4(self):
def test_pipeline_device_placement_works_with_nf4(self):
transformer_nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
Expand Down Expand Up @@ -522,7 +523,7 @@ def test_pipeline_cuda_placement_works_with_nf4(self):
transformer=transformer_4bit,
text_encoder_3=text_encoder_3_4bit,
torch_dtype=torch.float16,
).to("cuda")
).to(torch_device)

# Check if inference works.
_ = pipeline_4bit("table", max_sequence_length=20, num_inference_steps=2)
Expand Down Expand Up @@ -685,7 +686,7 @@ def test_lora_loading(self):
class BaseBnb4BitSerializationTests(Base4bitTests):
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

def test_serialization(self, quant_type="nf4", double_quant=True, safe_serialization=True):
r"""
Expand Down
29 changes: 15 additions & 14 deletions tests/quantization/bnb/test_mixed_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from diffusers.utils import is_accelerate_version
from diffusers.utils.testing_utils import (
CaptureLogger,
backend_empty_cache,
is_bitsandbytes_available,
is_torch_available,
is_transformers_available,
Expand All @@ -40,7 +41,7 @@
require_bitsandbytes_version_greater,
require_peft_version_greater,
require_torch,
require_torch_gpu,
require_torch_accelerator,
require_transformers_version_greater,
slow,
torch_device,
Expand Down Expand Up @@ -92,7 +93,7 @@ def forward(self, input, *args, **kwargs):
@require_bitsandbytes_version_greater("0.43.2")
@require_accelerate
@require_torch
@require_torch_gpu
@require_torch_accelerator
@slow
class Base8bitTests(unittest.TestCase):
# We need to test on relatively large models (aka >1b parameters otherwise the quantiztion may not work as expected)
Expand Down Expand Up @@ -130,7 +131,7 @@ def get_dummy_inputs(self):
class BnB8bitBasicTests(Base8bitTests):
def setUp(self):
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

# Models
self.model_fp16 = SD3Transformer2DModel.from_pretrained(
Expand All @@ -146,7 +147,7 @@ def tearDown(self):
del self.model_8bit

gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

def test_quantization_num_parameters(self):
r"""
Expand Down Expand Up @@ -274,7 +275,7 @@ def test_device_and_dtype_assignment(self):

with self.assertRaises(ValueError):
# Tries with a `device`
self.model_8bit.to(torch.device("cuda:0"))
self.model_8bit.to(torch.device(f"{torch_device}:0"))

with self.assertRaises(ValueError):
# Tries with a `device`
Expand Down Expand Up @@ -312,7 +313,7 @@ def test_device_and_dtype_assignment(self):
class Bnb8bitDeviceTests(Base8bitTests):
def setUp(self) -> None:
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
self.model_8bit = SanaTransformer2DModel.from_pretrained(
Expand All @@ -326,7 +327,7 @@ def tearDown(self):
del self.model_8bit

gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

def test_buffers_device_assignment(self):
for buffer_name, buffer in self.model_8bit.named_buffers():
Expand All @@ -340,7 +341,7 @@ def test_buffers_device_assignment(self):
class BnB8bitTrainingTests(Base8bitTests):
def setUp(self):
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
self.model_8bit = SD3Transformer2DModel.from_pretrained(
Expand Down Expand Up @@ -384,7 +385,7 @@ def test_training(self):
class SlowBnb8bitTests(Base8bitTests):
def setUp(self) -> None:
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
model_8bit = SD3Transformer2DModel.from_pretrained(
Expand All @@ -399,7 +400,7 @@ def tearDown(self):
del self.pipeline_8bit

gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

def test_quality(self):
output = self.pipeline_8bit(
Expand Down Expand Up @@ -611,7 +612,7 @@ def get_dummy_tensor_inputs(device=None, seed: int = 0):
class SlowBnb8bitFluxTests(Base8bitTests):
def setUp(self) -> None:
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

model_id = "hf-internal-testing/flux.1-dev-int8-pkg"
t5_8bit = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2")
Expand All @@ -628,7 +629,7 @@ def tearDown(self):
del self.pipeline_8bit

gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

def test_quality(self):
# keep the resolution and max tokens to a lower number for faster execution.
Expand Down Expand Up @@ -675,7 +676,7 @@ def test_lora_loading(self):
class BaseBnb8bitSerializationTests(Base8bitTests):
def setUp(self):
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
Expand All @@ -688,7 +689,7 @@ def tearDown(self):
del self.model_0

gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

def test_serialization(self):
r"""
Expand Down