Skip to content

Commit e1084af

Browse files
committed
updates
1 parent 0a5a860 commit e1084af

File tree

2 files changed

+75
-1
lines changed

2 files changed

+75
-1
lines changed

src/diffusers/pipelines/pipeline_loading_utils.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1134,4 +1134,23 @@ def _maybe_raise_error_for_incorrect_transformers(config_dict):
11341134

11351135

11361136
def _maybe_warn_for_wrong_component_in_quant_config(pipe_init_dict, quant_config):
1137-
pass
1137+
if quant_config is None:
1138+
return
1139+
1140+
actual_pipe_components = set(pipe_init_dict.keys())
1141+
missing = ""
1142+
quant_components = None
1143+
if getattr(quant_config, "components_to_quantize", None) is not None:
1144+
quant_components = set(quant_config.components_to_quantize)
1145+
elif getattr(quant_config, "quant_mapping", None) is not None and isinstance(quant_config.quant_mapping, dict):
1146+
quant_components = set(quant_config.quant_mapping.keys())
1147+
1148+
if quant_components and not quant_components.issubset(actual_pipe_components):
1149+
missing = quant_components - actual_pipe_components
1150+
1151+
if missing:
1152+
logger.warning(
1153+
f"The following components in the quantization config {missing} will be ignored "
1154+
"as they do not belong to the underlying pipeline. Acceptable values for the pipeline "
1155+
f"components are: {', '.join(actual_pipe_components)}."
1156+
)

tests/quantization/test_pipeline_level_quantization.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,13 @@
1616
import unittest
1717

1818
import torch
19+
from parameterized import parameterized
1920

2021
from diffusers import DiffusionPipeline, QuantoConfig
2122
from diffusers.quantizers import PipelineQuantizationConfig
23+
from diffusers.utils import logging
2224
from diffusers.utils.testing_utils import (
25+
CaptureLogger,
2326
is_transformers_available,
2427
require_accelerate,
2528
require_bitsandbytes_version_greater,
@@ -188,3 +191,55 @@ def test_saving_loading(self):
188191
output_2 = loaded_pipe(**pipe_inputs, generator=torch.manual_seed(self.seed)).images
189192

190193
self.assertTrue(torch.allclose(output_1, output_2))
194+
195+
@parameterized.expand(["quant_kwargs", "quant_mapping"])
196+
def test_warn_invalid_component(self, method):
197+
invalid_component = "foo"
198+
if method == "quant_kwargs":
199+
components_to_quantize = ["transformer", invalid_component]
200+
quant_config = PipelineQuantizationConfig(
201+
quant_backend="bitsandbytes_8bit",
202+
quant_kwargs={"load_in_8bit": True},
203+
components_to_quantize=components_to_quantize,
204+
)
205+
else:
206+
quant_config = PipelineQuantizationConfig(
207+
quant_mapping={
208+
"transformer": QuantoConfig("int8"),
209+
invalid_component: TranBitsAndBytesConfig(load_in_8bit=True),
210+
}
211+
)
212+
213+
logger = logging.get_logger("diffusers.pipelines.pipeline_loading_utils")
214+
logger.setLevel(logging.WARNING)
215+
with CaptureLogger(logger) as cap_logger:
216+
_ = DiffusionPipeline.from_pretrained(
217+
self.model_name,
218+
quantization_config=quant_config,
219+
torch_dtype=torch.bfloat16,
220+
)
221+
self.assertTrue(invalid_component in cap_logger.out)
222+
223+
@parameterized.expand(["quant_kwargs", "quant_mapping"])
224+
def test_no_quantization_for_all_invalid_components(self, method):
225+
invalid_component = "foo"
226+
if method == "quant_kwargs":
227+
components_to_quantize = [invalid_component]
228+
quant_config = PipelineQuantizationConfig(
229+
quant_backend="bitsandbytes_8bit",
230+
quant_kwargs={"load_in_8bit": True},
231+
components_to_quantize=components_to_quantize,
232+
)
233+
else:
234+
quant_config = PipelineQuantizationConfig(
235+
quant_mapping={invalid_component: TranBitsAndBytesConfig(load_in_8bit=True)}
236+
)
237+
238+
pipe = DiffusionPipeline.from_pretrained(
239+
self.model_name,
240+
quantization_config=quant_config,
241+
torch_dtype=torch.bfloat16,
242+
)
243+
for name, component in pipe.components.items():
244+
if isinstance(component, torch.nn.Module):
245+
self.assertTrue(not hasattr(component.config, "quantization_config"))

0 commit comments

Comments
 (0)