Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ def module_is_offloaded(module):
"It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` to remove the existing device map from the pipeline."
)

if device_type == "cuda":
if device_type in ["cuda", "xpu"]:
if pipeline_is_sequentially_offloaded and not pipeline_has_bnb:
raise ValueError(
"It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
Expand All @@ -440,7 +440,7 @@ def module_is_offloaded(module):

# Display a warning in this case (the operation succeeds but the benefits are lost)
pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items())
if pipeline_is_offloaded and device_type == "cuda":
if pipeline_is_offloaded and device_type in ["cuda", "xpu"]:
logger.warning(
f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading."
)
Expand Down
25 changes: 18 additions & 7 deletions src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(self, quantization_config, **kwargs):
self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules

def validate_environment(self, *args, **kwargs):
if not torch.cuda.is_available():
if not (torch.cuda.is_available() or torch.xpu.is_available()):
raise RuntimeError("No GPU found. A GPU is needed for quantization.")
if not is_accelerate_available() or is_accelerate_version("<", "0.26.0"):
raise ImportError(
Expand Down Expand Up @@ -238,11 +238,15 @@ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":

def update_device_map(self, device_map):
if device_map is None:
device_map = {"": f"cuda:{torch.cuda.current_device()}"}
if torch.xpu.is_available():
current_device = f"xpu:{torch.xpu.current_device()}"
else:
current_device = f"cuda:{torch.cuda.current_device()}"
device_map = {"": current_device}
logger.info(
"The device_map was not initialized. "
"Setting device_map to {"
": f`cuda:{torch.cuda.current_device()}`}. "
": {current_device}}. "
"If you want to use the model for inference, please set device_map ='auto' "
)
return device_map
Expand Down Expand Up @@ -312,7 +316,10 @@ def _dequantize(self, model):
logger.info(
"Model was found to be on CPU (could happen as a result of `enable_model_cpu_offload()`). So, moving it to GPU. After dequantization, will move the model back to CPU again to preserve the previous device."
)
model.to(torch.cuda.current_device())
if torch.xpu.is_available():
model.to(torch.xpu.current_device())
else:
model.to(torch.cuda.current_device())

model = dequantize_and_replace(
model, self.modules_to_not_convert, quantization_config=self.quantization_config
Expand Down Expand Up @@ -343,7 +350,7 @@ def __init__(self, quantization_config, **kwargs):
self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules

def validate_environment(self, *args, **kwargs):
if not torch.cuda.is_available():
if not (torch.cuda.is_available() or torch.xpu.is_available()):
raise RuntimeError("No GPU found. A GPU is needed for quantization.")
if not is_accelerate_available() or is_accelerate_version("<", "0.26.0"):
raise ImportError(
Expand Down Expand Up @@ -402,11 +409,15 @@ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
# Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.update_device_map
def update_device_map(self, device_map):
if device_map is None:
device_map = {"": f"cuda:{torch.cuda.current_device()}"}
if torch.xpu.is_available():
current_device = f"xpu:{torch.xpu.current_device()}"
else:
current_device = f"cuda:{torch.cuda.current_device()}"
device_map = {"": current_device}
logger.info(
"The device_map was not initialized. "
"Setting device_map to {"
": f`cuda:{torch.cuda.current_device()}`}. "
": {current_device}}. "
"If you want to use the model for inference, please set device_map ='auto' "
)
return device_map
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/utils/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,10 +574,10 @@ def load_numpy(arry: Union[str, np.ndarray], local_path: Optional[str] = None) -
return arry


def load_pt(url: str):
def load_pt(url: str, map_location: str):
response = requests.get(url)
response.raise_for_status()
arry = torch.load(BytesIO(response.content))
arry = torch.load(BytesIO(response.content), map_location=map_location)
return arry


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -377,9 +377,10 @@ def test_text_to_image_face_id(self):
pipeline.set_ip_adapter_scale(0.7)

inputs = self.get_dummy_inputs()
id_embeds = load_pt("https://huggingface.co/datasets/fabiorigano/testing-images/resolve/main/ai_face2.ipadpt")[
0
]
id_embeds = load_pt(
"https://huggingface.co/datasets/fabiorigano/testing-images/resolve/main/ai_face2.ipadpt",
map_location=torch_device,
)[0]
id_embeds = id_embeds.reshape((2, 1, 1, 512))
inputs["ip_adapter_image_embeds"] = [id_embeds]
inputs["ip_adapter_image"] = None
Expand Down
42 changes: 23 additions & 19 deletions tests/quantization/bnb/test_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from diffusers.utils import is_accelerate_version, logging
from diffusers.utils.testing_utils import (
CaptureLogger,
backend_empty_cache,
is_bitsandbytes_available,
is_torch_available,
is_transformers_available,
Expand All @@ -34,7 +35,7 @@
require_accelerate,
require_bitsandbytes_version_greater,
require_torch,
require_torch_gpu,
require_torch_accelerator,
require_transformers_version_greater,
slow,
torch_device,
Expand Down Expand Up @@ -65,7 +66,7 @@ def get_some_linear_layer(model):
@require_bitsandbytes_version_greater("0.43.2")
@require_accelerate
@require_torch
@require_torch_gpu
@require_torch_accelerator
@slow
class Base4bitTests(unittest.TestCase):
# We need to test on relatively large models (aka >1b parameters otherwise the quantiztion may not work as expected)
Expand All @@ -83,13 +84,16 @@ class Base4bitTests(unittest.TestCase):

def get_dummy_inputs(self):
prompt_embeds = load_pt(
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/prompt_embeds.pt"
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/prompt_embeds.pt",
torch_device,
)
pooled_prompt_embeds = load_pt(
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/pooled_prompt_embeds.pt"
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/pooled_prompt_embeds.pt",
torch_device,
)
latent_model_input = load_pt(
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/latent_model_input.pt"
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/latent_model_input.pt",
torch_device,
)

input_dict_for_transformer = {
Expand All @@ -105,7 +109,7 @@ def get_dummy_inputs(self):
class BnB4BitBasicTests(Base4bitTests):
def setUp(self):
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

# Models
self.model_fp16 = SD3Transformer2DModel.from_pretrained(
Expand All @@ -127,7 +131,7 @@ def tearDown(self):
del self.model_4bit

gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

def test_quantization_num_parameters(self):
r"""
Expand Down Expand Up @@ -223,7 +227,7 @@ def test_keep_modules_in_fp32(self):
self.assertTrue(module.weight.dtype == torch.uint8)

# test if inference works.
with torch.no_grad() and torch.amp.autocast("cuda", dtype=torch.float16):
with torch.no_grad() and torch.amp.autocast(torch_device, dtype=torch.float16):
input_dict_for_transformer = self.get_dummy_inputs()
model_inputs = {
k: v.to(device=torch_device) for k, v in input_dict_for_transformer.items() if not isinstance(v, bool)
Expand Down Expand Up @@ -265,9 +269,9 @@ def test_device_assignment(self):
self.assertAlmostEqual(self.model_4bit.get_memory_footprint(), mem_before)

# Move back to CUDA device
for device in [0, "cuda", "cuda:0", "call()"]:
for device in [0, f"{torch_device}", f"{torch_device}:0", "call()"]:
if device == "call()":
self.model_4bit.cuda(0)
self.model_4bit.to(f"{torch_device}:0")
else:
self.model_4bit.to(device)
self.assertEqual(self.model_4bit.device, torch.device(0))
Expand All @@ -285,7 +289,7 @@ def test_device_and_dtype_assignment(self):

with self.assertRaises(ValueError):
# Tries with a `device` and `dtype`
self.model_4bit.to(device="cuda:0", dtype=torch.float16)
self.model_4bit.to(device=f"{torch_device}:0", dtype=torch.float16)

with self.assertRaises(ValueError):
# Tries with a cast
Expand All @@ -296,7 +300,7 @@ def test_device_and_dtype_assignment(self):
self.model_4bit.half()

# This should work
self.model_4bit.to("cuda")
self.model_4bit.to(torch_device)

# Test if we did not break anything
self.model_fp16 = self.model_fp16.to(dtype=torch.float32, device=torch_device)
Expand All @@ -320,7 +324,7 @@ def test_device_and_dtype_assignment(self):
_ = self.model_fp16.float()

# Check that this does not throw an error
_ = self.model_fp16.cuda()
_ = self.model_fp16.to(torch_device)

def test_bnb_4bit_wrong_config(self):
r"""
Expand Down Expand Up @@ -397,7 +401,7 @@ def test_training(self):
model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs})

# Step 4: Check if the gradient is not None
with torch.amp.autocast("cuda", dtype=torch.float16):
with torch.amp.autocast(torch_device, dtype=torch.float16):
out = self.model_4bit(**model_inputs)[0]
out.norm().backward()

Expand All @@ -411,7 +415,7 @@ def test_training(self):
class SlowBnb4BitTests(Base4bitTests):
def setUp(self) -> None:
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
Expand All @@ -430,7 +434,7 @@ def tearDown(self):
del self.pipeline_4bit

gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

def test_quality(self):
output = self.pipeline_4bit(
Expand Down Expand Up @@ -500,7 +504,7 @@ def test_moving_to_cpu_throws_warning(self):
reason="Test will pass after https://github.com/huggingface/accelerate/pull/3223 is in a release.",
strict=True,
)
def test_pipeline_cuda_placement_works_with_nf4(self):
def test_pipeline_device_placement_works_with_nf4(self):
transformer_nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
Expand Down Expand Up @@ -531,7 +535,7 @@ def test_pipeline_cuda_placement_works_with_nf4(self):
transformer=transformer_4bit,
text_encoder_3=text_encoder_3_4bit,
torch_dtype=torch.float16,
).to("cuda")
).to(torch_device)

# Check if inference works.
_ = pipeline_4bit("table", max_sequence_length=20, num_inference_steps=2)
Expand Down Expand Up @@ -694,7 +698,7 @@ def test_lora_loading(self):
class BaseBnb4BitSerializationTests(Base4bitTests):
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

def test_serialization(self, quant_type="nf4", double_quant=True, safe_serialization=True):
r"""
Expand Down
Loading
Loading