Skip to content

Commit 4d3dede

Browse files
sayakpaulSunMarc
andcommitted
feat: implement pipeline-level quantization config
Co-authored-by: SunMarc <[email protected]>
1 parent c94d85a commit 4d3dede

File tree

3 files changed

+177
-0
lines changed

3 files changed

+177
-0
lines changed

src/diffusers/pipelines/pipeline_loading_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -667,8 +667,10 @@ def load_sub_model(
667667
use_safetensors: bool,
668668
dduf_entries: Optional[Dict[str, DDUFEntry]],
669669
provider_options: Any,
670+
quantization_config: Optional[Any] = None,
670671
):
671672
"""Helper method to load the module `name` from `library_name` and `class_name`"""
673+
from ..quantizers import PipelineQuantizationConfig
672674

673675
# retrieve class candidates
674676

@@ -761,6 +763,17 @@ def load_sub_model(
761763
else:
762764
loading_kwargs["low_cpu_mem_usage"] = False
763765

766+
if (
767+
quantization_config is not None
768+
and isinstance(quantization_config, PipelineQuantizationConfig)
769+
and issubclass(class_obj, torch.nn.Module)
770+
):
771+
model_quant_config = quantization_config._resolve_quant_config(
772+
is_diffusers=is_diffusers_model, module_name=name
773+
)
774+
if model_quant_config is not None:
775+
loading_kwargs["quantization_config"] = model_quant_config
776+
764777
# check if the module is in a subdirectory
765778
if dduf_entries:
766779
loading_kwargs["dduf_entries"] = dduf_entries

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -702,6 +702,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
702702
use_safetensors = kwargs.pop("use_safetensors", None)
703703
use_onnx = kwargs.pop("use_onnx", None)
704704
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
705+
quantization_config = kwargs.pop("quantization_config", None)
705706

706707
if not isinstance(torch_dtype, torch.dtype):
707708
torch_dtype = torch.float32
@@ -973,6 +974,7 @@ def load_module(name, value):
973974
use_safetensors=use_safetensors,
974975
dduf_entries=dduf_entries,
975976
provider_options=provider_options,
977+
quantization_config=quantization_config,
976978
)
977979
logger.info(
978980
f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."

src/diffusers/quantizers/__init__.py

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,167 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import inspect
16+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
17+
18+
from ..utils import is_transformers_available, logging
1519
from .auto import DiffusersAutoQuantizer
1620
from .base import DiffusersQuantizer
21+
22+
23+
if TYPE_CHECKING:
24+
from .quantization_config import QuantizationConfigMixin as DiffQuantConfigMixin
25+
26+
try:
27+
from transformers.utils.quantization_config import QuantizationConfigMixin as TransformersQuantConfigMixin
28+
except ImportError:
29+
30+
class TransformersQuantConfigMixin:
31+
pass
32+
33+
34+
logger = logging.get_logger(__name__)
35+
36+
37+
class PipelineQuantizationConfig:
38+
"""TODO"""
39+
40+
def __init__(
41+
self,
42+
quant_backend: str = None,
43+
quant_kwargs: Dict[str, Union[str, float, int]] = None,
44+
modules_to_quantize: Optional[List[str]] = None,
45+
quant_mapping: Dict[str, Union[DiffQuantConfigMixin, "TransformersQuantConfigMixin"]] = None,
46+
):
47+
self.quant_backend = quant_backend
48+
# Initialize kwargs to be {} to set to the defaults.
49+
self.quant_kwargs = quant_kwargs or {}
50+
self.modules_to_quantize = modules_to_quantize
51+
self.quant_mapping = quant_mapping
52+
53+
self.post_init()
54+
55+
def post_init(self):
56+
quant_mapping = self.quant_mapping
57+
self.is_granular = True if quant_mapping is not None else False
58+
59+
self._validate_init_args()
60+
61+
def _validate_init_args(self):
62+
if self.quant_backend and self.quant_mapping:
63+
raise ValueError("Both `quant_backend` and `quant_mapping` cannot be specified at the same time.")
64+
65+
if not self.quant_mapping and not self.quant_backend:
66+
raise ValueError("Must provide a `quant_backend` when not providing a `quant_mapping`.")
67+
68+
if not self.quant_kwargs and not self.quant_mapping:
69+
raise ValueError("Both `quant_kwargs` and `quant_mapping` cannot be None.")
70+
71+
if self.quant_backend is not None:
72+
self._validate_init_kwargs_in_backends()
73+
74+
if self.quant_mapping is not None:
75+
self._validate_quant_mapping_args()
76+
77+
def _validate_init_kwargs_in_backends(self):
78+
quant_backend = self.quant_backend
79+
80+
self._check_backend_availability(quant_backend)
81+
82+
quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list()
83+
84+
if quant_config_mapping_transformers is not None:
85+
init_kwargs_transformers = inspect.signature(quant_config_mapping_transformers[quant_backend].__init__)
86+
init_kwargs_transformers = {name for name in init_kwargs_transformers.parameters if name != "self"}
87+
else:
88+
init_kwargs_transformers = None
89+
90+
init_kwargs_diffusers = inspect.signature(quant_config_mapping_diffusers[quant_backend].__init__)
91+
init_kwargs_diffusers = {name for name in init_kwargs_diffusers.parameters if name != "self"}
92+
93+
if init_kwargs_transformers != init_kwargs_diffusers:
94+
raise ValueError(
95+
"The signatures of the __init__ methods of the quantization config classes in `diffusers` and `transformers` don't match. "
96+
f"Please provide a `quant_mapping` instead, in the {self.__class__.__name__} class."
97+
)
98+
99+
def _validate_quant_mapping_args(self):
100+
quant_mapping = self.quant_mapping
101+
quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list()
102+
103+
available_configs_transformers = (
104+
list(quant_config_mapping_transformers.values()) if quant_config_mapping_transformers else None
105+
)
106+
available_configs_diffusers = list(quant_config_mapping_diffusers.values())
107+
108+
for module_name, config in quant_mapping.items():
109+
msg = ""
110+
if not (any(isinstance(config, available) for available in available_configs_diffusers)):
111+
msg = f"Provided config for {module_name=} could not be found. Available ones for `diffusers` are: {available_configs_diffusers}.)"
112+
elif available_configs_transformers is not None and not (
113+
any(isinstance(config, available) for available in available_configs_transformers)
114+
):
115+
msg = f"Provided config for {module_name=} could not be found. Available ones for `transformers` are: {available_configs_transformers}.)"
116+
if msg:
117+
raise ValueError(msg)
118+
119+
def _check_backend_availability(self, quant_backend: str):
120+
quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list()
121+
122+
available_backends_transformers = (
123+
list(quant_config_mapping_transformers.keys()) if quant_config_mapping_transformers else None
124+
)
125+
available_backends_diffusers = list(quant_config_mapping_diffusers.keys())
126+
127+
if (
128+
available_backends_transformers and quant_backend not in available_backends_transformers
129+
) or quant_backend not in quant_config_mapping_diffusers:
130+
error_message = f"Provided quant_backend={quant_backend} was not found."
131+
if available_backends_transformers:
132+
error_message += f"\nAvailable ones (transformers): {available_backends_transformers}."
133+
error_message += f"\nAvailable ones (diffusers): {available_backends_diffusers}."
134+
raise ValueError(error_message)
135+
136+
def _resolve_quant_config(self, is_diffusers: bool = True, module_name: str = None):
137+
quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list()
138+
139+
quant_mapping = self.quant_mapping
140+
modules_to_quantize = self.modules_to_quantize
141+
142+
# Granular case
143+
if self.is_granular and module_name in quant_mapping:
144+
logger.debug(f"Initializing quantization config class for {module_name}.")
145+
config = quant_mapping[module_name]
146+
return config
147+
148+
# Global config case
149+
else:
150+
should_quantize = False
151+
# Only quantize the modules requested for.
152+
if modules_to_quantize and module_name in modules_to_quantize:
153+
should_quantize = True
154+
# No specification for `modules_to_quantize` means all modules should be quantized.
155+
elif not self.is_granular and not modules_to_quantize:
156+
should_quantize = True
157+
158+
if should_quantize:
159+
logger.debug(f"Initializing quantization config class for {module_name}.")
160+
mapping_to_use = quant_config_mapping_diffusers if is_diffusers else quant_config_mapping_transformers
161+
quant_config_cls = mapping_to_use[self.quant_backend]
162+
quant_kwargs = self.quant_kwargs
163+
return quant_config_cls(**quant_kwargs)
164+
165+
# Fallback: no applicable configuration found.
166+
return None
167+
168+
def _get_quant_config_list(self):
169+
if is_transformers_available():
170+
from transformers.quantizers.auto import (
171+
AUTO_QUANTIZATION_CONFIG_MAPPING as quant_config_mapping_transformers,
172+
)
173+
else:
174+
quant_config_mapping_transformers = None
175+
176+
from ..quantizers.auto import AUTO_QUANTIZATION_CONFIG_MAPPING as quant_config_mapping_diffusers
177+
178+
return quant_config_mapping_transformers, quant_config_mapping_diffusers

0 commit comments

Comments
 (0)