Skip to content

Commit 1d46b21

Browse files
committed
feat: draft pipeline-level quant config.
1 parent 9a1810f commit 1d46b21

File tree

3 files changed

+37
-0
lines changed

3 files changed

+37
-0
lines changed

src/diffusers/pipelines/pipeline_loading_utils.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -667,9 +667,12 @@ def load_sub_model(
667667
use_safetensors: bool,
668668
dduf_entries: Optional[Dict[str, DDUFEntry]],
669669
provider_options: Any,
670+
quantization_config: Optional[Any] = None,
670671
):
671672
"""Helper method to load the module `name` from `library_name` and `class_name`"""
672673

674+
from ..quantizers import PipelineQuantizationConfig
675+
673676
# retrieve class candidates
674677

675678
class_obj, class_candidates = get_class_obj_and_candidates(
@@ -761,6 +764,12 @@ def load_sub_model(
761764
else:
762765
loading_kwargs["low_cpu_mem_usage"] = False
763766

767+
if quantization_config is not None and isinstance(quantization_config, PipelineQuantizationConfig):
768+
exclude_modules = quantization_config.exclude_modules or []
769+
if name not in exclude_modules:
770+
model_quant_config = _resolve_quant_config(quantization_config, is_diffusers=is_diffusers_model)
771+
loading_kwargs["quantization_config"] = model_quant_config
772+
764773
# check if the module is in a subdirectory
765774
if dduf_entries:
766775
loading_kwargs["dduf_entries"] = dduf_entries
@@ -1070,3 +1079,22 @@ def _maybe_raise_error_for_incorrect_transformers(config_dict):
10701079
break
10711080
if has_transformers_component and not is_transformers_version(">", "4.47.1"):
10721081
raise ValueError("Please upgrade your `transformers` installation to the latest version to use DDUF.")
1082+
1083+
1084+
def _resolve_quant_config(quant_config, is_diffusers=True):
1085+
if is_diffusers:
1086+
from ..quantizers.auto import AUTO_QUANTIZATION_CONFIG_MAPPING
1087+
else:
1088+
from transformers.quantizers.auto import AUTO_QUANTIZATION_CONFIG_MAPPING
1089+
1090+
quant_backend = quant_config.quant_backend
1091+
if quant_backend not in AUTO_QUANTIZATION_CONFIG_MAPPING:
1092+
raise ValueError(
1093+
f"Provided {quant_backend=} was not found in the support quantizers. Available ones are: {AUTO_QUANTIZATION_CONFIG_MAPPING.keys()}."
1094+
)
1095+
1096+
quant_config_cls = AUTO_QUANTIZATION_CONFIG_MAPPING[quant_backend]
1097+
1098+
quant_kwargs = quant_config.quant_kwargs
1099+
quant_config = quant_config_cls(**quant_kwargs)
1100+
return quant_config

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -702,6 +702,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
702702
use_safetensors = kwargs.pop("use_safetensors", None)
703703
use_onnx = kwargs.pop("use_onnx", None)
704704
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
705+
quantization_config = kwargs.pop("quantization_config", None)
705706

706707
if not isinstance(torch_dtype, torch.dtype):
707708
torch_dtype = torch.float32
@@ -973,6 +974,7 @@ def load_module(name, value):
973974
use_safetensors=use_safetensors,
974975
dduf_entries=dduf_entries,
975976
provider_options=provider_options,
977+
quantization_config=quantization_config,
976978
)
977979
logger.info(
978980
f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."

src/diffusers/quantizers/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,10 @@
1414

1515
from .auto import DiffusersAutoQuantizer
1616
from .base import DiffusersQuantizer
17+
18+
19+
class PipelineQuantizationConfig:
20+
def __init__(self, quant_backend, quant_kwargs, exclude_modules):
21+
self.quant_backend = quant_backend
22+
self.quant_kwargs = quant_kwargs
23+
self.exclude_modules = exclude_modules

0 commit comments

Comments
 (0)