-
Couldn't load subscription status.
- Fork 6.5k
Flux quantized with lora #10990
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Flux quantized with lora #10990
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fantastic! I left some nits, let me know if they make sense.
Somewhat related to #10588, which I had planned to work on. But glad that you beat me to it. Do we want to club that in this PR?
Also, let's add a test to test_4bit.py?
|
Test also added, needs slice updating from runner. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very cool to merge after the slices have been updated. Thanks, @hlky!
Co-authored-by: Sayak Paul <[email protected]>
|
@bot /style |
|
Style fixes have been applied. View the workflow run here. |
|
What's the easiest way to trigger the slow bnb tests? |
|
Would have been this workflow: Easier would be to use with the big GPU ( |
|
While trying to gather slices, I ran into: src/diffusers/loaders/lora_pipeline.py:1552: in load_lora_weights
has_param_with_expanded_shape = self._maybe_expand_transformer_param_shape_or_error_(
src/diffusers/loaders/lora_pipeline.py:1981: in _maybe_expand_transformer_param_shape_or_error_
module_weight = dequantize_bnb_weight(module.weight, state=module.weight.quant_state).data
src/diffusers/quantizers/bitsandbytes/utils.py:171: in dequantize_bnb_weight
output_tensor = bnb.functional.dequantize_4bit(weight.data, weight.quant_state)
../.pyenv/versions/3.10.12/envs/diffusers/lib/python3.10/site-packages/bitsandbytes/functional.py:1363: in dequantize_4bit
is_on_gpu([A, absmax, out])
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
tensors = [tensor([[153],
[230],
[145],
...,
[ 54],
[ 85],
[ 39]], dtype=torch.u....., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]], dtype=torch.float16)]
def is_on_gpu(tensors: Iterable[Optional[torch.Tensor]]):
"""Verifies that the input tensors are all on the same device.
An input tensor may also be marked as `paged`, in which case the device placement is ignored.
Args:
tensors (`Iterable[Optional[torch.Tensor]]`): A list of tensors to verify.
Raises:
`RuntimeError`: Raised when the verification fails.
Returns:
`Literal[True]`
"""
on_gpu = True
gpu_ids = set()
for t in tensors:
# NULL pointers and paged tensors are OK.
if t is not None and not getattr(t, "is_paged", False):
on_gpu &= t.is_cuda
gpu_ids.add(t.device.index)
if not on_gpu:
> raise RuntimeError(
f"All input tensors need to be on the same GPU, but found some tensors to not be on a GPU:\n {[(t.shape, t.device) for t in tensors]}",
)
E RuntimeError: All input tensors need to be on the same GPU, but found some tensors to not be on a GPU:
E [(torch.Size([393216, 1]), device(type='cpu')), (torch.Size([12288]), device(type='cpu')), (torch.Size([3072, 256]), device(type='cpu'))]
../.pyenv/versions/3.10.12/envs/diffusers/lib/python3.10/site-packages/bitsandbytes/functional.py:464: RuntimeError@hlky possible to look into it? It passes with |
|
@hlky I made it possible to do diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py
index a38b38774..10d37d000 100644
--- a/src/diffusers/loaders/lora_pipeline.py
+++ b/src/diffusers/loaders/lora_pipeline.py
@@ -1978,7 +1978,16 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
"The checkpoint seems to have been quantized with `bitsandbytes` (4bits). Install `bitsandbytes` to load quantized checkpoints."
)
elif is_bnb_4bit_quantized:
- module_weight = dequantize_bnb_weight(module.weight, state=module.weight.quant_state).data
+ weight_on_cpu = False
+ if not module.weight.is_cuda:
+ weight_on_cpu = True
+ module_weight = dequantize_bnb_weight(
+ module.weight.cuda() if weight_on_cpu else module.weight,
+ state=module.weight.quant_state,
+ dtype=transformer.dtype
+ ).data
+ if weight_on_cpu:
+ module_weight = module_weight.cpu()
else:
module_weight = module.weight.data
module_bias = module.bias.data if module.bias is not None else None
diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py
index 9b1f78acb..e2a307cad 100644
--- a/tests/quantization/bnb/test_4bit.py
+++ b/tests/quantization/bnb/test_4bit.py
@@ -21,8 +21,9 @@ import numpy as np
import pytest
import safetensors.torch
from huggingface_hub import hf_hub_download
+from PIL import Image
-from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel
+from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel, FluxControlPipeline
from diffusers.utils import is_accelerate_version, logging
from diffusers.utils.testing_utils import (
CaptureLogger,
@@ -702,10 +703,7 @@ class SlowBnb4BitFluxControlWithLoraTests(Base4bitTests):
gc.collect()
torch.cuda.empty_cache()
- self.pipeline_4bit = DiffusionPipeline.from_pretrained(
- "eramth/flux-4bit",
- torch_dtype=torch.float16,
- )
+ self.pipeline_4bit = FluxControlPipeline.from_pretrained("eramth/flux-4bit", torch_dtype=torch.float16)
self.pipeline_4bit.enable_model_cpu_offload()
def tearDown(self):
@@ -719,6 +717,7 @@ class SlowBnb4BitFluxControlWithLoraTests(Base4bitTests):
output = self.pipeline_4bit(
prompt=self.prompt,
+ control_image=Image.new(mode="RGB", size=(256, 256)),
height=256,
width=256,
max_sequence_length=64,
@@ -727,8 +726,7 @@ class SlowBnb4BitFluxControlWithLoraTests(Base4bitTests):
generator=torch.Generator().manual_seed(42),
).images
out_slice = output[0, -3:, -3:, -1].flatten()
- # TODO: update slice
- expected_slice = np.array([0.5347, 0.5342, 0.5283, 0.5093, 0.4988, 0.5093, 0.5044, 0.5015, 0.4946])
+ expected_slice = np.array([0.1636, 0.1675, 0.1982, 0.1743, 0.1809, 0.1936, 0.1743, 0.2095, 0.2139])
max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
self.assertTrue(max_diff < 1e-3, msg=f"{out_slice=} != {expected_slice=}")
Okay to update this branch so we can move forward? |
|
@DN6 could you give this a check? It enables loading Flux Control LoRAs into quantized checkpoints. |
|
@hlky @DN6 when loading a LoRA into a quantized model, I would expect the expanded module to also have quantized linear. For example here we replace the expanded module with diffusers/src/diffusers/loaders/lora_pipeline.py Line 2019 in f424b1b
Or am I thinking incorrectly? |
Co-authored-by: hlky <[email protected]>
The current fix to just dequantize the expanded layer should be okay I think? The expanded module is a single layer no? Wouldn't hurt memory much? But yes technically, we should recreate the quantized linear layer. Because of the shape differences in 4bit BnB weights, I think how it would have to be done is to dequantize the weight, zero pad the extra channels and then create the appropriate quantized linear layer again. For 8bit I think you can just zero pad the quantized weight as is and recreate the quantized layer. I think the issue for expanded modules would affect GGUF and Quanto as well. GGUF has the same problem as 4bit BnB ;the quant weight shapes are different from the original, so I don't think we can naively zero pad the weight. For Quanto, the way the weight tensor parameters are stored in model layers is a bit different from the rest of the backends. We can just fix BnB in this PR and do a more thorough one for all the quant backends. |
|
@hlky @sayakpaul Could you take another look at this. It adds support to handle GGUF quants as well (which run into the same issue as BnB 4bit). Additionally the current |
|
Taking a look. |
This is true and in line with what I had in mind. But don't you think that if we ship the PR without this change, it will create a breaking change when we eventually ship the recreation of the quantized layer for the expanded module since the values would change? @DN6 |
Do you have a reproducer for this? @DN6 |
|
@DN6 some updates. When I say "works" it means loading works and it results into a reasonable image. BnB params + Control LoRA works. Codefrom diffusers import FluxControlPipeline
from diffusers.utils import load_image
from controlnet_aux import CannyDetector
import torch
pipeline_4bit = FluxControlPipeline.from_pretrained("eramth/flux-4bit", torch_dtype=torch.float16).to("cuda")
pipeline_4bit.load_lora_weights("black-forest-labs/FLUX.1-Canny-dev-lora")
prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
control_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")
processor = CannyDetector()
control_image = processor(control_image, low_threshold=50, high_threshold=200, detect_resolution=1024, image_resolution=1024)
image = pipeline_4bit(
prompt=prompt,
control_image=control_image,
height=1024,
width=1024,
num_inference_steps=50,
guidance_scale=30.0,
).images[0]
image.save("output.png")GGUF + Control LoRA works Codeimport torch
from diffusers.utils import load_image
from controlnet_aux import CannyDetector
from diffusers import FluxControlPipeline, FluxTransformer2DModel, GGUFQuantizationConfig
ckpt_path = (
"https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
)
transformer = FluxTransformer2DModel.from_single_file(
ckpt_path,
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
torch_dtype=torch.bfloat16,
)
pipe = FluxControlPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
transformer=transformer,
torch_dtype=torch.bfloat16,
).to("cuda")
pipe.load_lora_weights("black-forest-labs/FLUX.1-Canny-dev-lora")
prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
control_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")
processor = CannyDetector()
control_image = processor(control_image, low_threshold=50, high_threshold=200, detect_resolution=1024, image_resolution=1024)
image = pipe(
prompt=prompt,
control_image=control_image,
height=1024,
width=1024,
num_inference_steps=50,
guidance_scale=30.0,
).images[0]
image.save("output_gguf.png")GGUF + regular LoRA works Codeimport torch
from diffusers import FluxPipeline, FluxTransformer2DModel, GGUFQuantizationConfig
ckpt_path = (
"https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
)
transformer = FluxTransformer2DModel.from_single_file(
ckpt_path,
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
torch_dtype=torch.bfloat16,
)
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
transformer=transformer,
torch_dtype=torch.bfloat16,
).to("cuda")
pipe.load_lora_weights("ByteDance/Hyper-SD", weight_name="Hyper-FLUX.1-dev-8steps-lora.safetensors")
prompt = "A cat holding a sign that says hello world"
image = pipe(
prompt,
num_inference_steps=8,
generator=torch.manual_seed(0),
joint_attention_kwargs={"scale": 0.125}
).images[0]
image.save("flux-gguf-lora.png")What am I missing? Additionally, I pushed some changes here in 9c12c30. It mainly is about renaming the function to |
For GGUF it should matter since the weight is dynamically dequantized, so it recover the exact same tensor on inference. For BnB it should work the same no? We're padding dequantized weights with zeros, so quantizing them again should give you the same weights + padding?
If this was regarding GGUF + regular LoRAs, I've pushed the fix in this PR. If you run your GGUF + regular LoRA snippet on main it will fail. Tests and renaming LGTM 👍🏽 |
|
Hmm, a bit confused. The current solution first dequantizes if the underlying module weight is quantized and then creates the expanded module with Usually, dequantization + quantization will incur some info loss. But there's a way to create a quantized layer with quantized module weights directly with |
Wouldn't you always have to dequantize? If the shape of the 4bit param is different, how do we zero pad? |
|
You're right. |
|
Will merge once the CI is through |
What does this PR do?
Fixes #10989
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.