Skip to content
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
35b4cf2
allow device placement when using bnb quantization.
sayakpaul Nov 1, 2024
ec4d422
warning.
sayakpaul Nov 2, 2024
2afa9b0
tests
sayakpaul Nov 2, 2024
3679ebd
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Nov 2, 2024
79633ee
fixes
sayakpaul Nov 5, 2024
876cd13
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Nov 5, 2024
a28c702
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Nov 5, 2024
ad1584d
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Nov 5, 2024
34d0925
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Nov 7, 2024
d713c41
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Nov 11, 2024
e9ef6ea
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Nov 15, 2024
6ce560e
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Nov 16, 2024
329b32e
docs.
sayakpaul Nov 16, 2024
2f6b07d
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Nov 18, 2024
fdeb500
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Nov 19, 2024
53bc502
require accelerate version.
sayakpaul Nov 19, 2024
f81b71e
remove print.
sayakpaul Nov 19, 2024
8e1b6f5
revert to()
sayakpaul Nov 21, 2024
e3e3a96
tests
sayakpaul Nov 21, 2024
9e9561b
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Nov 21, 2024
2ddcbf1
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Nov 24, 2024
5130cc3
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Nov 26, 2024
e76f93a
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Nov 29, 2024
1963b5c
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Dec 2, 2024
a799ba8
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Dec 2, 2024
7d47364
fixes
sayakpaul Dec 2, 2024
ebfec45
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Dec 3, 2024
1fe8a79
fix: missing AutoencoderKL lora adapter (#9807)
beniz Dec 3, 2024
f05d81d
fixes
sayakpaul Dec 3, 2024
6e17cad
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Dec 4, 2024
ea09eb2
fix condition test
sayakpaul Dec 4, 2024
1779093
updates
sayakpaul Dec 4, 2024
6ff53e3
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Dec 4, 2024
7b73dc2
updates
sayakpaul Dec 4, 2024
729acea
remove is_offloaded.
sayakpaul Dec 4, 2024
3d3aab4
fixes
sayakpaul Dec 4, 2024
c033816
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Dec 4, 2024
b5cffab
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Dec 4, 2024
662868b
better
sayakpaul Dec 4, 2024
3fc15fe
empty
sayakpaul Dec 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
is_accelerate_version,
is_torch_npu_available,
is_torch_version,
is_transformers_available,
is_transformers_version,
logging,
numpy_to_pil,
Expand All @@ -66,6 +67,8 @@
if is_torch_npu_available():
import torch_npu # noqa: F401

if is_transformers_available():
pass

from .pipeline_loading_utils import (
ALL_IMPORTABLE_CLASSES,
Expand Down Expand Up @@ -428,6 +431,19 @@ def module_is_offloaded(module):
f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading."
)

pipeline_has_bnb = any(any((_check_bnb_status(module))) for _, module in self.components.items())
# PR: https://github.com/huggingface/accelerate/pull/3223/
if (
not pipeline_is_offloaded
and not pipeline_is_sequentially_offloaded
and pipeline_has_bnb
and torch.device(device).type == "cuda"
and is_accelerate_version("<", "1.1.0.dev0")
):
raise ValueError(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the error message you want to throw against this scenario, no?

  1. accelerator < 1.1.0.dev0
  2. you call pipeline.to("cuda") on a pipeline that has bnb

but if these 2 condition are met (older accelerator version + bnb):

  1. not pipeline_is_sequentially_offloadedwill beFalse` here and you will not reach the value error
  2. you will reach this check first and get an error message -this is the wrong error message I was talking about
    if pipeline_is_sequentially_offloaded and device and torch.device(device).type == "cuda":
 if (
            not pipeline_is_offloaded
            and not pipeline_is_sequentially_offloaded
            and pipeline_has_bnb
            and torch.device(device).type == "cuda"
            and is_accelerate_version("<", "1.1.0.dev0")
        ):

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this makes a ton of sense. Thanks for the elaborate clarification. I have reflected this in my latest commits.

I have also tested most of the SLOW tests and they are passing. This is to ensure existing functionalities don't break with the current changes.

LMK.

"You are trying to call `.to('cuda')` on a pipeline that has models quantized with `bitsandbytes`. Your current `accelerate` installation does not support it. Please upgrade the installation."
)

module_names, _ = self._get_signature_keys(self)
modules = [getattr(self, n, None) for n in module_names]
modules = [m for m in modules if isinstance(m, torch.nn.Module)]
Expand All @@ -441,7 +457,7 @@ def module_is_offloaded(module):
f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {'4bit' if is_loaded_in_4bit_bnb else '8bit'} and conversion to {dtype} is not supported. Module is still in {'4bit' if is_loaded_in_4bit_bnb else '8bit'} precision."
)

if is_loaded_in_8bit_bnb and device is not None:
if is_loaded_in_8bit_bnb and not is_offloaded and device is not None:
logger.warning(
f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` 8bit and moving it to {device} via `.to()` is not supported. Module is still on {module.device}."
)
Expand Down
45 changes: 44 additions & 1 deletion tests/quantization/bnb/test_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@
import unittest

import numpy as np
import pytest
import safetensors.torch

from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel
from diffusers.utils import logging
from diffusers.utils import is_accelerate_version, logging
from diffusers.utils.testing_utils import (
CaptureLogger,
is_bitsandbytes_available,
Expand All @@ -47,6 +48,7 @@ def get_some_linear_layer(model):


if is_transformers_available():
from transformers import BitsAndBytesConfig as BnbConfig
from transformers import T5EncoderModel

if is_torch_available():
Expand Down Expand Up @@ -483,6 +485,47 @@ def test_moving_to_cpu_throws_warning(self):

assert "Pipelines loaded with `dtype=torch.float16`" in cap_logger.out

@pytest.mark.xfail(
condition=is_accelerate_version("<=", "1.1.1"),
reason="Test will pass after https://github.com/huggingface/accelerate/pull/3223 is in a release.",
strict=True,
)
def test_pipeline_cuda_placement_works_with_nf4(self):
transformer_nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
transformer_4bit = SD3Transformer2DModel.from_pretrained(
self.model_name,
subfolder="transformer",
quantization_config=transformer_nf4_config,
torch_dtype=torch.float16,
)
text_encoder_3_nf4_config = BnbConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
text_encoder_3_4bit = T5EncoderModel.from_pretrained(
self.model_name,
subfolder="text_encoder_3",
quantization_config=text_encoder_3_nf4_config,
torch_dtype=torch.float16,
)
# CUDA device placement works.
pipeline_4bit = DiffusionPipeline.from_pretrained(
self.model_name,
transformer=transformer_4bit,
text_encoder_3=text_encoder_3_4bit,
torch_dtype=torch.float16,
).to("cuda")

# Check if inference works.
_ = pipeline_4bit("table", max_sequence_length=20, num_inference_steps=2)

del pipeline_4bit


@require_transformers_version_greater("4.44.0")
class SlowBnb4BitFluxTests(Base4bitTests):
Expand Down
36 changes: 36 additions & 0 deletions tests/quantization/bnb/test_mixed_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
import unittest

import numpy as np
import pytest

from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel, logging
from diffusers.utils import is_accelerate_version
from diffusers.utils.testing_utils import (
CaptureLogger,
is_bitsandbytes_available,
Expand All @@ -44,6 +46,7 @@ def get_some_linear_layer(model):


if is_transformers_available():
from transformers import BitsAndBytesConfig as BnbConfig
from transformers import T5EncoderModel

if is_torch_available():
Expand Down Expand Up @@ -432,6 +435,39 @@ def test_generate_quality_dequantize(self):
output_type="np",
).images

@pytest.mark.xfail(
condition=is_accelerate_version("<=", "1.1.1"),
reason="Test will pass after https://github.com/huggingface/accelerate/pull/3223 is in a release.",
strict=True,
)
def test_pipeline_cuda_placement_works_with_mixed_int8(self):
transformer_8bit_config = BitsAndBytesConfig(load_in_8bit=True)
transformer_8bit = SD3Transformer2DModel.from_pretrained(
self.model_name,
subfolder="transformer",
quantization_config=transformer_8bit_config,
torch_dtype=torch.float16,
)
text_encoder_3_8bit_config = BnbConfig(load_in_8bit=True)
text_encoder_3_8bit = T5EncoderModel.from_pretrained(
self.model_name,
subfolder="text_encoder_3",
quantization_config=text_encoder_3_8bit_config,
torch_dtype=torch.float16,
)
# CUDA device placement works.
pipeline_8bit = DiffusionPipeline.from_pretrained(
self.model_name,
transformer=transformer_8bit,
text_encoder_3=text_encoder_3_8bit,
torch_dtype=torch.float16,
).to("cuda")

# Check if inference works.
_ = pipeline_8bit("table", max_sequence_length=20, num_inference_steps=2)

del pipeline_8bit


@require_transformers_version_greater("4.44.0")
class SlowBnb8bitFluxTests(Base8bitTests):
Expand Down
Loading