|
60 | 60 |
|
61 | 61 | bias_qspec: Optional[QuantizationSpec] = None |
62 | 62 |
|
| 63 | +_default_qconfig = QuantizationConfig( |
| 64 | + act_qspec, |
| 65 | + act_qspec, |
| 66 | + wgt_qspec, |
| 67 | + None, |
| 68 | +) |
| 69 | + |
63 | 70 |
|
64 | 71 | class CadenceAtenQuantizer(Quantizer): |
65 | 72 | def __init__( |
@@ -140,31 +147,39 @@ def get_supported_operators(cls) -> List[OperatorConfig]: |
140 | 147 | return [] |
141 | 148 |
|
142 | 149 |
|
| 150 | +def get_cadence_default_quantizer_list_with_config( |
| 151 | + quantization_config: QuantizationConfig, |
| 152 | +) -> List[Quantizer]: |
| 153 | + return [ |
| 154 | + CadenceAtenQuantizer(AddmmPattern(), quantization_config), |
| 155 | + CadenceAtenQuantizer(BmmPattern(), quantization_config), |
| 156 | + CadenceAtenQuantizer(Conv1dPattern(), quantization_config), |
| 157 | + CadenceAtenQuantizer(Conv2dPattern(), quantization_config), |
| 158 | + CadenceAtenQuantizer(LayerNormPattern(), quantization_config), |
| 159 | + CadenceAtenQuantizer(LinearPattern(), quantization_config), |
| 160 | + CadenceAtenQuantizer(MatmulPattern(), quantization_config), |
| 161 | + CadenceAtenQuantizer(ReluPattern0(), quantization_config), |
| 162 | + CadenceAtenQuantizer(ReluPattern1(), quantization_config), |
| 163 | + ] |
| 164 | + |
| 165 | + |
143 | 166 | class CadenceQuantizer(ComposableQuantizer): |
144 | | - def __init__( |
145 | | - self, quantization_config: Optional[QuantizationConfig] = None |
146 | | - ) -> None: |
147 | | - static_qconfig = ( |
148 | | - QuantizationConfig( |
149 | | - act_qspec, |
150 | | - act_qspec, |
151 | | - wgt_qspec, |
152 | | - None, |
153 | | - ) |
154 | | - if not quantization_config |
155 | | - else quantization_config |
156 | | - ) |
| 167 | + """ |
| 168 | + Generic CadenceQuantizer. Although it can be used directly, it is typically a base |
| 169 | + class for explicitly defined quantizers (like CadenceDefaultQuantizer). |
| 170 | + """ |
157 | 171 |
|
158 | | - super().__init__( |
159 | | - [ |
160 | | - CadenceAtenQuantizer(AddmmPattern(), static_qconfig), |
161 | | - CadenceAtenQuantizer(BmmPattern(), static_qconfig), |
162 | | - CadenceAtenQuantizer(Conv1dPattern(), static_qconfig), |
163 | | - CadenceAtenQuantizer(Conv2dPattern(), static_qconfig), |
164 | | - CadenceAtenQuantizer(LayerNormPattern(), static_qconfig), |
165 | | - CadenceAtenQuantizer(LinearPattern(), static_qconfig), |
166 | | - CadenceAtenQuantizer(MatmulPattern(), static_qconfig), |
167 | | - CadenceAtenQuantizer(ReluPattern0(), static_qconfig), |
168 | | - CadenceAtenQuantizer(ReluPattern1(), static_qconfig), |
169 | | - ] |
170 | | - ) |
| 172 | + def __init__(self, quantizers: List[Quantizer]) -> None: |
| 173 | + super().__init__(quantizers) |
| 174 | + |
| 175 | + |
| 176 | +class CadenceDefaultQuantizer(CadenceQuantizer): |
| 177 | + """ |
| 178 | + Default quantizer for Cadence backend. |
| 179 | + """ |
| 180 | + |
| 181 | + def __init__(self, qconfig: Optional[QuantizationConfig] = None) -> None: |
| 182 | + if qconfig is None: |
| 183 | + qconfig = _default_qconfig |
| 184 | + quantizers = get_cadence_default_quantizer_list_with_config(qconfig) |
| 185 | + super().__init__(quantizers) |
0 commit comments