Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
36eb48b
Flux quantized with lora
hlky Mar 6, 2025
695ad14
fix
hlky Mar 6, 2025
bc912fc
changes
hlky Mar 7, 2025
f950380
Apply suggestions from code review
hlky Mar 7, 2025
67bc7c0
Apply style fixes
github-actions[bot] Mar 7, 2025
316d52f
Merge branch 'main' into flux-quantized-w-lora
hlky Mar 7, 2025
9df7c94
Merge branch 'main' into flux-quantized-w-lora
sayakpaul Mar 17, 2025
ffbc7c0
Merge branch 'main' into flux-quantized-w-lora
sayakpaul Mar 18, 2025
b1e752a
Merge branch 'main' into flux-quantized-w-lora
sayakpaul Mar 20, 2025
2c21d34
Merge branch 'main' into flux-quantized-w-lora
sayakpaul Mar 20, 2025
d39497b
enable model cpu offload()
sayakpaul Mar 20, 2025
3ce35c9
Merge pull request #1 from huggingface/hlky-flux-quantized-w-lora
hlky Mar 20, 2025
b504f61
Merge branch 'main' into flux-quantized-w-lora
sayakpaul Mar 20, 2025
514f1d7
Merge branch 'main' into flux-quantized-w-lora
sayakpaul Mar 21, 2025
572c5fe
Merge branch 'main' into flux-quantized-w-lora
DN6 Mar 31, 2025
299c6ab
Update src/diffusers/loaders/lora_pipeline.py
DN6 Apr 2, 2025
12a837b
update
DN6 Apr 7, 2025
de9d3b7
Merge branch 'main' into flux-quantized-w-lora
sayakpaul Apr 8, 2025
0a71d38
Apply suggestions from code review
sayakpaul Apr 8, 2025
9c12c30
update
sayakpaul Apr 8, 2025
7cfadf6
add peft as an additional dependency for gguf
sayakpaul Apr 8, 2025
eadbaac
Merge branch 'main' into flux-quantized-w-lora
sayakpaul Apr 8, 2025
16098be
Merge branch 'main' into flux-quantized-w-lora
sayakpaul Apr 8, 2025
d980148
Merge branch 'main' into flux-quantized-w-lora
sayakpaul Apr 8, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/nightly_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ jobs:
additional_deps: ["peft"]
- backend: "gguf"
test_location: "gguf"
additional_deps: []
additional_deps: ["peft"]
- backend: "torchao"
test_location: "torchao"
additional_deps: []
Expand Down
63 changes: 58 additions & 5 deletions src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
USE_PEFT_BACKEND,
deprecate,
get_submodule_by_name,
is_bitsandbytes_available,
is_gguf_available,
is_peft_available,
is_peft_version,
is_torch_version,
Expand Down Expand Up @@ -68,6 +70,49 @@
_MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX = {"x_embedder": "in_channels"}


def _maybe_dequantize_weight_for_expanded_lora(model, module):
if is_bitsandbytes_available():
from ..quantizers.bitsandbytes import dequantize_bnb_weight

if is_gguf_available():
from ..quantizers.gguf.utils import dequantize_gguf_tensor

is_bnb_4bit_quantized = module.weight.__class__.__name__ == "Params4bit"
is_gguf_quantized = module.weight.__class__.__name__ == "GGUFParameter"

if is_bnb_4bit_quantized and not is_bitsandbytes_available():
raise ValueError(
"The checkpoint seems to have been quantized with `bitsandbytes` (4bits). Install `bitsandbytes` to load quantized checkpoints."
)
if is_gguf_quantized and not is_gguf_available():
raise ValueError(
"The checkpoint seems to have been quantized with `gguf`. Install `gguf` to load quantized checkpoints."
)

weight_on_cpu = False
if not module.weight.is_cuda:
weight_on_cpu = True

if is_bnb_4bit_quantized:
module_weight = dequantize_bnb_weight(
module.weight.cuda() if weight_on_cpu else module.weight,
state=module.weight.quant_state,
dtype=model.dtype,
).data
elif is_gguf_quantized:
module_weight = dequantize_gguf_tensor(
module.weight.cuda() if weight_on_cpu else module.weight,
)
module_weight = module_weight.to(model.dtype)
else:
module_weight = module.weight.data

if weight_on_cpu:
module_weight = module_weight.cpu()

return module_weight


