|  | 
| 12 | 12 | # See the License for the specific language governing permissions and | 
| 13 | 13 | # limitations under the License. | 
| 14 | 14 | 
 | 
| 15 |  | -import inspect | 
| 16 |  | -from typing import Dict, List, Optional, Union | 
| 17 | 15 | 
 | 
| 18 |  | -from ..utils import is_transformers_available, logging | 
| 19 | 16 | from .auto import DiffusersAutoQuantizer | 
| 20 | 17 | from .base import DiffusersQuantizer | 
| 21 |  | -from .quantization_config import QuantizationConfigMixin as DiffQuantConfigMixin | 
| 22 |  | - | 
| 23 |  | - | 
| 24 |  | -try: | 
| 25 |  | -    from transformers.utils.quantization_config import QuantizationConfigMixin as TransformersQuantConfigMixin | 
| 26 |  | -except ImportError: | 
| 27 |  | - | 
| 28 |  | -    class TransformersQuantConfigMixin: | 
| 29 |  | -        pass | 
| 30 |  | - | 
| 31 |  | - | 
| 32 |  | -logger = logging.get_logger(__name__) | 
| 33 |  | - | 
| 34 |  | - | 
| 35 |  | -class PipelineQuantizationConfig: | 
| 36 |  | -    """ | 
| 37 |  | -    Configuration class to be used when applying quantization on-the-fly to [`~DiffusionPipeline.from_pretrained`]. | 
| 38 |  | -
 | 
| 39 |  | -    Args: | 
| 40 |  | -        quant_backend (`str`): Quantization backend to be used. When using this option, we assume that the backend | 
| 41 |  | -            is available to both `diffusers` and `transformers`. | 
| 42 |  | -        quant_kwargs (`dict`): Params to initialize the quantization backend class. | 
| 43 |  | -        components_to_quantize (`list`): Components of a pipeline to be quantized. | 
| 44 |  | -        quant_mapping (`dict`): Mapping defining the quantization specs to be used for the pipeline | 
| 45 |  | -            components. When using this argument, users are not expected to provide `quant_backend`, `quant_kawargs`, | 
| 46 |  | -            and `components_to_quantize`. | 
| 47 |  | -    """ | 
| 48 |  | - | 
| 49 |  | -    def __init__( | 
| 50 |  | -        self, | 
| 51 |  | -        quant_backend: str = None, | 
| 52 |  | -        quant_kwargs: Dict[str, Union[str, float, int, dict]] = None, | 
| 53 |  | -        components_to_quantize: Optional[List[str]] = None, | 
| 54 |  | -        quant_mapping: Dict[str, Union[DiffQuantConfigMixin, "TransformersQuantConfigMixin"]] = None, | 
| 55 |  | -    ): | 
| 56 |  | -        self.quant_backend = quant_backend | 
| 57 |  | -        # Initialize kwargs to be {} to set to the defaults. | 
| 58 |  | -        self.quant_kwargs = quant_kwargs or {} | 
| 59 |  | -        self.components_to_quantize = components_to_quantize | 
| 60 |  | -        self.quant_mapping = quant_mapping | 
| 61 |  | - | 
| 62 |  | -        self.post_init() | 
| 63 |  | - | 
| 64 |  | -    def post_init(self): | 
| 65 |  | -        quant_mapping = self.quant_mapping | 
| 66 |  | -        self.is_granular = True if quant_mapping is not None else False | 
| 67 |  | - | 
| 68 |  | -        self._validate_init_args() | 
| 69 |  | - | 
| 70 |  | -    def _validate_init_args(self): | 
| 71 |  | -        if self.quant_backend and self.quant_mapping: | 
| 72 |  | -            raise ValueError("Both `quant_backend` and `quant_mapping` cannot be specified at the same time.") | 
| 73 |  | - | 
| 74 |  | -        if not self.quant_mapping and not self.quant_backend: | 
| 75 |  | -            raise ValueError("Must provide a `quant_backend` when not providing a `quant_mapping`.") | 
| 76 |  | - | 
| 77 |  | -        if not self.quant_kwargs and not self.quant_mapping: | 
| 78 |  | -            raise ValueError("Both `quant_kwargs` and `quant_mapping` cannot be None.") | 
| 79 |  | - | 
| 80 |  | -        if self.quant_backend is not None: | 
| 81 |  | -            self._validate_init_kwargs_in_backends() | 
| 82 |  | - | 
| 83 |  | -        if self.quant_mapping is not None: | 
| 84 |  | -            self._validate_quant_mapping_args() | 
| 85 |  | - | 
| 86 |  | -    def _validate_init_kwargs_in_backends(self): | 
| 87 |  | -        quant_backend = self.quant_backend | 
| 88 |  | - | 
| 89 |  | -        self._check_backend_availability(quant_backend) | 
| 90 |  | - | 
| 91 |  | -        quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list() | 
| 92 |  | - | 
| 93 |  | -        if quant_config_mapping_transformers is not None: | 
| 94 |  | -            init_kwargs_transformers = inspect.signature(quant_config_mapping_transformers[quant_backend].__init__) | 
| 95 |  | -            init_kwargs_transformers = {name for name in init_kwargs_transformers.parameters if name != "self"} | 
| 96 |  | -        else: | 
| 97 |  | -            init_kwargs_transformers = None | 
| 98 |  | - | 
| 99 |  | -        init_kwargs_diffusers = inspect.signature(quant_config_mapping_diffusers[quant_backend].__init__) | 
| 100 |  | -        init_kwargs_diffusers = {name for name in init_kwargs_diffusers.parameters if name != "self"} | 
| 101 |  | - | 
| 102 |  | -        if init_kwargs_transformers != init_kwargs_diffusers: | 
| 103 |  | -            raise ValueError( | 
| 104 |  | -                "The signatures of the __init__ methods of the quantization config classes in `diffusers` and `transformers` don't match. " | 
| 105 |  | -                f"Please provide a `quant_mapping` instead, in the {self.__class__.__name__} class. Refer to [the docs](https://huggingface.co/docs/diffusers/main/en/quantization/overview#pipeline-level-quantization) to learn more about how " | 
| 106 |  | -                "this mapping would look like." | 
| 107 |  | -            ) | 
| 108 |  | - | 
| 109 |  | -    def _validate_quant_mapping_args(self): | 
| 110 |  | -        quant_mapping = self.quant_mapping | 
| 111 |  | -        transformers_map, diffusers_map = self._get_quant_config_list() | 
| 112 |  | - | 
| 113 |  | -        available_transformers = list(transformers_map.values()) if transformers_map else None | 
| 114 |  | -        available_diffusers = list(diffusers_map.values()) | 
| 115 |  | - | 
| 116 |  | -        for module_name, config in quant_mapping.items(): | 
| 117 |  | -            if any(isinstance(config, cfg) for cfg in available_diffusers): | 
| 118 |  | -                continue | 
| 119 |  | - | 
| 120 |  | -            if available_transformers and any(isinstance(config, cfg) for cfg in available_transformers): | 
| 121 |  | -                continue | 
| 122 |  | - | 
| 123 |  | -            if available_transformers: | 
| 124 |  | -                raise ValueError( | 
| 125 |  | -                    f"Provided config for module_name={module_name} could not be found. " | 
| 126 |  | -                    f"Available diffusers configs: {available_diffusers}; " | 
| 127 |  | -                    f"Available transformers configs: {available_transformers}." | 
| 128 |  | -                ) | 
| 129 |  | -            else: | 
| 130 |  | -                raise ValueError( | 
| 131 |  | -                    f"Provided config for module_name={module_name} could not be found. " | 
| 132 |  | -                    f"Available diffusers configs: {available_diffusers}." | 
| 133 |  | -                ) | 
| 134 |  | - | 
| 135 |  | -    def _check_backend_availability(self, quant_backend: str): | 
| 136 |  | -        quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list() | 
| 137 |  | - | 
| 138 |  | -        available_backends_transformers = ( | 
| 139 |  | -            list(quant_config_mapping_transformers.keys()) if quant_config_mapping_transformers else None | 
| 140 |  | -        ) | 
| 141 |  | -        available_backends_diffusers = list(quant_config_mapping_diffusers.keys()) | 
| 142 |  | - | 
| 143 |  | -        if ( | 
| 144 |  | -            available_backends_transformers and quant_backend not in available_backends_transformers | 
| 145 |  | -        ) or quant_backend not in quant_config_mapping_diffusers: | 
| 146 |  | -            error_message = f"Provided quant_backend={quant_backend} was not found." | 
| 147 |  | -            if available_backends_transformers: | 
| 148 |  | -                error_message += f"\nAvailable ones (transformers): {available_backends_transformers}." | 
| 149 |  | -            error_message += f"\nAvailable ones (diffusers): {available_backends_diffusers}." | 
| 150 |  | -            raise ValueError(error_message) | 
| 151 |  | - | 
| 152 |  | -    def _resolve_quant_config(self, is_diffusers: bool = True, module_name: str = None): | 
| 153 |  | -        quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list() | 
| 154 |  | - | 
| 155 |  | -        quant_mapping = self.quant_mapping | 
| 156 |  | -        components_to_quantize = self.components_to_quantize | 
| 157 |  | - | 
| 158 |  | -        # Granular case | 
| 159 |  | -        if self.is_granular and module_name in quant_mapping: | 
| 160 |  | -            logger.debug(f"Initializing quantization config class for {module_name}.") | 
| 161 |  | -            config = quant_mapping[module_name] | 
| 162 |  | -            return config | 
| 163 |  | - | 
| 164 |  | -        # Global config case | 
| 165 |  | -        else: | 
| 166 |  | -            should_quantize = False | 
| 167 |  | -            # Only quantize the modules requested for. | 
| 168 |  | -            if components_to_quantize and module_name in components_to_quantize: | 
| 169 |  | -                should_quantize = True | 
| 170 |  | -            # No specification for `components_to_quantize` means all modules should be quantized. | 
| 171 |  | -            elif not self.is_granular and not components_to_quantize: | 
| 172 |  | -                should_quantize = True | 
| 173 |  | - | 
| 174 |  | -            if should_quantize: | 
| 175 |  | -                logger.debug(f"Initializing quantization config class for {module_name}.") | 
| 176 |  | -                mapping_to_use = quant_config_mapping_diffusers if is_diffusers else quant_config_mapping_transformers | 
| 177 |  | -                quant_config_cls = mapping_to_use[self.quant_backend] | 
| 178 |  | -                quant_kwargs = self.quant_kwargs | 
| 179 |  | -                return quant_config_cls(**quant_kwargs) | 
| 180 |  | - | 
| 181 |  | -        # Fallback: no applicable configuration found. | 
| 182 |  | -        return None | 
| 183 |  | - | 
| 184 |  | -    def _get_quant_config_list(self): | 
| 185 |  | -        if is_transformers_available(): | 
| 186 |  | -            from transformers.quantizers.auto import ( | 
| 187 |  | -                AUTO_QUANTIZATION_CONFIG_MAPPING as quant_config_mapping_transformers, | 
| 188 |  | -            ) | 
| 189 |  | -        else: | 
| 190 |  | -            quant_config_mapping_transformers = None | 
| 191 |  | - | 
| 192 |  | -        from ..quantizers.auto import AUTO_QUANTIZATION_CONFIG_MAPPING as quant_config_mapping_diffusers | 
| 193 |  | - | 
| 194 |  | -        return quant_config_mapping_transformers, quant_config_mapping_diffusers | 
|  | 18 | +from .pipe_quant_config import PipelineQuantizationConfig | 
0 commit comments