|
9 | 9 | from typing import cast, List, Optional, Tuple |
10 | 10 |
|
11 | 11 | import torch |
| 12 | +from executorch.backends.transforms import get_shape |
12 | 13 | from executorch.backends.xnnpack.operators.quant_params import QuantParams |
13 | 14 | from executorch.backends.xnnpack.partition.config.xnnpack_config import ( |
14 | 15 | ConfigPrecisionType, |
@@ -358,18 +359,35 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: |
358 | 359 | why(node, "Only support 1D + 2D Conv") |
359 | 360 | return False # Only support 1D + 2D Conv |
360 | 361 |
|
361 | | - precision = self._detect_precision(node) |
362 | | - if precision == ConfigPrecisionType.DYNAMIC_QUANT and len(conv_stride) != 2: |
363 | | - why(node, "Only support 2D Conv for dynamic quantization") |
364 | | - return False |
365 | | - |
366 | 362 | kernel_node = get_input_node(node, 1) |
| 363 | + kernel_shape = get_shape(kernel_node) |
367 | 364 | weight_quant_params = QuantParams.from_weights(kernel_node, ep) |
368 | | - |
369 | | - is_transpose = node.args[6] |
370 | 365 | groups = cast(int, node.args[8]) |
| 366 | + is_transpose = node.args[6] |
| 367 | + |
| 368 | + if is_transpose: |
| 369 | + group_input_channels = int(kernel_shape[0] / groups) |
| 370 | + group_output_channels = kernel_shape[1] |
| 371 | + else: |
| 372 | + group_input_channels = kernel_shape[1] |
| 373 | + group_output_channels = int(kernel_shape[0] / groups) |
| 374 | + |
| 375 | + is_depthwise = ( |
| 376 | + group_input_channels == 1 |
| 377 | + and group_output_channels % group_input_channels == 0 |
| 378 | + ) |
| 379 | + |
| 380 | + # XNNPACK does not support dynamic quantization convs that are not 2D or are depthwise |
| 381 | + if self._detect_precision(node) == ConfigPrecisionType.DYNAMIC_QUANT and ( |
| 382 | + len(conv_stride) != 2 or is_depthwise |
| 383 | + ): |
| 384 | + why( |
| 385 | + node, |
| 386 | + "XNNPACK only supports standard 2D convolutions for dynamic quantization", |
| 387 | + ) |
| 388 | + return False |
371 | 389 |
|
372 | | - # XNNPack does not support non-zero output padding in transposed |
| 390 | + # XNNPACK does not support non-zero output padding in transposed |
373 | 391 | # convolutions. |
374 | 392 | if is_transpose and any( |
375 | 393 | out_pad != 0 for out_pad in cast(List[int], node.args[7]) |
|
0 commit comments