Skip to content

Conversation

@hlky
Copy link
Contributor

@hlky hlky commented Mar 6, 2025

What does this PR do?

Fixes #10989

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@hlky hlky requested a review from sayakpaul March 6, 2025 16:06
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fantastic! I left some nits, let me know if they make sense.

Somewhat related to #10588, which I had planned to work on. But glad that you beat me to it. Do we want to club that in this PR?

Also, let's add a test to test_4bit.py?

@hlky
Copy link
Contributor Author

hlky commented Mar 7, 2025

Test also added, needs slice updating from runner.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool to merge after the slices have been updated. Thanks, @hlky!

@hlky
Copy link
Contributor Author

hlky commented Mar 7, 2025

@bot /style

@github-actions
Copy link
Contributor

github-actions bot commented Mar 7, 2025

Style fixes have been applied. View the workflow run here.

@hlky
Copy link
Contributor Author

hlky commented Mar 7, 2025

What's the easiest way to trigger the slow bnb tests?

@sayakpaul
Copy link
Member

Would have been this workflow:
https://github.com/huggingface/diffusers/actions/workflows/run_tests_from_a_pr.yml but it won't have bitsandbytes installed. So, it will get skipped.

Easier would be to use with the big GPU (aws-g6e-xlarge-plus): https://github.com/huggingface/diffusers/actions/workflows/ssh-runner.yml, following this internal doc.

@sayakpaul
Copy link
Member

sayakpaul commented Mar 18, 2025

While trying to gather slices, I ran into:

src/diffusers/loaders/lora_pipeline.py:1552: in load_lora_weights
    has_param_with_expanded_shape = self._maybe_expand_transformer_param_shape_or_error_(
src/diffusers/loaders/lora_pipeline.py:1981: in _maybe_expand_transformer_param_shape_or_error_
    module_weight = dequantize_bnb_weight(module.weight, state=module.weight.quant_state).data
src/diffusers/quantizers/bitsandbytes/utils.py:171: in dequantize_bnb_weight
    output_tensor = bnb.functional.dequantize_4bit(weight.data, weight.quant_state)
../.pyenv/versions/3.10.12/envs/diffusers/lib/python3.10/site-packages/bitsandbytes/functional.py:1363: in dequantize_4bit
    is_on_gpu([A, absmax, out])
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

tensors = [tensor([[153],
        [230],
        [145],
        ...,
        [ 54],
        [ 85],
        [ 39]], dtype=torch.u....., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], dtype=torch.float16)]

    def is_on_gpu(tensors: Iterable[Optional[torch.Tensor]]):
        """Verifies that the input tensors are all on the same device.
    
        An input tensor may also be marked as `paged`, in which case the device placement is ignored.
    
        Args:
            tensors (`Iterable[Optional[torch.Tensor]]`): A list of tensors to verify.
    
        Raises:
            `RuntimeError`: Raised when the verification fails.
    
        Returns:
            `Literal[True]`
        """
    
        on_gpu = True
        gpu_ids = set()
    
        for t in tensors:
            # NULL pointers and paged tensors are OK.
            if t is not None and not getattr(t, "is_paged", False):
                on_gpu &= t.is_cuda
                gpu_ids.add(t.device.index)
    
        if not on_gpu:
>           raise RuntimeError(
                f"All input tensors need to be on the same GPU, but found some tensors to not be on a GPU:\n {[(t.shape, t.device) for t in tensors]}",
            )
E           RuntimeError: All input tensors need to be on the same GPU, but found some tensors to not be on a GPU:
E            [(torch.Size([393216, 1]), device(type='cpu')), (torch.Size([12288]), device(type='cpu')), (torch.Size([3072, 256]), device(type='cpu'))]

../.pyenv/versions/3.10.12/envs/diffusers/lib/python3.10/site-packages/bitsandbytes/functional.py:464: RuntimeError

@hlky possible to look into it? It passes with main and your PR #11039.

@sayakpaul
Copy link
Member

@hlky I made it possible to do enable_model_cpu_offload() on this:

diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py
index a38b38774..10d37d000 100644
--- a/src/diffusers/loaders/lora_pipeline.py
+++ b/src/diffusers/loaders/lora_pipeline.py
@@ -1978,7 +1978,16 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
                         "The checkpoint seems to have been quantized with `bitsandbytes` (4bits). Install `bitsandbytes` to load quantized checkpoints."
                     )
                 elif is_bnb_4bit_quantized:
-                    module_weight = dequantize_bnb_weight(module.weight, state=module.weight.quant_state).data
+                    weight_on_cpu = False
+                    if not module.weight.is_cuda:
+                        weight_on_cpu = True
+                    module_weight = dequantize_bnb_weight(
+                        module.weight.cuda() if weight_on_cpu else module.weight, 
+                        state=module.weight.quant_state, 
+                        dtype=transformer.dtype
+                    ).data
+                    if weight_on_cpu:
+                        module_weight = module_weight.cpu()
                 else:
                     module_weight = module.weight.data
                 module_bias = module.bias.data if module.bias is not None else None
diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py
index 9b1f78acb..e2a307cad 100644
--- a/tests/quantization/bnb/test_4bit.py
+++ b/tests/quantization/bnb/test_4bit.py
@@ -21,8 +21,9 @@ import numpy as np
 import pytest
 import safetensors.torch
 from huggingface_hub import hf_hub_download
+from PIL import Image
 
-from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel
+from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel, FluxControlPipeline
 from diffusers.utils import is_accelerate_version, logging
 from diffusers.utils.testing_utils import (
     CaptureLogger,
@@ -702,10 +703,7 @@ class SlowBnb4BitFluxControlWithLoraTests(Base4bitTests):
         gc.collect()
         torch.cuda.empty_cache()
 
-        self.pipeline_4bit = DiffusionPipeline.from_pretrained(
-            "eramth/flux-4bit",
-            torch_dtype=torch.float16,
-        )
+        self.pipeline_4bit = FluxControlPipeline.from_pretrained("eramth/flux-4bit", torch_dtype=torch.float16)
         self.pipeline_4bit.enable_model_cpu_offload()
 
     def tearDown(self):
@@ -719,6 +717,7 @@ class SlowBnb4BitFluxControlWithLoraTests(Base4bitTests):
 
         output = self.pipeline_4bit(
             prompt=self.prompt,
+            control_image=Image.new(mode="RGB", size=(256, 256)),
             height=256,
             width=256,
             max_sequence_length=64,
@@ -727,8 +726,7 @@ class SlowBnb4BitFluxControlWithLoraTests(Base4bitTests):
             generator=torch.Generator().manual_seed(42),
         ).images
         out_slice = output[0, -3:, -3:, -1].flatten()
-        # TODO: update slice
-        expected_slice = np.array([0.5347, 0.5342, 0.5283, 0.5093, 0.4988, 0.5093, 0.5044, 0.5015, 0.4946])
+        expected_slice = np.array([0.1636, 0.1675, 0.1982, 0.1743, 0.1809, 0.1936, 0.1743, 0.2095, 0.2139])
 
         max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
         self.assertTrue(max_diff < 1e-3, msg=f"{out_slice=} != {expected_slice=}")

Okay to update this branch so we can move forward?

@sayakpaul sayakpaul requested a review from DN6 March 20, 2025 09:47
@sayakpaul
Copy link
Member

@DN6 could you give this a check? It enables loading Flux Control LoRAs into quantized checkpoints.

@sayakpaul
Copy link
Member

@hlky @DN6 when loading a LoRA into a quantized model, I would expect the expanded module to also have quantized linear. For example here we replace the expanded module with torch.nn.Linear but it should be 4bit Linear when the underlying base module is so:

expanded_module = torch.nn.Linear(

Or am I thinking incorrectly?

@DN6
Copy link
Collaborator

DN6 commented Apr 4, 2025

when loading a LoRA into a quantized model, I would expect the expanded module to also have quantized linear. For example here we replace the expanded module with torch.nn.Linear but it should be 4bit Linear when the underlying base module is so:

The current fix to just dequantize the expanded layer should be okay I think? The expanded module is a single layer no? Wouldn't hurt memory much?

But yes technically, we should recreate the quantized linear layer. Because of the shape differences in 4bit BnB weights, I think how it would have to be done is to dequantize the weight, zero pad the extra channels and then create the appropriate quantized linear layer again. For 8bit I think you can just zero pad the quantized weight as is and recreate the quantized layer.

I think the issue for expanded modules would affect GGUF and Quanto as well. GGUF has the same problem as 4bit BnB ;the quant weight shapes are different from the original, so I don't think we can naively zero pad the weight. For Quanto, the way the weight tensor parameters are stored in model layers is a bit different from the rest of the backends.

We can just fix BnB in this PR and do a more thorough one for all the quant backends.

@DN6
Copy link
Collaborator

DN6 commented Apr 8, 2025

@hlky @sayakpaul Could you take another look at this. It adds support to handle GGUF quants as well (which run into the same issue as BnB 4bit). Additionally the current _maybe_expand_transformer_param_shape method errors out when loading with GGUF and regulars LoRAs as well.

@sayakpaul
Copy link
Member

Taking a look.

@sayakpaul
Copy link
Member

But yes technically, we should recreate the quantized linear layer. Because of the shape differences in 4bit BnB weights, I think how it would have to be done is to dequantize the weight, zero pad the extra channels and then create the appropriate quantized linear layer again. For 8bit I think you can just zero pad the quantized weight as is and recreate the quantized layer.

This is true and in line with what I had in mind. But don't you think that if we ship the PR without this change, it will create a breaking change when we eventually ship the recreation of the quantized layer for the expanded module since the values would change? @DN6

@sayakpaul
Copy link
Member

Additionally the current _maybe_expand_transformer_param_shape method errors out when loading with GGUF and regulars LoRAs as well.

Do you have a reproducer for this? @DN6

@sayakpaul
Copy link
Member

sayakpaul commented Apr 8, 2025

@DN6 some updates. When I say "works" it means loading works and it results into a reasonable image.

BnB params + Control LoRA works.

Code
from diffusers import FluxControlPipeline 
from diffusers.utils import load_image
from controlnet_aux import CannyDetector
import torch 

pipeline_4bit = FluxControlPipeline.from_pretrained("eramth/flux-4bit", torch_dtype=torch.float16).to("cuda")
pipeline_4bit.load_lora_weights("black-forest-labs/FLUX.1-Canny-dev-lora")

prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
control_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")

processor = CannyDetector()
control_image = processor(control_image, low_threshold=50, high_threshold=200, detect_resolution=1024, image_resolution=1024)

image = pipeline_4bit(
    prompt=prompt,
    control_image=control_image,
    height=1024,
    width=1024,
    num_inference_steps=50,
    guidance_scale=30.0,
).images[0]
image.save("output.png")

GGUF + Control LoRA works

Code
import torch
from diffusers.utils import load_image 
from controlnet_aux import CannyDetector
from diffusers import FluxControlPipeline, FluxTransformer2DModel, GGUFQuantizationConfig

ckpt_path = (
    "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
)
transformer = FluxTransformer2DModel.from_single_file(
    ckpt_path,
    quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
    torch_dtype=torch.bfloat16,
)
pipe = FluxControlPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    transformer=transformer,
    torch_dtype=torch.bfloat16,
).to("cuda")
pipe.load_lora_weights("black-forest-labs/FLUX.1-Canny-dev-lora")

prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
control_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")

processor = CannyDetector()
control_image = processor(control_image, low_threshold=50, high_threshold=200, detect_resolution=1024, image_resolution=1024)

image = pipe(
    prompt=prompt,
    control_image=control_image,
    height=1024,
    width=1024,
    num_inference_steps=50,
    guidance_scale=30.0,
).images[0]
image.save("output_gguf.png")

GGUF + regular LoRA works

Code
import torch

from diffusers import FluxPipeline, FluxTransformer2DModel, GGUFQuantizationConfig

ckpt_path = (
    "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
)
transformer = FluxTransformer2DModel.from_single_file(
    ckpt_path,
    quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
    torch_dtype=torch.bfloat16,
)
pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    transformer=transformer,
    torch_dtype=torch.bfloat16,
).to("cuda")
pipe.load_lora_weights("ByteDance/Hyper-SD", weight_name="Hyper-FLUX.1-dev-8steps-lora.safetensors")
prompt = "A cat holding a sign that says hello world"
image = pipe(
    prompt, 
    num_inference_steps=8,
    generator=torch.manual_seed(0), 
    joint_attention_kwargs={"scale": 0.125}
).images[0]
image.save("flux-gguf-lora.png")

