Skip to content

Commit 7d47364

Browse files
committed
fixes
1 parent a799ba8 commit 7d47364

File tree

3 files changed

+21
-13
lines changed

3 files changed

+21
-13
lines changed

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,13 @@ def to(self, *args, **kwargs):
392392

393393
device = device or device_arg
394394

395+
pipeline_has_bnb = any(any((_check_bnb_status(module))) for _, module in self.components.items())
396+
# PR: https://github.com/huggingface/accelerate/pull/3223/
397+
if pipeline_has_bnb and torch.device(device).type == "cuda" and is_accelerate_version("<", "1.1.0.dev0"):
398+
raise ValueError(
399+
"You are trying to call `.to('cuda')` on a pipeline that has models quantized with `bitsandbytes`. Your current `accelerate` installation does not support it. Please upgrade the installation."
400+
)
401+
395402
# throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU.
396403
def module_is_sequentially_offloaded(module):
397404
if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
@@ -424,15 +431,6 @@ def module_is_offloaded(module):
424431
"It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` first and then call `to()`."
425432
)
426433

427-
pipeline_has_bnb = any(
428-
(_check_bnb_status(module)[1] or _check_bnb_status(module)[-1]) for _, module in self.components.items()
429-
)
430-
# PR: https://github.com/huggingface/accelerate/pull/3223/
431-
if pipeline_has_bnb and torch.device(device).type == "cuda" and is_accelerate_version("<", "1.1.0.dev0"):
432-
raise ValueError(
433-
"You are trying to call `.to('cuda')` on a pipeline that has models quantized with `bitsandbytes`. Your current `accelerate` installation does not support it. Please upgrade the installation."
434-
)
435-
436434
# Display a warning in this case (the operation succeeds but the benefits are lost)
437435
pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items())
438436
if pipeline_is_offloaded and device and torch.device(device).type == "cuda":

tests/quantization/bnb/test_4bit.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,20 @@
1818
import unittest
1919

2020
import numpy as np
21+
import pytest
2122
import safetensors.torch
2223

2324
from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel
2425
from diffusers.utils import logging
2526
from diffusers.utils.testing_utils import (
2627
CaptureLogger,
28+
is_accelerate_version,
2729
is_bitsandbytes_available,
2830
is_torch_available,
2931
is_transformers_available,
3032
load_pt,
3133
numpy_cosine_similarity_distance,
3234
require_accelerate,
33-
require_accelerate_version_greater,
3435
require_bitsandbytes_version_greater,
3536
require_torch,
3637
require_torch_gpu,
@@ -485,7 +486,11 @@ def test_moving_to_cpu_throws_warning(self):
485486

486487
assert "Pipelines loaded with `dtype=torch.float16`" in cap_logger.out
487488

488-
@require_accelerate_version_greater("1.0.0")
489+
@pytest.mark.xfail(
490+
condtion=is_accelerate_version("<=", "1.1.1"),
491+
reason="Test will pass after https://github.com/huggingface/accelerate/pull/3223 is in a release.",
492+
strict=True,
493+
)
489494
def test_pipeline_cuda_placement_works_with_nf4(self):
490495
transformer_nf4_config = BitsAndBytesConfig(
491496
load_in_4bit=True,

tests/quantization/bnb/test_mixed_int8.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,18 @@
1717
import unittest
1818

1919
import numpy as np
20+
import pytest
2021

2122
from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel, logging
2223
from diffusers.utils.testing_utils import (
2324
CaptureLogger,
25+
is_accelerate_version,
2426
is_bitsandbytes_available,
2527
is_torch_available,
2628
is_transformers_available,
2729
load_pt,
2830
numpy_cosine_similarity_distance,
2931
require_accelerate,
30-
require_accelerate_version_greater,
3132
require_bitsandbytes_version_greater,
3233
require_torch,
3334
require_torch_gpu,
@@ -434,7 +435,11 @@ def test_generate_quality_dequantize(self):
434435
output_type="np",
435436
).images
436437

437-
@require_accelerate_version_greater("1.0.0")
438+
@pytest.mark.xfail(
439+
condtion=is_accelerate_version("<=", "1.1.1"),
440+
reason="Test will pass after https://github.com/huggingface/accelerate/pull/3223 is in a release.",
441+
strict=True,
442+
)
438443
def test_pipeline_cuda_placement_works_with_mixed_int8(self):
439444
transformer_8bit_config = BitsAndBytesConfig(load_in_8bit=True)
440445
transformer_8bit = SD3Transformer2DModel.from_pretrained(

0 commit comments

Comments
 (0)