class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
r"""
Load LoRA layers into Stable Diffusion [`UNet2DConditionModel`] and
Expand Down Expand Up @@ -2267,6 +2312,7 @@ def _maybe_expand_transformer_param_shape_or_error_(
overwritten_params = {}

is_peft_loaded = getattr(transformer, "peft_config", None) is not None
is_quantized = hasattr(transformer, "hf_quantizer")
for name, module in transformer.named_modules():
if isinstance(module, torch.nn.Linear):
module_weight = module.weight.data
Expand All @@ -2291,9 +2337,7 @@ def _maybe_expand_transformer_param_shape_or_error_(
if tuple(module_weight_shape) == (out_features, in_features):
continue

# TODO (sayakpaul): We still need to consider if the module we're expanding is
# quantized and handle it accordingly if that is the case.
module_out_features, module_in_features = module_weight.shape
module_out_features, module_in_features = module_weight_shape
debug_message = ""
if in_features > module_in_features:
debug_message += (
Expand All @@ -2316,6 +2360,10 @@ def _maybe_expand_transformer_param_shape_or_error_(
parent_module_name, _, current_module_name = name.rpartition(".")
parent_module = transformer.get_submodule(parent_module_name)

if is_quantized:
module_weight = _maybe_dequantize_weight_for_expanded_lora(transformer, module)

# TODO: consider if this layer needs to be a quantized layer as well if `is_quantized` is True.
with torch.device("meta"):
expanded_module = torch.nn.Linear(
in_features, out_features, bias=bias, dtype=module_weight.dtype
Expand All @@ -2327,7 +2375,7 @@ def _maybe_expand_transformer_param_shape_or_error_(
new_weight = torch.zeros_like(
expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype
)
slices = tuple(slice(0, dim) for dim in module_weight.shape)
slices = tuple(slice(0, dim) for dim in module_weight_shape)
new_weight[slices] = module_weight
tmp_state_dict = {"weight": new_weight}
if module_bias is not None:
Expand Down Expand Up @@ -2416,7 +2464,12 @@ def _calculate_module_shape(
base_weight_param_name: str = None,
) -> "torch.Size":
def _get_weight_shape(weight: torch.Tensor):
return weight.quant_state.shape if weight.__class__.__name__ == "Params4bit" else weight.shape
if weight.__class__.__name__ == "Params4bit":
return weight.quant_state.shape
elif weight.__class__.__name__ == "GGUFParameter":
return weight.quant_shape
else:
return weight.shape

if base_module is not None:
return _get_weight_shape(base_module.weight)
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/quantizers/gguf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,8 @@ def __new__(cls, data, requires_grad=False, quant_type=None):
data = data if data is not None else torch.empty(0)
self = torch.Tensor._make_subclass(cls, data, requires_grad)
self.quant_type = quant_type
block_size, type_size = GGML_QUANT_SIZES[quant_type]
self.quant_shape = _quant_shape_from_byte_shape(self.shape, type_size, block_size)

return self

Expand Down
47 changes: 45 additions & 2 deletions tests/quantization/bnb/test_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,15 @@
import pytest
import safetensors.torch
from huggingface_hub import hf_hub_download

from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel
from PIL import Image

from diffusers import (
BitsAndBytesConfig,
DiffusionPipeline,
FluxControlPipeline,
FluxTransformer2DModel,
SD3Transformer2DModel,
)
from diffusers.utils import is_accelerate_version, logging
from diffusers.utils.testing_utils import (
CaptureLogger,
Expand Down Expand Up @@ -696,6 +703,42 @@ def test_lora_loading(self):
self.assertTrue(max_diff < 1e-3)


@require_transformers_version_greater("4.44.0")
@require_peft_backend
class SlowBnb4BitFluxControlWithLoraTests(Base4bitTests):
def setUp(self) -> None:
gc.collect()
torch.cuda.empty_cache()

self.pipeline_4bit = FluxControlPipeline.from_pretrained("eramth/flux-4bit", torch_dtype=torch.float16)
self.pipeline_4bit.enable_model_cpu_offload()

def tearDown(self):
del self.pipeline_4bit

gc.collect()
torch.cuda.empty_cache()

def test_lora_loading(self):
self.pipeline_4bit.load_lora_weights("black-forest-labs/FLUX.1-Canny-dev-lora")

output = self.pipeline_4bit(
prompt=self.prompt,
control_image=Image.new(mode="RGB", size=(256, 256)),
height=256,
width=256,
max_sequence_length=64,
output_type="np",
num_inference_steps=8,
generator=torch.Generator().manual_seed(42),
).images
out_slice = output[0, -3:, -3:, -1].flatten()
expected_slice = np.array([0.1636, 0.1675, 0.1982, 0.1743, 0.1809, 0.1936, 0.1743, 0.2095, 0.2139])

max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
self.assertTrue(max_diff < 1e-3, msg=f"{out_slice=} != {expected_slice=}")


@slow
class BaseBnb4BitSerializationTests(Base4bitTests):
def tearDown(self):
Expand Down
46 changes: 46 additions & 0 deletions tests/quantization/gguf/test_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,22 @@
from diffusers import (
AuraFlowPipeline,
AuraFlowTransformer2DModel,
FluxControlPipeline,
FluxPipeline,
FluxTransformer2DModel,
GGUFQuantizationConfig,
SD3Transformer2DModel,
StableDiffusion3Pipeline,
)
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
is_gguf_available,
nightly,
numpy_cosine_similarity_distance,
require_accelerate,
require_big_gpu_with_torch_cuda,
require_gguf_version_greater_or_equal,
require_peft_backend,
torch_device,
)

Expand Down Expand Up @@ -456,3 +459,46 @@ def test_pipeline_inference(self):
)
max_diff = numpy_cosine_similarity_distance(expected_slice, output_slice)
assert max_diff < 1e-4


@require_peft_backend
@nightly
@require_big_gpu_with_torch_cuda
@require_accelerate
@require_gguf_version_greater_or_equal("0.10.0")
class FluxControlLoRAGGUFTests(unittest.TestCase):
def test_lora_loading(self):
ckpt_path = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
transformer = FluxTransformer2DModel.from_single_file(
ckpt_path,
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
torch_dtype=torch.bfloat16,
)
pipe = FluxControlPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
transformer=transformer,
torch_dtype=torch.bfloat16,
).to("cuda")
pipe.load_lora_weights("black-forest-labs/FLUX.1-Canny-dev-lora")

prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
control_image = load_image(
"https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/control_image_robot_canny.png"
)

output = pipe(
prompt=prompt,
control_image=control_image,
height=256,
width=256,
num_inference_steps=10,
guidance_scale=30.0,
output_type="np",
generator=torch.manual_seed(0),
).images

out_slice = output[0, -3:, -3:, -1].flatten()
expected_slice = np.array([0.8047, 0.8359, 0.8711, 0.6875, 0.7070, 0.7383, 0.5469, 0.5820, 0.6641])

max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
self.assertTrue(max_diff < 1e-3)
Loading