Skip to content

Commit 909e715

Browse files
committed
allow non list components_to_quantize.
1 parent 22b229b commit 909e715

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

src/diffusers/quantizers/pipe_quant_config.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,15 @@ def __init__(
4848
self,
4949
quant_backend: str = None,
5050
quant_kwargs: Dict[str, Union[str, float, int, dict]] = None,
51-
components_to_quantize: Optional[List[str]] = None,
51+
components_to_quantize: Optional[Union[List[str], str]] = None,
5252
quant_mapping: Dict[str, Union[DiffQuantConfigMixin, "TransformersQuantConfigMixin"]] = None,
5353
):
5454
self.quant_backend = quant_backend
5555
# Initialize kwargs to be {} to set to the defaults.
5656
self.quant_kwargs = quant_kwargs or {}
57+
if components_to_quantize:
58+
if isinstance(components_to_quantize, str):
59+
components_to_quantize = [components_to_quantize]
5760
self.components_to_quantize = components_to_quantize
5861
self.quant_mapping = quant_mapping
5962
self.config_mapping = {} # book-keeping Example: `{module_name: quant_config}`

tests/quantization/test_pipeline_level_quantization.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,3 +298,19 @@ def _parse_config_string(self, config_string: str) -> tuple[str, dict]:
298298
data = json.loads(json_part)
299299

300300
return data
301+
302+
def test_single_component_to_quantize(self):
303+
component_to_quantize = "transformer"
304+
quant_config = PipelineQuantizationConfig(
305+
quant_backend="bitsandbytes_8bit",
306+
quant_kwargs={"load_in_8bit": True},
307+
components_to_quantize=component_to_quantize,
308+
)
309+
pipe = DiffusionPipeline.from_pretrained(
310+
self.model_name,
311+
quantization_config=quant_config,
312+
torch_dtype=torch.bfloat16,
313+
)
314+
for name, component in pipe.components.items():
315+
if name == component_to_quantize:
316+
self.assertTrue(hasattr(component.config, "quantization_config"))

0 commit comments

Comments
 (0)