Skip to content
Merged
Changes from 2 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
36eb48b
Flux quantized with lora
hlky Mar 6, 2025
695ad14
fix
hlky Mar 6, 2025
bc912fc
changes
hlky Mar 7, 2025
f950380
Apply suggestions from code review
hlky Mar 7, 2025
67bc7c0
Apply style fixes
github-actions[bot] Mar 7, 2025
316d52f
Merge branch 'main' into flux-quantized-w-lora
hlky Mar 7, 2025
9df7c94
Merge branch 'main' into flux-quantized-w-lora
sayakpaul Mar 17, 2025
ffbc7c0
Merge branch 'main' into flux-quantized-w-lora
sayakpaul Mar 18, 2025
b1e752a
Merge branch 'main' into flux-quantized-w-lora
sayakpaul Mar 20, 2025
2c21d34
Merge branch 'main' into flux-quantized-w-lora
sayakpaul Mar 20, 2025
d39497b
enable model cpu offload()
sayakpaul Mar 20, 2025
3ce35c9
Merge pull request #1 from huggingface/hlky-flux-quantized-w-lora
hlky Mar 20, 2025
b504f61
Merge branch 'main' into flux-quantized-w-lora
sayakpaul Mar 20, 2025
514f1d7
Merge branch 'main' into flux-quantized-w-lora
sayakpaul Mar 21, 2025
572c5fe
Merge branch 'main' into flux-quantized-w-lora
DN6 Mar 31, 2025
299c6ab
Update src/diffusers/loaders/lora_pipeline.py
DN6 Apr 2, 2025
12a837b
update
DN6 Apr 7, 2025
de9d3b7
Merge branch 'main' into flux-quantized-w-lora
sayakpaul Apr 8, 2025
0a71d38
Apply suggestions from code review
sayakpaul Apr 8, 2025
9c12c30
update
sayakpaul Apr 8, 2025
7cfadf6
add peft as an additional dependency for gguf
sayakpaul Apr 8, 2025
eadbaac
Merge branch 'main' into flux-quantized-w-lora
sayakpaul Apr 8, 2025
16098be
Merge branch 'main' into flux-quantized-w-lora
sayakpaul Apr 8, 2025
d980148
Merge branch 'main' into flux-quantized-w-lora
sayakpaul Apr 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch
from huggingface_hub.utils import validate_hf_hub_args

from ..quantizers.bitsandbytes import dequantize_bnb_weight
from ..utils import (
USE_PEFT_BACKEND,
deprecate,
Expand Down Expand Up @@ -1970,7 +1971,11 @@ def _maybe_expand_transformer_param_shape_or_error_(
is_peft_loaded = getattr(transformer, "peft_config", None) is not None
for name, module in transformer.named_modules():
if isinstance(module, torch.nn.Linear):
module_weight = module.weight.data
module_weight = (
dequantize_bnb_weight(module.weight, state=module.weight.quant_state).data
if module.weight.__class__.__name__ == "Params4bit"
else module.weight.data
)
module_bias = module.bias.data if module.bias is not None else None
bias = module_bias is not None

Expand All @@ -1994,7 +1999,7 @@ def _maybe_expand_transformer_param_shape_or_error_(

# TODO (sayakpaul): We still need to consider if the module we're expanding is
# quantized and handle it accordingly if that is the case.
module_out_features, module_in_features = module_weight.shape
module_out_features, module_in_features = module_weight_shape
debug_message = ""
if in_features > module_in_features:
debug_message += (
Expand Down Expand Up @@ -2028,7 +2033,7 @@ def _maybe_expand_transformer_param_shape_or_error_(
new_weight = torch.zeros_like(
expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype
)
slices = tuple(slice(0, dim) for dim in module_weight.shape)
slices = tuple(slice(0, dim) for dim in module_weight_shape)
new_weight[slices] = module_weight
tmp_state_dict = {"weight": new_weight}
if module_bias is not None:
Expand Down
Loading