What am I missing?

Additionally, I pushed some changes here in 9c12c30. It mainly is about renaming the function to _maybe_dequantize_weight_for_expanded_lora() and adding a test in GGUF for LoRA. I have also added peft as an additional dependency in the GGUF CI. Would you mind taking a look?

@DN6
Copy link
Collaborator

DN6 commented Apr 8, 2025

But don't you think that if we ship the PR without this change, it will create a breaking change when we eventually ship the recreation of the quantized layer for the expanded module since the values would change?

For GGUF it should matter since the weight is dynamically dequantized, so it recover the exact same tensor on inference. For BnB it should work the same no? We're padding dequantized weights with zeros, so quantizing them again should give you the same weights + padding?

What am I missing?

If this was regarding GGUF + regular LoRAs, I've pushed the fix in this PR. If you run your GGUF + regular LoRA snippet on main it will fail.

Tests and renaming LGTM 👍🏽

@sayakpaul
Copy link
Member

Hmm, a bit confused. The current solution first dequantizes if the underlying module weight is quantized and then creates the expanded module with torch.nn.Linear with the dequantized module weight.

Usually, dequantization + quantization will incur some info loss. But there's a way to create a quantized layer with quantized module weights directly with bnb.nn.Params4bit.from_prequantized. WDYT?

@DN6
Copy link
Collaborator

DN6 commented Apr 8, 2025

Usually, dequantization + quantization will incur some info loss. But there's a way to create a quantized layer with quantized module weights directly with bnb.nn.Params4bit.from_prequantized. WDYT?

Wouldn't you always have to dequantize? If the shape of the 4bit param is different, how do we zero pad?

@sayakpaul
Copy link
Member

You're right.

@sayakpaul
Copy link
Member

Will merge once the CI is through

@sayakpaul sayakpaul merged commit 5d49b3e into huggingface:main Apr 8, 2025
28 of 29 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Flux Controlnet Lora Fails To Load When Transformers Are Quantized.

4 participants