Skip to content

Commit 12a837b

Browse files
committed
update
1 parent 299c6ab commit 12a837b

File tree

2 files changed

+57
-22
lines changed

2 files changed

+57
-22
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 55 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
deprecate,
2424
get_submodule_by_name,
2525
is_bitsandbytes_available,
26+
is_gguf_available,
2627
is_peft_available,
2728
is_peft_version,
2829
is_torch_version,
@@ -49,9 +50,6 @@
4950
)
5051

5152

52-
if is_bitsandbytes_available():
53-
from ..quantizers.bitsandbytes import dequantize_bnb_weight
54-
5553
_LOW_CPU_MEM_USAGE_DEFAULT_LORA = False
5654
if is_torch_version(">=", "1.9.0"):
5755
if (
@@ -72,6 +70,49 @@
7270
_MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX = {"x_embedder": "in_channels"}
7371

7472

73+
def _dequantize_weight_for_expanded_lora(model, module):
74+
if is_bitsandbytes_available():
75+
from ..quantizers.bitsandbytes import dequantize_bnb_weight
76+
77+
if is_gguf_available():
78+
from ..quantizers.gguf.utils import dequantize_gguf_tensor
79+
80+
is_bnb_4bit_quantized = module.weight.__class__.__name__ == "Params4bit"
81+
is_gguf_quantized = module.weight.__class__.__name__ == "GGUFParameter"
82+
83+
if is_bnb_4bit_quantized and not is_bitsandbytes_available():
84+
raise ValueError(
85+
"The checkpoint seems to have been quantized with `bitsandbytes` (4bits). Install `bitsandbytes` to load quantized checkpoints."
86+
)
87+
if is_gguf_quantized and not is_gguf_available():
88+
raise ValueError(
89+
"The checkpoint seems to have been quantized with `gguf`. Install `gguf` to load quantized checkpoints."
90+
)
91+
92+
weight_on_cpu = False
93+
if not module.weight.is_cuda:
94+
weight_on_cpu = True
95+
96+
if is_bnb_4bit_quantized:
97+
module_weight = dequantize_bnb_weight(
98+
module.weight.cuda() if weight_on_cpu else module.weight,
99+
state=module.weight.quant_state,
100+
dtype=model.dtype,
101+
).data
102+
elif is_gguf_quantized:
103+
module_weight = dequantize_gguf_tensor(
104+
module.weight.cuda() if weight_on_cpu else module.weight,
105+
)
106+
module_weight = module_weight.to(model.dtype)
107+
else:
108+
module_weight = module.weight.data
109+
110+
if weight_on_cpu:
111+
module_weight = module_weight.cpu()
112+
113+
return module_weight
114+
115+
75116
class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
76117
r"""
77118
Load LoRA layers into Stable Diffusion [`UNet2DConditionModel`] and
@@ -1970,26 +2011,10 @@ def _maybe_expand_transformer_param_shape_or_error_(
19702011
overwritten_params = {}
19712012

19722013
is_peft_loaded = getattr(transformer, "peft_config", None) is not None
2014+
is_quantized = hasattr(transformer, "hf_quantizer")
19732015
for name, module in transformer.named_modules():
19742016
if isinstance(module, torch.nn.Linear):
1975-
is_bnb_4bit_quantized = module.weight.__class__.__name__ == "Params4bit"
1976-
if is_bnb_4bit_quantized and not is_bitsandbytes_available():
1977-
raise ValueError(
1978-
"The checkpoint seems to have been quantized with `bitsandbytes` (4bits). Install `bitsandbytes` to load quantized checkpoints."
1979-
)
1980-
elif is_bnb_4bit_quantized:
1981-
weight_on_cpu = False
1982-
if not module.weight.is_cuda:
1983-
weight_on_cpu = True
1984-
module_weight = dequantize_bnb_weight(
1985-
module.weight.cuda() if weight_on_cpu else module.weight,
1986-
state=module.weight.quant_state,
1987-
dtype=transformer.dtype,
1988-
).data
1989-
if weight_on_cpu:
1990-
module_weight = module_weight.cpu()
1991-
else:
1992-
module_weight = module.weight.data
2017+
module_weight = module.weight.data
19932018
module_bias = module.bias.data if module.bias is not None else None
19942019
bias = module_bias is not None
19952020

@@ -2034,6 +2059,9 @@ def _maybe_expand_transformer_param_shape_or_error_(
20342059
parent_module_name, _, current_module_name = name.rpartition(".")
20352060
parent_module = transformer.get_submodule(parent_module_name)
20362061

2062+
if is_quantized:
2063+
module_weight = _dequantize_weight_for_expanded_lora(transformer, module)
2064+
20372065
with torch.device("meta"):
20382066
expanded_module = torch.nn.Linear(
20392067
in_features, out_features, bias=bias, dtype=module_weight.dtype
@@ -2134,7 +2162,12 @@ def _calculate_module_shape(
21342162
base_weight_param_name: str = None,
21352163
) -> "torch.Size":
21362164
def _get_weight_shape(weight: torch.Tensor):
2137-
return weight.quant_state.shape if weight.__class__.__name__ == "Params4bit" else weight.shape
2165+
if weight.__class__.__name__ == "Params4bit":
2166+
return weight.quant_state.shape
2167+
elif weight.__class__.__name__ == "GGUFParameter":
2168+
return weight.quant_shape
2169+
else:
2170+
return weight.shape
21382171

21392172
if base_module is not None:
21402173
return _get_weight_shape(base_module.weight)

src/diffusers/quantizers/gguf/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,8 @@ def __new__(cls, data, requires_grad=False, quant_type=None):
400400
data = data if data is not None else torch.empty(0)
401401
self = torch.Tensor._make_subclass(cls, data, requires_grad)
402402
self.quant_type = quant_type
403+
block_size, type_size = GGML_QUANT_SIZES[quant_type]
404+
self.quant_shape = _quant_shape_from_byte_shape(self.shape, type_size, block_size)
403405

404406
return self
405407

0 commit comments

Comments
 (0)