Skip to content

Commit d3d8ef2

Browse files
committed
updates
1 parent f46ba42 commit d3d8ef2

File tree

4 files changed

+47
-23
lines changed

4 files changed

+47
-23
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from ..utils import (
2222
USE_PEFT_BACKEND,
2323
deprecate,
24+
get_submodule_by_name,
2425
is_peft_available,
2526
is_peft_version,
2627
is_torch_version,
@@ -1981,16 +1982,12 @@ def _maybe_expand_transformer_param_shape_or_error_(
19811982
in_features = state_dict[lora_A_weight_name].shape[1]
19821983
out_features = state_dict[lora_B_weight_name].shape[0]
19831984

1985+
# Model maybe loaded with different quantization schemes which may flatten the params.
1986+
# `bitsandbytes`, for example, flatten the weights when using 4bit.
1987+
module_weight_shape = cls._calculate_module_shape(model=transformer, base_module=module)
1988+
19841989
# This means there's no need for an expansion in the params, so we simply skip.
1985-
module_weight_shape = module_weight.shape
1986-
expansion_shape = (out_features, in_features)
1987-
quantization_config = getattr(transformer, "quantization_config", None)
1988-
if quantization_config and quantization_config.quant_method == "bitsandbytes":
1989-
if quantization_config.load_in_4bit:
1990-
expansion_shape = torch.Size(expansion_shape).numel()
1991-
expansion_shape = ((expansion_shape + 1) // 2, 1)
1992-
1993-
if tuple(module_weight_shape) == expansion_shape:
1990+
if tuple(module_weight_shape) == (out_features, in_features):
19941991
continue
19951992

19961993
# TODO (sayakpaul): We still need to consider if the module we're expanding is
@@ -2090,22 +2087,16 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):
20902087
base_weight_param = transformer_state_dict[base_param_name]
20912088
lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"]
20922089

2093-
# TODO (sayakpaul): Handle the cases when we actually need to expand.
2094-
base_out_feature_shape = base_weight_param.shape[1]
2095-
lora_A_out_feature_shape = lora_A_param.shape[1]
2096-
quantization_config = getattr(transformer, "quantization_config", None)
2097-
if quantization_config and quantization_config.quant_method == "bitsandbytes":
2098-
if quantization_config.load_in_4bit:
2099-
lora_A_out_feature_shape = lora_A_param.shape.numel()
2100-
lora_A_out_feature_shape = ((lora_A_out_feature_shape + 1) // 2, 1)[1]
2090+
# TODO (sayakpaul): Handle the cases when we actually need to expand when using quantization.
2091+
base_module_shape = cls._calculate_module_shape(model=transformer, base_weight_param_name=base_param_name)
21012092

2102-
if base_out_feature_shape > lora_A_out_feature_shape:
2093+
if base_module_shape[1] > lora_A_param.shape[1]:
21032094
shape = (lora_A_param.shape[0], base_weight_param.shape[1])
21042095
expanded_state_dict_weight = torch.zeros(shape, device=base_weight_param.device)
21052096
expanded_state_dict_weight[:, : lora_A_param.shape[1]].copy_(lora_A_param)
21062097
lora_state_dict[f"{prefix}{k}.lora_A.weight"] = expanded_state_dict_weight
21072098
expanded_module_names.add(k)
2108-
elif lora_A_out_feature_shape < lora_A_out_feature_shape:
2099+
elif base_module_shape[1] < lora_A_param.shape[1]:
21092100
raise NotImplementedError(
21102101
f"This LoRA param ({k}.lora_A.weight) has an incompatible shape {lora_A_param.shape}. Please open an issue to file for a feature request - https://github.com/huggingface/diffusers/issues/new."
21112102
)
@@ -2117,6 +2108,28 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):
21172108

21182109
return lora_state_dict
21192110

2111+
@staticmethod
2112+
def _calculate_module_shape(
2113+
model: "torch.nn.Module",
2114+
base_module: "torch.nn.Linear" = None,
2115+
base_weight_param_name: str = None,
2116+
) -> "torch.Size":
2117+
def _get_weight_shape(weight: torch.Tensor):
2118+
return weight.quant_state.shape if weight.__class__.__name__ == "Params4bit" else weight.shape
2119+
2120+
if base_module is not None:
2121+
return _get_weight_shape(base_module.weight)
2122+
elif base_weight_param_name is not None:
2123+
module_path = (
2124+
base_weight_param_name.rsplit(".weight", 1)[0]
2125+
if base_weight_param_name.endswith(".weight")
2126+
else base_weight_param_name
2127+
)
2128+
submodule = get_submodule_by_name(model, module_path)
2129+
return _get_weight_shape(submodule.weight)
2130+
2131+
raise ValueError("Either `base_module` or `base_weight_param_name` must be provided.")
2132+
21202133

21212134
# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially
21222135
# relied on `StableDiffusionLoraLoaderMixin` for its LoRA support.

src/diffusers/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@
101101
is_xformers_available,
102102
requires_backends,
103103
)
104-
from .loading_utils import get_module_from_name, load_image, load_video
104+
from .loading_utils import get_module_from_name, get_submodule_by_name, load_image, load_video
105105
from .logging import get_logger
106106
from .outputs import BaseOutput
107107
from .peft_utils import (

src/diffusers/utils/loading_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,3 +148,16 @@ def get_module_from_name(module, tensor_name: str) -> Tuple[Any, str]:
148148
module = new_module
149149
tensor_name = splits[-1]
150150
return module, tensor_name
151+
152+
153+
def get_submodule_by_name(root_module, module_path: str):
154+
current = root_module
155+
parts = module_path.split(".")
156+
for part in parts:
157+
# If part is integer-like and the current module supports indexing, convert to int
158+
if part.isdigit():
159+
idx = int(part)
160+
current = current[idx] # e.g., for nn.ModuleList or nn.Sequential
161+
else:
162+
current = getattr(current, part)
163+
return current

tests/quantization/bnb/test_4bit.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
numpy_cosine_similarity_distance,
3434
require_accelerate,
3535
require_bitsandbytes_version_greater,
36-
require_peft_version_greater,
3736
require_torch,
3837
require_torch_gpu,
3938
require_transformers_version_greater,
@@ -570,8 +569,7 @@ def test_quality(self):
570569
max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
571570
self.assertTrue(max_diff < 1e-3)
572571

573-
@require_peft_version_greater("0.14.0")
574-
def test_lora_loading_works(self):
572+
def test_lora_loading(self):
575573
self.pipeline_4bit.load_lora_weights(
576574
hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd"
577575
)

0 commit comments

Comments
 (0)