Skip to content

Commit 3ce35c9

Browse files
authored
Merge pull request #1 from huggingface/hlky-flux-quantized-w-lora
enable model cpu offload()
2 parents 2c21d34 + d39497b commit 3ce35c9

File tree

2 files changed

+22
-8
lines changed

2 files changed

+22
-8
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1978,6 +1978,16 @@ def _maybe_expand_transformer_param_shape_or_error_(
19781978
"The checkpoint seems to have been quantized with `bitsandbytes` (4bits). Install `bitsandbytes` to load quantized checkpoints."
19791979
)
19801980
elif is_bnb_4bit_quantized:
1981+
weight_on_cpu = False
1982+
if not module.weight.is_cuda:
1983+
weight_on_cpu = True
1984+
module_weight = dequantize_bnb_weight(
1985+
module.weight.cuda() if weight_on_cpu else module.weight,
1986+
state=module.weight.quant_state,
1987+
dtype=transformer.dtype,
1988+
).data
1989+
if weight_on_cpu:
1990+
module_weight = module_weight.cpu()
19811991
module_weight = dequantize_bnb_weight(module.weight, state=module.weight.quant_state).data
19821992
else:
19831993
module_weight = module.weight.data

tests/quantization/bnb/test_4bit.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,15 @@
2121
import pytest
2222
import safetensors.torch
2323
from huggingface_hub import hf_hub_download
24-
25-
from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel
24+
from PIL import Image
25+
26+
from diffusers import (
27+
BitsAndBytesConfig,
28+
DiffusionPipeline,
29+
FluxControlPipeline,
30+
FluxTransformer2DModel,
31+
SD3Transformer2DModel,
32+
)
2633
from diffusers.utils import is_accelerate_version, logging
2734
from diffusers.utils.testing_utils import (
2835
CaptureLogger,
@@ -702,10 +709,7 @@ def setUp(self) -> None:
702709
gc.collect()
703710
torch.cuda.empty_cache()
704711

705-
self.pipeline_4bit = DiffusionPipeline.from_pretrained(
706-
"eramth/flux-4bit",
707-
torch_dtype=torch.float16,
708-
)
712+
self.pipeline_4bit = FluxControlPipeline.from_pretrained("eramth/flux-4bit", torch_dtype=torch.float16)
709713
self.pipeline_4bit.enable_model_cpu_offload()
710714

711715
def tearDown(self):
@@ -719,6 +723,7 @@ def test_lora_loading(self):
719723

720724
output = self.pipeline_4bit(
721725
prompt=self.prompt,
726+
control_image=Image.new(mode="RGB", size=(256, 256)),
722727
height=256,
723728
width=256,
724729
max_sequence_length=64,
@@ -727,8 +732,7 @@ def test_lora_loading(self):
727732
generator=torch.Generator().manual_seed(42),
728733
).images
729734
out_slice = output[0, -3:, -3:, -1].flatten()
730-
# TODO: update slice
731-
expected_slice = np.array([0.5347, 0.5342, 0.5283, 0.5093, 0.4988, 0.5093, 0.5044, 0.5015, 0.4946])
735+
expected_slice = np.array([0.1636, 0.1675, 0.1982, 0.1743, 0.1809, 0.1936, 0.1743, 0.2095, 0.2139])
732736

733737
max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
734738
self.assertTrue(max_diff < 1e-3, msg=f"{out_slice=} != {expected_slice=}")

0 commit comments

Comments
 (0)