|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
| 15 | +import inspect |
| 16 | +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union |
| 17 | + |
| 18 | +from ..utils import is_transformers_available, logging |
15 | 19 | from .auto import DiffusersAutoQuantizer |
16 | 20 | from .base import DiffusersQuantizer |
| 21 | + |
| 22 | + |
| 23 | +if TYPE_CHECKING: |
| 24 | + from .quantization_config import QuantizationConfigMixin as DiffQuantConfigMixin |
| 25 | + |
| 26 | + try: |
| 27 | + from transformers.utils.quantization_config import QuantizationConfigMixin as TransformersQuantConfigMixin |
| 28 | + except ImportError: |
| 29 | + |
| 30 | + class TransformersQuantConfigMixin: |
| 31 | + pass |
| 32 | + |
| 33 | + |
| 34 | +logger = logging.get_logger(__name__) |
| 35 | + |
| 36 | + |
| 37 | +class PipelineQuantizationConfig: |
| 38 | + """TODO""" |
| 39 | + |
| 40 | + def __init__( |
| 41 | + self, |
| 42 | + quant_backend: str = None, |
| 43 | + quant_kwargs: Dict[str, Union[str, float, int]] = None, |
| 44 | + modules_to_quantize: Optional[List[str]] = None, |
| 45 | + quant_mapping: Dict[str, Union[DiffQuantConfigMixin, "TransformersQuantConfigMixin"]] = None, |
| 46 | + ): |
| 47 | + self.quant_backend = quant_backend |
| 48 | + # Initialize kwargs to be {} to set to the defaults. |
| 49 | + self.quant_kwargs = quant_kwargs or {} |
| 50 | + self.modules_to_quantize = modules_to_quantize |
| 51 | + self.quant_mapping = quant_mapping |
| 52 | + |
| 53 | + self.post_init() |
| 54 | + |
| 55 | + def post_init(self): |
| 56 | + quant_mapping = self.quant_mapping |
| 57 | + self.is_granular = True if quant_mapping is not None else False |
| 58 | + |
| 59 | + self._validate_init_args() |
| 60 | + |
| 61 | + def _validate_init_args(self): |
| 62 | + if self.quant_backend and self.quant_mapping: |
| 63 | + raise ValueError("Both `quant_backend` and `quant_mapping` cannot be specified at the same time.") |
| 64 | + |
| 65 | + if not self.quant_mapping and not self.quant_backend: |
| 66 | + raise ValueError("Must provide a `quant_backend` when not providing a `quant_mapping`.") |
| 67 | + |
| 68 | + if not self.quant_kwargs and not self.quant_mapping: |
| 69 | + raise ValueError("Both `quant_kwargs` and `quant_mapping` cannot be None.") |
| 70 | + |
| 71 | + if self.quant_backend is not None: |
| 72 | + self._validate_init_kwargs_in_backends() |
| 73 | + |
| 74 | + if self.quant_mapping is not None: |
| 75 | + self._validate_quant_mapping_args() |
| 76 | + |
| 77 | + def _validate_init_kwargs_in_backends(self): |
| 78 | + quant_backend = self.quant_backend |
| 79 | + |
| 80 | + self._check_backend_availability(quant_backend) |
| 81 | + |
| 82 | + quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list() |
| 83 | + |
| 84 | + if quant_config_mapping_transformers is not None: |
| 85 | + init_kwargs_transformers = inspect.signature(quant_config_mapping_transformers[quant_backend].__init__) |
| 86 | + init_kwargs_transformers = {name for name in init_kwargs_transformers.parameters if name != "self"} |
| 87 | + else: |
| 88 | + init_kwargs_transformers = None |
| 89 | + |
| 90 | + init_kwargs_diffusers = inspect.signature(quant_config_mapping_diffusers[quant_backend].__init__) |
| 91 | + init_kwargs_diffusers = {name for name in init_kwargs_diffusers.parameters if name != "self"} |
| 92 | + |
| 93 | + if init_kwargs_transformers != init_kwargs_diffusers: |
| 94 | + raise ValueError( |
| 95 | + "The signatures of the __init__ methods of the quantization config classes in `diffusers` and `transformers` don't match. " |
| 96 | + f"Please provide a `quant_mapping` instead, in the {self.__class__.__name__} class." |
| 97 | + ) |
| 98 | + |
| 99 | + def _validate_quant_mapping_args(self): |
| 100 | + quant_mapping = self.quant_mapping |
| 101 | + quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list() |
| 102 | + |
| 103 | + available_configs_transformers = ( |
| 104 | + list(quant_config_mapping_transformers.values()) if quant_config_mapping_transformers else None |
| 105 | + ) |
| 106 | + available_configs_diffusers = list(quant_config_mapping_diffusers.values()) |
| 107 | + |
| 108 | + for module_name, config in quant_mapping.items(): |
| 109 | + msg = "" |
| 110 | + if not (any(isinstance(config, available) for available in available_configs_diffusers)): |
| 111 | + msg = f"Provided config for {module_name=} could not be found. Available ones for `diffusers` are: {available_configs_diffusers}.)" |
| 112 | + elif available_configs_transformers is not None and not ( |
| 113 | + any(isinstance(config, available) for available in available_configs_transformers) |
| 114 | + ): |
| 115 | + msg = f"Provided config for {module_name=} could not be found. Available ones for `transformers` are: {available_configs_transformers}.)" |
| 116 | + if msg: |
| 117 | + raise ValueError(msg) |
| 118 | + |
| 119 | + def _check_backend_availability(self, quant_backend: str): |
| 120 | + quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list() |
| 121 | + |
| 122 | + available_backends_transformers = ( |
| 123 | + list(quant_config_mapping_transformers.keys()) if quant_config_mapping_transformers else None |
| 124 | + ) |
| 125 | + available_backends_diffusers = list(quant_config_mapping_diffusers.keys()) |
| 126 | + |
| 127 | + if ( |
| 128 | + available_backends_transformers and quant_backend not in available_backends_transformers |
| 129 | + ) or quant_backend not in quant_config_mapping_diffusers: |
| 130 | + error_message = f"Provided quant_backend={quant_backend} was not found." |
| 131 | + if available_backends_transformers: |
| 132 | + error_message += f"\nAvailable ones (transformers): {available_backends_transformers}." |
| 133 | + error_message += f"\nAvailable ones (diffusers): {available_backends_diffusers}." |
| 134 | + raise ValueError(error_message) |
| 135 | + |
| 136 | + def _resolve_quant_config(self, is_diffusers: bool = True, module_name: str = None): |
| 137 | + quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list() |
| 138 | + |
| 139 | + quant_mapping = self.quant_mapping |
| 140 | + modules_to_quantize = self.modules_to_quantize |
| 141 | + |
| 142 | + # Granular case |
| 143 | + if self.is_granular and module_name in quant_mapping: |
| 144 | + logger.debug(f"Initializing quantization config class for {module_name}.") |
| 145 | + config = quant_mapping[module_name] |
| 146 | + return config |
| 147 | + |
| 148 | + # Global config case |
| 149 | + else: |
| 150 | + should_quantize = False |
| 151 | + # Only quantize the modules requested for. |
| 152 | + if modules_to_quantize and module_name in modules_to_quantize: |
| 153 | + should_quantize = True |
| 154 | + # No specification for `modules_to_quantize` means all modules should be quantized. |
| 155 | + elif not self.is_granular and not modules_to_quantize: |
| 156 | + should_quantize = True |
| 157 | + |
| 158 | + if should_quantize: |
| 159 | + logger.debug(f"Initializing quantization config class for {module_name}.") |
| 160 | + mapping_to_use = quant_config_mapping_diffusers if is_diffusers else quant_config_mapping_transformers |
| 161 | + quant_config_cls = mapping_to_use[self.quant_backend] |
| 162 | + quant_kwargs = self.quant_kwargs |
| 163 | + return quant_config_cls(**quant_kwargs) |
| 164 | + |
| 165 | + # Fallback: no applicable configuration found. |
| 166 | + return None |
| 167 | + |
| 168 | + def _get_quant_config_list(self): |
| 169 | + if is_transformers_available(): |
| 170 | + from transformers.quantizers.auto import ( |
| 171 | + AUTO_QUANTIZATION_CONFIG_MAPPING as quant_config_mapping_transformers, |
| 172 | + ) |
| 173 | + else: |
| 174 | + quant_config_mapping_transformers = None |
| 175 | + |
| 176 | + from ..quantizers.auto import AUTO_QUANTIZATION_CONFIG_MAPPING as quant_config_mapping_diffusers |
| 177 | + |
| 178 | + return quant_config_mapping_transformers, quant_config_mapping_diffusers |
0 commit comments