Skip to content

Commit 2f05b6e

Browse files
Merge branch 'main' into docs/add-pruna-to-diffusers-optimization
2 parents e44e109 + 8e88495 commit 2f05b6e

File tree

4 files changed

+91
-5
lines changed

4 files changed

+91
-5
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,17 @@ def _maybe_dequantize_weight_for_expanded_lora(model, module):
8181
from ..quantizers.gguf.utils import dequantize_gguf_tensor
8282

8383
is_bnb_4bit_quantized = module.weight.__class__.__name__ == "Params4bit"
84+
is_bnb_8bit_quantized = module.weight.__class__.__name__ == "Int8Params"
8485
is_gguf_quantized = module.weight.__class__.__name__ == "GGUFParameter"
8586

8687
if is_bnb_4bit_quantized and not is_bitsandbytes_available():
8788
raise ValueError(
8889
"The checkpoint seems to have been quantized with `bitsandbytes` (4bits). Install `bitsandbytes` to load quantized checkpoints."
8990
)
91+
if is_bnb_8bit_quantized and not is_bitsandbytes_available():
92+
raise ValueError(
93+
"The checkpoint seems to have been quantized with `bitsandbytes` (8bits). Install `bitsandbytes` to load quantized checkpoints."
94+
)
9095
if is_gguf_quantized and not is_gguf_available():
9196
raise ValueError(
9297
"The checkpoint seems to have been quantized with `gguf`. Install `gguf` to load quantized checkpoints."
@@ -97,10 +102,10 @@ def _maybe_dequantize_weight_for_expanded_lora(model, module):
97102
weight_on_cpu = True
98103

99104
device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
100-
if is_bnb_4bit_quantized:
105+
if is_bnb_4bit_quantized or is_bnb_8bit_quantized:
101106
module_weight = dequantize_bnb_weight(
102107
module.weight.to(device) if weight_on_cpu else module.weight,
103-
state=module.weight.quant_state,
108+
state=module.weight.quant_state if is_bnb_4bit_quantized else module.state,
104109
dtype=model.dtype,
105110
).data
106111
elif is_gguf_quantized:

src/diffusers/utils/dynamic_modules_utils.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -154,12 +154,30 @@ def check_imports(filename):
154154
return get_relative_imports(filename)
155155

156156

157-
def get_class_in_module(class_name, module_path):
157+
def get_class_in_module(class_name, module_path, pretrained_model_name_or_path=None):
158158
"""
159159
Import a module on the cache directory for modules and extract a class from it.
160160
"""
161161
module_path = module_path.replace(os.path.sep, ".")
162-
module = importlib.import_module(module_path)
162+
try:
163+
module = importlib.import_module(module_path)
164+
except ModuleNotFoundError as e:
165+
# This can happen when the repo id contains ".", which Python's import machinery interprets as a directory
166+
# separator. We do a bit of monkey patching to detect and fix this case.
167+
if not (
168+
pretrained_model_name_or_path is not None
169+
and "." in pretrained_model_name_or_path
170+
and module_path.startswith("diffusers_modules")
171+
and pretrained_model_name_or_path.replace("/", "--") in module_path
172+
):
173+
raise e # We can't figure this one out, just reraise the original error
174+
175+
corrected_path = os.path.join(HF_MODULES_CACHE, module_path.replace(".", "/")) + ".py"
176+
corrected_path = corrected_path.replace(
177+
pretrained_model_name_or_path.replace("/", "--").replace(".", "/"),
178+
pretrained_model_name_or_path.replace("/", "--"),
179+
)
180+
module = importlib.machinery.SourceFileLoader(module_path, corrected_path).load_module()
163181

164182
if class_name is None:
165183
return find_pipeline_class(module)
@@ -454,4 +472,4 @@ def get_class_from_dynamic_module(
454472
revision=revision,
455473
local_files_only=local_files_only,
456474
)
457-
return get_class_in_module(class_name, final_module.replace(".py", ""))
475+
return get_class_in_module(class_name, final_module.replace(".py", ""), pretrained_model_name_or_path)

tests/pipelines/test_pipelines.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1105,6 +1105,21 @@ def test_remote_auto_custom_pipe(self):
11051105

11061106
assert images.shape == (1, 64, 64, 3)
11071107

1108+
def test_remote_custom_pipe_with_dot_in_name(self):
1109+
# make sure that trust remote code has to be passed
1110+
with self.assertRaises(ValueError):
1111+
pipeline = DiffusionPipeline.from_pretrained("akasharidas/ddpm-cifar10-32-dot.in.name")
1112+
1113+
pipeline = DiffusionPipeline.from_pretrained("akasharidas/ddpm-cifar10-32-dot.in.name", trust_remote_code=True)
1114+
1115+
assert pipeline.__class__.__name__ == "CustomPipeline"
1116+
1117+
pipeline = pipeline.to(torch_device)
1118+
images, output_str = pipeline(num_inference_steps=2, output_type="np")
1119+
1120+
assert images[0].shape == (1, 32, 32, 3)
1121+
assert output_str == "This is a test"
1122+
11081123
def test_local_custom_pipeline_repo(self):
11091124
local_custom_pipeline_path = get_tests_dir("fixtures/custom_pipeline")
11101125
pipeline = DiffusionPipeline.from_pretrained(

tests/quantization/bnb/test_mixed_int8.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,18 @@
1919
import numpy as np
2020
import pytest
2121
from huggingface_hub import hf_hub_download
22+
from PIL import Image
2223

2324
from diffusers import (
2425
BitsAndBytesConfig,
2526
DiffusionPipeline,
27+
FluxControlPipeline,
2628
FluxTransformer2DModel,
2729
SanaTransformer2DModel,
2830
SD3Transformer2DModel,
2931
logging,
3032
)
33+
from diffusers.quantizers import PipelineQuantizationConfig
3134
from diffusers.utils import is_accelerate_version
3235
from diffusers.utils.testing_utils import (
3336
CaptureLogger,
@@ -39,6 +42,7 @@
3942
numpy_cosine_similarity_distance,
4043
require_accelerate,
4144
require_bitsandbytes_version_greater,
45+
require_peft_backend,
4246
require_peft_version_greater,
4347
require_torch,
4448
require_torch_accelerator,
@@ -697,6 +701,50 @@ def test_lora_loading(self):
697701
self.assertTrue(max_diff < 1e-3)
698702

699703

704+
@require_transformers_version_greater("4.44.0")
705+
@require_peft_backend
706+
class SlowBnb4BitFluxControlWithLoraTests(Base8bitTests):
707+
def setUp(self) -> None:
708+
gc.collect()
709+
backend_empty_cache(torch_device)
710+
711+
self.pipeline_8bit = FluxControlPipeline.from_pretrained(
712+
"black-forest-labs/FLUX.1-dev",
713+
quantization_config=PipelineQuantizationConfig(
714+
quant_backend="bitsandbytes_8bit",
715+
quant_kwargs={"load_in_8bit": True},
716+
components_to_quantize=["transformer", "text_encoder_2"],
717+
),
718+
torch_dtype=torch.float16,
719+
)
720+
self.pipeline_8bit.enable_model_cpu_offload()
721+
722+
def tearDown(self):
723+
del self.pipeline_8bit
724+
725+
gc.collect()
726+
backend_empty_cache(torch_device)
727+
728+
def test_lora_loading(self):
729+
self.pipeline_8bit.load_lora_weights("black-forest-labs/FLUX.1-Canny-dev-lora")
730+
731+
output = self.pipeline_8bit(
732+
prompt=self.prompt,
733+
control_image=Image.new(mode="RGB", size=(256, 256)),
734+
height=256,
735+
width=256,
736+
max_sequence_length=64,
737+
output_type="np",
738+
num_inference_steps=8,
739+
generator=torch.Generator().manual_seed(42),
740+
).images
741+
out_slice = output[0, -3:, -3:, -1].flatten()
742+
expected_slice = np.array([0.2029, 0.2136, 0.2268, 0.1921, 0.1997, 0.2185, 0.2021, 0.2183, 0.2292])
743+
744+
max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
745+
self.assertTrue(max_diff < 1e-3, msg=f"{out_slice=} != {expected_slice=}")
746+
747+
700748
@slow
701749
class BaseBnb8bitSerializationTests(Base8bitTests):
702750
def setUp(self):

0 commit comments

Comments
 (0)