Skip to content

Commit c92758f

Browse files
committed
remove weight check.
1 parent 8b13c1e commit c92758f

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1983,7 +1983,8 @@ def _maybe_expand_transformer_param_shape_or_error_(
19831983
out_features = state_dict[lora_B_weight_name].shape[0]
19841984

19851985
# Model maybe loaded with different quantization schemes which may flatten the params.
1986-
# `bitsandbytes`, for example, flatten the weights when using 4bit.
1986+
# `bitsandbytes`, for example, flatten the weights when using 4bit. 8bit bnb models
1987+
# preserve weight shape.
19871988
module_weight_shape = cls._calculate_module_shape(model=transformer, base_module=module)
19881989

19891990
# This means there's no need for an expansion in the params, so we simply skip.
@@ -2120,11 +2121,11 @@ def _get_weight_shape(weight: torch.Tensor):
21202121
if base_module is not None:
21212122
return _get_weight_shape(base_module.weight)
21222123
elif base_weight_param_name is not None:
2123-
module_path = (
2124-
base_weight_param_name.rsplit(".weight", 1)[0]
2125-
if base_weight_param_name.endswith(".weight")
2126-
else base_weight_param_name
2127-
)
2124+
if not base_weight_param_name.endswith(".weight"):
2125+
raise ValueError(
2126+
f"Invalid `base_weight_param_name` passed as it does not end with '.weight' {base_weight_param_name=}."
2127+
)
2128+
module_path = base_weight_param_name.rsplit(".weight", 1)[0]
21282129
submodule = get_submodule_by_name(model, module_path)
21292130
return _get_weight_shape(submodule.weight)
21302131

0 commit comments

Comments
 (0)