Skip to content

Commit eb7ef26

Browse files
sayakpaulstevhliu
andauthored
[quant] allow components_to_quantize to be a non-list for single components (huggingface#12234)
* allow non list components_to_quantize. * up * Apply suggestions from code review * Apply suggestions from code review Co-authored-by: Steven Liu <[email protected]> * [docs] components_to_quantize (huggingface#12287) init Co-authored-by: Sayak Paul <[email protected]> --------- Co-authored-by: Steven Liu <[email protected]>
1 parent e1b7f1f commit eb7ef26

File tree

6 files changed

+29
-7
lines changed

6 files changed

+29
-7
lines changed

docs/source/en/api/pipelines/cogvideox.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ from diffusers.utils import export_to_video
5050
pipeline_quant_config = PipelineQuantizationConfig(
5151
quant_backend="torchao",
5252
quant_kwargs={"quant_type": "int8wo"},
53-
components_to_quantize=["transformer"]
53+
components_to_quantize="transformer"
5454
)
5555

5656
# fp8 layerwise weight-casting

docs/source/en/api/pipelines/hunyuan_video.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ pipeline_quant_config = PipelineQuantizationConfig(
5454
"bnb_4bit_quant_type": "nf4",
5555
"bnb_4bit_compute_dtype": torch.bfloat16
5656
},
57-
components_to_quantize=["transformer"]
57+
components_to_quantize="transformer"
5858
)
5959

6060
pipeline = HunyuanVideoPipeline.from_pretrained(
@@ -91,7 +91,7 @@ pipeline_quant_config = PipelineQuantizationConfig(
9191
"bnb_4bit_quant_type": "nf4",
9292
"bnb_4bit_compute_dtype": torch.bfloat16
9393
},
94-
components_to_quantize=["transformer"]
94+
components_to_quantize="transformer"
9595
)
9696

9797
pipeline = HunyuanVideoPipeline.from_pretrained(
@@ -139,7 +139,7 @@ export_to_video(video, "output.mp4", fps=15)
139139
"bnb_4bit_quant_type": "nf4",
140140
"bnb_4bit_compute_dtype": torch.bfloat16
141141
},
142-
components_to_quantize=["transformer"]
142+
components_to_quantize="transformer"
143143
)
144144

145145
pipeline = HunyuanVideoPipeline.from_pretrained(

docs/source/en/quantization/overview.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@ Initialize [`~quantizers.PipelineQuantizationConfig`] with the following paramet
3434
> [!TIP]
3535
> These `quant_kwargs` arguments are different for each backend. Refer to the [Quantization API](../api/quantization) docs to view the arguments for each backend.
3636
37-
- `components_to_quantize` specifies which components of the pipeline to quantize. Typically, you should quantize the most compute intensive components like the transformer. The text encoder is another component to consider quantizing if a pipeline has more than one such as [`FluxPipeline`]. The example below quantizes the T5 text encoder in [`FluxPipeline`] while keeping the CLIP model intact.
37+
- `components_to_quantize` specifies which component(s) of the pipeline to quantize. Typically, you should quantize the most compute intensive components like the transformer. The text encoder is another component to consider quantizing if a pipeline has more than one such as [`FluxPipeline`]. The example below quantizes the T5 text encoder in [`FluxPipeline`] while keeping the CLIP model intact.
38+
39+
`components_to_quantize` accepts either a list for multiple models or a string for a single model.
3840

3941
The example below loads the bitsandbytes backend with the following arguments from [`~quantizers.quantization_config.BitsAndBytesConfig`], `load_in_4bit`, `bnb_4bit_quant_type`, and `bnb_4bit_compute_dtype`.
4042

@@ -62,6 +64,7 @@ pipe = DiffusionPipeline.from_pretrained(
6264
image = pipe("photo of a cute dog").images[0]
6365
```
6466

67+
6568
### Advanced quantization
6669

6770
The `quant_mapping` argument provides more options for how to quantize each individual component in a pipeline, like combining different quantization backends.

docs/source/en/using-diffusers/text-img2vid.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ pipeline_quant_config = PipelineQuantizationConfig(
9898
"bnb_4bit_quant_type": "nf4",
9999
"bnb_4bit_compute_dtype": torch.bfloat16
100100
},
101-
components_to_quantize=["transformer"]
101+
components_to_quantize="transformer"
102102
)
103103

104104
pipeline = HunyuanVideoPipeline.from_pretrained(

src/diffusers/quantizers/pipe_quant_config.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,15 @@ def __init__(
4848
self,
4949
quant_backend: str = None,
5050
quant_kwargs: Dict[str, Union[str, float, int, dict]] = None,
51-
components_to_quantize: Optional[List[str]] = None,
51+
components_to_quantize: Optional[Union[List[str], str]] = None,
5252
quant_mapping: Dict[str, Union[DiffQuantConfigMixin, "TransformersQuantConfigMixin"]] = None,
5353
):
5454
self.quant_backend = quant_backend
5555
# Initialize kwargs to be {} to set to the defaults.
5656
self.quant_kwargs = quant_kwargs or {}
57+
if components_to_quantize:
58+
if isinstance(components_to_quantize, str):
59+
components_to_quantize = [components_to_quantize]
5760
self.components_to_quantize = components_to_quantize
5861
self.quant_mapping = quant_mapping
5962
self.config_mapping = {} # book-keeping Example: `{module_name: quant_config}`

tests/quantization/test_pipeline_level_quantization.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,3 +299,19 @@ def _parse_config_string(self, config_string: str) -> tuple[str, dict]:
299299
data = json.loads(json_part)
300300

301301
return data
302+
303+
def test_single_component_to_quantize(self):
304+
component_to_quantize = "transformer"
305+
quant_config = PipelineQuantizationConfig(
306+
quant_backend="bitsandbytes_8bit",
307+
quant_kwargs={"load_in_8bit": True},
308+
components_to_quantize=component_to_quantize,
309+
)
310+
pipe = DiffusionPipeline.from_pretrained(
311+
self.model_name,
312+
quantization_config=quant_config,
313+
torch_dtype=torch.bfloat16,
314+
)
315+
for name, component in pipe.components.items():
316+
if name == component_to_quantize:
317+
self.assertTrue(hasattr(component.config, "quantization_config"))

0 commit comments

Comments
 (0)