|
18 | 18 | except: |
19 | 19 | has_quantized_ops = False |
20 | 20 |
|
| 21 | +from executorch.backends.xnnpack.partition.config.xnnpack_config import ( |
| 22 | + ConfigPrecisionType, |
| 23 | +) |
| 24 | +from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner |
21 | 25 | from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( |
22 | 26 | get_symmetric_quantization_config, |
23 | 27 | ) |
|
26 | 30 | ) |
27 | 31 | from executorch.backends.xnnpack.test.test_xnnpack_utils import randomize_bn |
28 | 32 | from executorch.backends.xnnpack.test.tester import Quantize, Tester |
29 | | - |
| 33 | +from executorch.backends.xnnpack.test.tester.tester import ( |
| 34 | + Partition, |
| 35 | + ToEdgeTransformAndLower, |
| 36 | +) |
30 | 37 | from executorch.exir.dialects._ops import ops as exir_ops |
31 | 38 |
|
32 | 39 |
|
@@ -223,6 +230,61 @@ def _test( |
223 | 230 | .run_method_and_compare_outputs(qtol=1) |
224 | 231 | ) |
225 | 232 |
|
| 233 | + def _test_dq_conv2d( |
| 234 | + self, |
| 235 | + m: torch.nn.Module, |
| 236 | + inputs, |
| 237 | + dynamic_shapes, |
| 238 | + atol=5e-02, |
| 239 | + ): |
| 240 | + quant_config = get_symmetric_quantization_config( |
| 241 | + is_per_channel=True, |
| 242 | + is_dynamic=True, |
| 243 | + act_qmin=-128, |
| 244 | + act_qmax=127, |
| 245 | + weight_qmin=-128, |
| 246 | + weight_qmax=127, |
| 247 | + ) |
| 248 | + |
| 249 | + DynamicallyQuantizedPartitioner = XnnpackPartitioner( |
| 250 | + config_precisions=ConfigPrecisionType.DYNAMIC_QUANT, |
| 251 | + per_op_mode=False, |
| 252 | + ) |
| 253 | + |
| 254 | + tester = Tester(m, inputs, dynamic_shapes=dynamic_shapes) |
| 255 | + tester = tester.quantize(Quantize(quantization_config=quant_config)) |
| 256 | + |
| 257 | + # Print after quantization |
| 258 | + tester.stages["quantize"] = tester.stages[tester.cur] |
| 259 | + print("\n----------Annotated Graph:") |
| 260 | + print(tester.stages["quantize"].graph_module.code) |
| 261 | + |
| 262 | + exported = tester.export() |
| 263 | + |
| 264 | + # Print after exporting |
| 265 | + tester.stages["export"] = exported.stages[exported.cur] |
| 266 | + print("\n----------Exported Graph:") |
| 267 | + print(tester.stages["export"].graph_module.code) |
| 268 | + |
| 269 | + # Check for choose_qparams |
| 270 | + tester.check(["torch.ops.quantized_decomposed.choose_qparams"]) |
| 271 | + |
| 272 | + tester.to_edge_transform_and_lower( |
| 273 | + ToEdgeTransformAndLower([DynamicallyQuantizedPartitioner]) |
| 274 | + ) |
| 275 | + |
| 276 | + # Print after lower and partition |
| 277 | + print("\n----------Lowered Graph:") |
| 278 | + print(tester.stages[tester.cur].graph_module.code) |
| 279 | + |
| 280 | + tester.check(["executorch_exir_dialects_edge__ops_aten_convolution_default"]) |
| 281 | + tester.check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) |
| 282 | + tester.check_not(["executorch_exir_dialects_edge__ops_aten_conv2d_default"]) |
| 283 | + |
| 284 | + tester.to_executorch() |
| 285 | + tester.serialize() |
| 286 | + tester.run_method_and_compare_outputs(atol=atol) |
| 287 | + |
226 | 288 | def test_fp16_conv2d(self) -> None: |
227 | 289 | for transpose in (True, False): |
228 | 290 | for has_bias in (True, False): |
@@ -699,3 +761,25 @@ def forward(self, x): |
699 | 761 | .serialize() |
700 | 762 | .run_method_and_compare_outputs(qtol=1) |
701 | 763 | ) |
| 764 | + |
| 765 | + def test_dq_conv2d(self) -> None: |
| 766 | + class SimpleConv2d(torch.nn.Module): |
| 767 | + def __init__(self): |
| 768 | + super().__init__() |
| 769 | + self.conv = torch.nn.Conv2d(1, 2, 3) |
| 770 | + self.conv.weight.requires_grad = False |
| 771 | + self.conv.bias.requires_grad = False |
| 772 | + |
| 773 | + def forward(self, x): |
| 774 | + return self.conv(x) |
| 775 | + |
| 776 | + def get_inputs(self): |
| 777 | + return (torch.randn(1, 1, 8, 8),) |
| 778 | + |
| 779 | + model = SimpleConv2d() |
| 780 | + self._test_dq_conv2d( |
| 781 | + model, |
| 782 | + model.get_inputs(), |
| 783 | + dynamic_shapes=None, |
| 784 | + atol=5e-2, |
| 785 | + ) |
0 commit comments