Skip to content

Commit 779c17b

Browse files
committed
feat: support loading loras into 4bit quantized models.
1 parent be62c85 commit 779c17b

File tree

2 files changed

+46
-3
lines changed

2 files changed

+46
-3
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1982,9 +1982,19 @@ def _maybe_expand_transformer_param_shape_or_error_(
19821982
out_features = state_dict[lora_B_weight_name].shape[0]
19831983

19841984
# This means there's no need for an expansion in the params, so we simply skip.
1985-
if tuple(module_weight.shape) == (out_features, in_features):
1985+
module_weight_shape = module_weight.shape
1986+
expansion_shape = (out_features, in_features)
1987+
quantization_config = getattr(transformer, "quantization_config", None)
1988+
if quantization_config and quantization_config.quant_method == "bitsandbytes":
1989+
if quantization_config.load_in_4bit:
1990+
expansion_shape = torch.Size(expansion_shape).numel()
1991+
expansion_shape = ((expansion_shape + 1) // 2, 1)
1992+
1993+
if tuple(module_weight_shape) == expansion_shape:
19861994
continue
19871995

1996+
# TODO (sayakpaul): We still need to consider if the module we're expanding is
1997+
# quantized and handle it accordingly if that is the case.
19881998
module_out_features, module_in_features = module_weight.shape
19891999
debug_message = ""
19902000
if in_features > module_in_features:
@@ -2080,13 +2090,22 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):
20802090
base_weight_param = transformer_state_dict[base_param_name]
20812091
lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"]
20822092

2083-
if base_weight_param.shape[1] > lora_A_param.shape[1]:
2093+
# TODO (sayakpaul): Handle the cases when we actually need to expand.
2094+
base_out_feature_shape = base_weight_param.shape[1]
2095+
lora_A_out_feature_shape = lora_A_param.shape[1]
2096+
quantization_config = getattr(transformer, "quantization_config", None)
2097+
if quantization_config and quantization_config.quant_method == "bitsandbytes":
2098+
if quantization_config.load_in_4bit:
2099+
lora_A_out_feature_shape = lora_A_param.shape.numel()
2100+
lora_A_out_feature_shape = ((lora_A_out_feature_shape + 1) // 2, 1)[1]
2101+
2102+
if base_out_feature_shape > lora_A_out_feature_shape:
20842103
shape = (lora_A_param.shape[0], base_weight_param.shape[1])
20852104
expanded_state_dict_weight = torch.zeros(shape, device=base_weight_param.device)
20862105
expanded_state_dict_weight[:, : lora_A_param.shape[1]].copy_(lora_A_param)
20872106
lora_state_dict[f"{prefix}{k}.lora_A.weight"] = expanded_state_dict_weight
20882107
expanded_module_names.add(k)
2089-
elif base_weight_param.shape[1] < lora_A_param.shape[1]:
2108+
elif lora_A_out_feature_shape < lora_A_out_feature_shape:
20902109
raise NotImplementedError(
20912110
f"This LoRA param ({k}.lora_A.weight) has an incompatible shape {lora_A_param.shape}. Please open an issue to file for a feature request - https://github.com/huggingface/diffusers/issues/new."
20922111
)

tests/quantization/bnb/test_4bit.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import numpy as np
2121
import pytest
2222
import safetensors.torch
23+
from huggingface_hub import hf_hub_download
2324

2425
from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel
2526
from diffusers.utils import is_accelerate_version, logging
@@ -32,6 +33,7 @@
3233
numpy_cosine_similarity_distance,
3334
require_accelerate,
3435
require_bitsandbytes_version_greater,
36+
require_peft_version_greater,
3537
require_torch,
3638
require_torch_gpu,
3739
require_transformers_version_greater,
@@ -568,6 +570,28 @@ def test_quality(self):
568570
max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
569571
self.assertTrue(max_diff < 1e-3)
570572

573+
@require_peft_version_greater("0.14.0")
574+
def test_lora_loading_works(self):
575+
self.pipeline_4bit.load_lora_weights(
576+
hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd"
577+
)
578+
self.pipeline_4bit.set_adapters("hyper-sd", adapter_weights=0.125)
579+
580+
output = self.pipeline_4bit(
581+
prompt=self.prompt,
582+
height=256,
583+
width=256,
584+
max_sequence_length=64,
585+
output_type="np",
586+
num_inference_steps=8,
587+
generator=torch.Generator().manual_seed(42),
588+
).images
589+
out_slice = output[0, -3:, -3:, -1].flatten()
590+
expected_slice = np.array([0.5347, 0.5342, 0.5283, 0.5093, 0.4988, 0.5093, 0.5044, 0.5015, 0.4946])
591+
592+
max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
593+
self.assertTrue(max_diff < 1e-3)
594+
571595

572596
@slow
573597
class BaseBnb4BitSerializationTests(Base4bitTests):

0 commit comments

Comments
 (0)