Skip to content

Commit 35b4cf2

Browse files
committed
allow device placement when using bnb quantization.
1 parent a98a839 commit 35b4cf2

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -410,10 +410,14 @@ def module_is_offloaded(module):
410410
pipeline_is_sequentially_offloaded = any(
411411
module_is_sequentially_offloaded(module) for _, module in self.components.items()
412412
)
413+
pipeline_has_bnb = any(
414+
(_check_bnb_status(module)[1] or _check_bnb_status(module)[-1]) for _, module in self.components.items()
415+
)
413416
if pipeline_is_sequentially_offloaded and device and torch.device(device).type == "cuda":
414-
raise ValueError(
415-
"It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
416-
)
417+
if not pipeline_has_bnb:
418+
raise ValueError(
419+
"It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
420+
)
417421

418422
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
419423
if is_pipeline_device_mapped:
@@ -449,7 +453,9 @@ def module_is_offloaded(module):
449453
# This can happen for `transformer` models. CPU placement was added in
450454
# https://github.com/huggingface/transformers/pull/33122. So, we guard this accordingly.
451455
if is_loaded_in_4bit_bnb and device is not None and is_transformers_version(">", "4.44.0"):
452-
module.to(device=device)
456+
# Since it's already supposed on CUDA.
457+
if torch.device(device).type != "cuda":
458+
module.to(device=device)
453459
elif not is_loaded_in_4bit_bnb and not is_loaded_in_8bit_bnb:
454460
module.to(device, dtype)
455461

0 commit comments

Comments
 (0)