Skip to content

Commit 478a353

Browse files
committed
add validation to the input quantization_config.
1 parent 5b6ee10 commit 478a353

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from ..models import AutoencoderKL
4848
from ..models.attention_processor import FusedAttnProcessor2_0
4949
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin
50+
from ..quantizers import PipelineQuantizationConfig
5051
from ..quantizers.bitsandbytes.utils import _check_bnb_status
5152
from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
5253
from ..utils import (
@@ -742,6 +743,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
742743
" install accelerate\n```\n."
743744
)
744745

746+
if quantization_config is not None and not isinstance(quantization_config, PipelineQuantizationConfig):
747+
raise ValueError("`quantization_config` must be an instance of `PipelineQuantizationConfig`.")
748+
745749
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
746750
raise NotImplementedError(
747751
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"

tests/quantization/test_pipeline_level_quantization.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def test_quant_config_set_correctly_through_kwargs(self):
7474

7575
_ = pipe(self.prompt, num_inference_steps=self.num_inference_steps)
7676

77-
def test_quant_config_set_correctly_granular(self):
77+
def test_quant_config_set_correctly_through_granular(self):
7878
quant_config = PipelineQuantizationConfig(
7979
quant_mapping={
8080
"transformer": QuantoConfig(weights_dtype="int8"),
@@ -128,6 +128,21 @@ def test_validation_for_kwargs(self):
128128
"The signatures of the __init__ methods of the quantization config classes" in str(err_context.exception)
129129
)
130130

131+
def test_raises_error_for_wrong_config_class(self):
132+
quant_config = {
133+
"transformer": QuantoConfig(weights_dtype="int8"),
134+
"text_encoder_2": TranBitsAndBytesConfig(load_in_4bit=True, compute_dtype=torch.bfloat16),
135+
}
136+
with self.assertRaises(ValueError) as err_context:
137+
_ = DiffusionPipeline.from_pretrained(
138+
self.model_name,
139+
quantization_config=quant_config,
140+
torch_dtype=torch.bfloat16,
141+
)
142+
self.assertTrue(
143+
str(err_context.exception) == "`quantization_config` must be an instance of `PipelineQuantizationConfig`."
144+
)
145+
131146
def test_validation_for_mapping(self):
132147
with self.assertRaises(ValueError) as err_context:
133148
_ = PipelineQuantizationConfig(

0 commit comments

Comments
 (0)