Skip to content

Commit bc912fc

Browse files
committed
changes
1 parent 695ad14 commit bc912fc

File tree

2 files changed

+49
-8
lines changed

2 files changed

+49
-8
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@
1818
import torch
1919
from huggingface_hub.utils import validate_hf_hub_args
2020

21-
from ..quantizers.bitsandbytes import dequantize_bnb_weight
2221
from ..utils import (
2322
USE_PEFT_BACKEND,
2423
deprecate,
2524
get_submodule_by_name,
25+
is_bitsandbytes_available,
2626
is_peft_available,
2727
is_peft_version,
2828
is_torch_version,
@@ -48,6 +48,9 @@
4848
)
4949

5050

51+
if is_bitsandbytes_available():
52+
from ..quantizers.bitsandbytes import dequantize_bnb_weight
53+
5154
_LOW_CPU_MEM_USAGE_DEFAULT_LORA = False
5255
if is_torch_version(">=", "1.9.0"):
5356
if (
@@ -1971,11 +1974,13 @@ def _maybe_expand_transformer_param_shape_or_error_(
19711974
is_peft_loaded = getattr(transformer, "peft_config", None) is not None
19721975
for name, module in transformer.named_modules():
19731976
if isinstance(module, torch.nn.Linear):
1974-
module_weight = (
1975-
dequantize_bnb_weight(module.weight, state=module.weight.quant_state).data
1976-
if module.weight.__class__.__name__ == "Params4bit"
1977-
else module.weight.data
1978-
)
1977+
is_quantized = module.weight.__class__.__name__ == "Params4bit"
1978+
if is_quantized and not is_bitsandbytes_available():
1979+
raise ValueError("Install `bitsandbytes` to load quantized checkpoints.")
1980+
elif is_quantized:
1981+
module_weight = dequantize_bnb_weight(module.weight, state=module.weight.quant_state).data
1982+
else:
1983+
module_weight = module.weight.data
19791984
module_bias = module.bias.data if module.bias is not None else None
19801985
bias = module_bias is not None
19811986

@@ -1997,8 +2002,6 @@ def _maybe_expand_transformer_param_shape_or_error_(
19972002
if tuple(module_weight_shape) == (out_features, in_features):
19982003
continue
19992004

2000-
# TODO (sayakpaul): We still need to consider if the module we're expanding is
2001-
# quantized and handle it accordingly if that is the case.
20022005
module_out_features, module_in_features = module_weight_shape
20032006
debug_message = ""
20042007
if in_features > module_in_features:

tests/quantization/bnb/test_4bit.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -681,6 +681,44 @@ def test_lora_loading(self):
681681
self.assertTrue(max_diff < 1e-3)
682682

683683

684+
@require_transformers_version_greater("4.44.0")
685+
class SlowBnb4BitFluxWithLoraTests(Base4bitTests):
686+
def setUp(self) -> None:
687+
gc.collect()
688+
torch.cuda.empty_cache()
689+
690+
self.pipeline_4bit = DiffusionPipeline.from_pretrained(
691+
"eramth/flux-4bit",
692+
torch_dtype=torch.float16,
693+
)
694+
self.pipeline_4bit.enable_model_cpu_offload()
695+
696+
def tearDown(self):
697+
del self.pipeline_4bit
698+
699+
gc.collect()
700+
torch.cuda.empty_cache()
701+
702+
def test_lora_loading(self):
703+
self.pipeline_4bit.load_lora_weights("black-forest-labs/FLUX.1-Canny-dev-lora")
704+
705+
output = self.pipeline_4bit(
706+
prompt=self.prompt,
707+
height=256,
708+
width=256,
709+
max_sequence_length=64,
710+
output_type="np",
711+
num_inference_steps=8,
712+
generator=torch.Generator().manual_seed(42),
713+
).images
714+
out_slice = output[0, -3:, -3:, -1].flatten()
715+
# TODO: update slice
716+
expected_slice = np.array([0.5347, 0.5342, 0.5283, 0.5093, 0.4988, 0.5093, 0.5044, 0.5015, 0.4946])
717+
718+
max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
719+
self.assertTrue(max_diff < 1e-3, msg=f"{out_slice=} != {expected_slice=}")
720+
721+
684722
@slow
685723
class BaseBnb4BitSerializationTests(Base4bitTests):
686724
def tearDown(self):

0 commit comments

Comments
 (0)