Skip to content

Commit 7b73dc2

Browse files
committed
updates
1 parent 6ff53e3 commit 7b73dc2

File tree

1 file changed

+10
-16
lines changed

1 file changed

+10
-16
lines changed

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,7 @@ def to(self, *args, **kwargs):
391391
)
392392

393393
device = device or device_arg
394+
pipeline_has_bnb = any(any((_check_bnb_status(module))) for _, module in self.components.items())
394395

395396
# throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU.
396397
def module_is_sequentially_offloaded(module):
@@ -414,9 +415,15 @@ def module_is_offloaded(module):
414415
module_is_sequentially_offloaded(module) for _, module in self.components.items()
415416
)
416417
if pipeline_is_sequentially_offloaded and device and torch.device(device).type == "cuda":
417-
raise ValueError(
418-
"It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
419-
)
418+
if not pipeline_has_bnb:
419+
raise ValueError(
420+
"It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
421+
)
422+
# PR: https://github.com/huggingface/accelerate/pull/3223/
423+
elif pipeline_has_bnb and is_accelerate_version("<", "1.1.0.dev0"):
424+
raise ValueError(
425+
"You are trying to call `.to('cuda')` on a pipeline that has models quantized with `bitsandbytes`. Your current `accelerate` installation does not support it. Please upgrade the installation."
426+
)
420427

421428
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
422429
if is_pipeline_device_mapped:
@@ -431,19 +438,6 @@ def module_is_offloaded(module):
431438
f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading."
432439
)
433440

434-
pipeline_has_bnb = any(any((_check_bnb_status(module))) for _, module in self.components.items())
435-
# PR: https://github.com/huggingface/accelerate/pull/3223/
436-
if (
437-
not pipeline_is_offloaded
438-
and not pipeline_is_sequentially_offloaded
439-
and pipeline_has_bnb
440-
and torch.device(device).type == "cuda"
441-
and is_accelerate_version("<", "1.1.0.dev0")
442-
):
443-
raise ValueError(
444-
"You are trying to call `.to('cuda')` on a pipeline that has models quantized with `bitsandbytes`. Your current `accelerate` installation does not support it. Please upgrade the installation."
445-
)
446-
447441
module_names, _ = self._get_signature_keys(self)
448442
modules = [getattr(self, n, None) for n in module_names]
449443
modules = [m for m in modules if isinstance(m, torch.nn.Module)]

0 commit comments

Comments
 (0)