Skip to content

Commit 36eb48b

Browse files
committed
Flux quantized with lora
1 parent f103993 commit 36eb48b

File tree

1 file changed

+10
-11
lines changed

1 file changed

+10
-11
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import torch
1919
from huggingface_hub.utils import validate_hf_hub_args
2020

21+
from ..quantizers.bitsandbytes import dequantize_bnb_weight
2122
from ..utils import (
2223
USE_PEFT_BACKEND,
2324
deprecate,
@@ -1905,7 +1906,6 @@ def unload_lora_weights(self, reset_to_overwritten_params=False):
19051906

19061907
for name, module in transformer.named_modules():
19071908
if isinstance(module, torch.nn.Linear) and name in module_names:
1908-
module_weight = module.weight.data
19091909
module_bias = module.bias.data if module.bias is not None else None
19101910
bias = module_bias is not None
19111911

@@ -1919,7 +1919,6 @@ def unload_lora_weights(self, reset_to_overwritten_params=False):
19191919
in_features,
19201920
out_features,
19211921
bias=bias,
1922-
dtype=module_weight.dtype,
19231922
)
19241923

19251924
tmp_state_dict = {"weight": current_param_weight}
@@ -1970,7 +1969,11 @@ def _maybe_expand_transformer_param_shape_or_error_(
19701969
is_peft_loaded = getattr(transformer, "peft_config", None) is not None
19711970
for name, module in transformer.named_modules():
19721971
if isinstance(module, torch.nn.Linear):
1973-
module_weight = module.weight.data
1972+
module_weight = (
1973+
dequantize_bnb_weight(module.weight, state=module.weight.quant_state).data
1974+
if module.weight.__class__.__name__ == "Params4bit"
1975+
else module.weight.data
1976+
)
19741977
module_bias = module.bias.data if module.bias is not None else None
19751978
bias = module_bias is not None
19761979

@@ -1994,7 +1997,7 @@ def _maybe_expand_transformer_param_shape_or_error_(
19941997

19951998
# TODO (sayakpaul): We still need to consider if the module we're expanding is
19961999
# quantized and handle it accordingly if that is the case.
1997-
module_out_features, module_in_features = module_weight.shape
2000+
module_out_features, module_in_features = module_weight_shape
19982001
debug_message = ""
19992002
if in_features > module_in_features:
20002003
debug_message += (
@@ -2018,17 +2021,13 @@ def _maybe_expand_transformer_param_shape_or_error_(
20182021
parent_module = transformer.get_submodule(parent_module_name)
20192022

20202023
with torch.device("meta"):
2021-
expanded_module = torch.nn.Linear(
2022-
in_features, out_features, bias=bias, dtype=module_weight.dtype
2023-
)
2024+
expanded_module = torch.nn.Linear(in_features, out_features, bias=bias)
20242025
# Only weights are expanded and biases are not. This is because only the input dimensions
20252026
# are changed while the output dimensions remain the same. The shape of the weight tensor
20262027
# is (out_features, in_features), while the shape of bias tensor is (out_features,), which
20272028
# explains the reason why only weights are expanded.
2028-
new_weight = torch.zeros_like(
2029-
expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype
2030-
)
2031-
slices = tuple(slice(0, dim) for dim in module_weight.shape)
2029+
new_weight = torch.zeros_like(expanded_module.weight.data)
2030+
slices = tuple(slice(0, dim) for dim in module_weight_shape)
20322031
new_weight[slices] = module_weight
20332032
tmp_state_dict = {"weight": new_weight}
20342033
if module_bias is not None:

0 commit comments

Comments
 (0)