|
7 | 7 | from typing import Dict |
8 | 8 | from typing import Optional |
9 | 9 | from typing import Type |
10 | | -from typing import TypeVar |
| 10 | +from typing import TypeAlias |
11 | 11 |
|
12 | 12 | from torch import nn |
13 | 13 |
|
|
52 | 52 |
|
53 | 53 | from .quant_blocks import * |
54 | 54 |
|
| 55 | +# Prevents Pylance from raising "Variable not allowed in type expression" error in every type hint in BaseQuantizer |
| 56 | +QuantInjector: TypeAlias = ExtendedInjector # type: ignore |
| 57 | + |
| 58 | + |
| 59 | +class BaseQuantizer: |
| 60 | + weight_quant: ClassVar[Optional[QuantInjector]] = None |
| 61 | + linear_input_quant: ClassVar[Optional[QuantInjector]] = None |
| 62 | + input_quant: ClassVar[Optional[QuantInjector]] = None |
| 63 | + q_scaled_quant: ClassVar[Optional[QuantInjector]] = None |
| 64 | + k_transposed_quant: ClassVar[Optional[QuantInjector]] = None |
| 65 | + v_quant: ClassVar[Optional[QuantInjector]] = None |
| 66 | + attn_output_weights_quant: ClassVar[Optional[QuantInjector]] = None |
| 67 | + |
| 68 | + @classmethod |
| 69 | + def override_quantizers_dict( |
| 70 | + cls: "BaseQuantizer", |
| 71 | + quantizers_dict: Dict[str, |
| 72 | + Optional[QuantInjector]]) -> Dict[str, Optional[QuantInjector]]: |
| 73 | + # Overrides the quantizers in the input dictionary |
| 74 | + for key in quantizers_dict: |
| 75 | + if (value := getattr(cls, key)) is not None: |
| 76 | + quantizers_dict[key] = value |
| 77 | + return quantizers_dict |
| 78 | + |
| 79 | + |
| 80 | +# Registry for custom quantizers |
| 81 | +CUSTOM_QUANTIZERS_REGISTRY = Registry[Type[BaseQuantizer]](registry_name="CustomQuantizersRegistry") |
| 82 | + |
55 | 83 |
|
56 | 84 | class DynamicActProxyMixin(ExtendedInjector): |
57 | 85 | proxy_class = DynamicActQuantProxyFromInjector |
@@ -226,31 +254,3 @@ class FP8e4m3FNUZDynamicActPerRowFloat(Fp8e4m3FNUZActPerTensorFloat): |
226 | 254 |
|
227 | 255 | class Fp8e4m3WeightPerChannelFloatMSE(MSESymmetricScale, Fp8e4m3WeightPerChannelFloat): |
228 | 256 | pass |
229 | | - |
230 | | - |
231 | | -# TODO: Subject to change |
232 | | -class BaseQuantizer: |
233 | | - weight_quant: ClassVar[Optional[ExtendedInjector]] = None # type: ignore |
234 | | - linear_input_quant: ClassVar[Optional[ExtendedInjector]] = None # type: ignore |
235 | | - input_quant: ClassVar[Optional[ExtendedInjector]] = None # type: ignore |
236 | | - q_scaled_quant: ClassVar[Optional[ExtendedInjector]] = None # type: ignore |
237 | | - k_transposed_quant: ClassVar[Optional[ExtendedInjector]] = None # type: ignore |
238 | | - v_quant: ClassVar[Optional[ExtendedInjector]] = None # type: ignore |
239 | | - attn_output_weights_quant: ClassVar[Optional[ExtendedInjector]] = None # type: ignore |
240 | | - |
241 | | - @classmethod |
242 | | - def override_quantizers_dict( |
243 | | - cls: "BaseQuantizer", |
244 | | - quantizers_dict: Dict[str, Optional[ExtendedInjector]]): # type: ignore |
245 | | - for key in quantizers_dict: |
246 | | - if hasattr(cls, key) and (value := getattr(cls, key)) is not None: |
247 | | - quantizers_dict[key] = value |
248 | | - return quantizers_dict |
249 | | - |
250 | | - |
251 | | -CUSTOM_QUANTIZERS_REGISTRY = Registry[Type[BaseQuantizer]](registry_name="CustomQuantizersRegistry") |
252 | | - |
253 | | - |
254 | | -@CUSTOM_QUANTIZERS_REGISTRY.register("custom_quant") |
255 | | -class CustomQuantizerExample(BaseQuantizer): |
256 | | - weight_quant: ClassVar[Optional[ExtendedInjector]] = Int8WeightPerTensorFloat # type: ignore |
0 commit comments