Skip to content
4 changes: 3 additions & 1 deletion docs/source/en/quantization/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ Initialize [`~quantizers.PipelineQuantizationConfig`] with the following paramet
> [!TIP]
> These `quant_kwargs` arguments are different for each backend. Refer to the [Quantization API](../api/quantization) docs to view the arguments for each backend.

- `components_to_quantize` specifies which components of the pipeline to quantize. Typically, you should quantize the most compute intensive components like the transformer. The text encoder is another component to consider quantizing if a pipeline has more than one such as [`FluxPipeline`]. The example below quantizes the T5 text encoder in [`FluxPipeline`] while keeping the CLIP model intact.
- `components_to_quantize` specifies which component(s) of the pipeline to quantize. Typically, you should quantize the most compute intensive components like the transformer. The text encoder is another component to consider quantizing if a pipeline has more than one such as [`FluxPipeline`]. The example below quantizes the T5 text encoder in [`FluxPipeline`] while keeping the CLIP model intact.

The example below loads the bitsandbytes backend with the following arguments from [`~quantizers.quantization_config.BitsAndBytesConfig`], `load_in_4bit`, `bnb_4bit_quant_type`, and `bnb_4bit_compute_dtype`.

Expand Down Expand Up @@ -62,6 +62,8 @@ pipe = DiffusionPipeline.from_pretrained(
image = pipe("photo of a cute dog").images[0]
```

`components_to_quantize` doesn't have to be a list. You can also pass: `components_to_quantize="transformer"`.

### Advanced quantization

The `quant_mapping` argument provides more options for how to quantize each individual component in a pipeline, like combining different quantization backends.
Expand Down
5 changes: 4 additions & 1 deletion src/diffusers/quantizers/pipe_quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,15 @@ def __init__(
self,
quant_backend: str = None,
quant_kwargs: Dict[str, Union[str, float, int, dict]] = None,
components_to_quantize: Optional[List[str]] = None,
components_to_quantize: Optional[Union[List[str], str]] = None,
quant_mapping: Dict[str, Union[DiffQuantConfigMixin, "TransformersQuantConfigMixin"]] = None,
):
self.quant_backend = quant_backend
# Initialize kwargs to be {} to set to the defaults.
self.quant_kwargs = quant_kwargs or {}
if components_to_quantize:
if isinstance(components_to_quantize, str):
components_to_quantize = [components_to_quantize]
self.components_to_quantize = components_to_quantize
self.quant_mapping = quant_mapping
self.config_mapping = {} # book-keeping Example: `{module_name: quant_config}`
Expand Down
16 changes: 16 additions & 0 deletions tests/quantization/test_pipeline_level_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,3 +299,19 @@ def _parse_config_string(self, config_string: str) -> tuple[str, dict]:
data = json.loads(json_part)

return data

def test_single_component_to_quantize(self):
component_to_quantize = "transformer"
quant_config = PipelineQuantizationConfig(
quant_backend="bitsandbytes_8bit",
quant_kwargs={"load_in_8bit": True},
components_to_quantize=component_to_quantize,
)
pipe = DiffusionPipeline.from_pretrained(
self.model_name,
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
)
for name, component in pipe.components.items():
if name == component_to_quantize:
self.assertTrue(hasattr(component.config, "quantization_config"))
Loading