Skip to content

Commit eaba819

Browse files
committed
Add depthwise conv checks for dynamic quant
1 parent 7c53454 commit eaba819

File tree

2 files changed

+43
-9
lines changed

2 files changed

+43
-9
lines changed

backends/xnnpack/partition/config/gemm_configs.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import cast, List, Optional, Tuple
1010

1111
import torch
12+
from executorch.backends.transforms import get_shape
1213
from executorch.backends.xnnpack.operators.quant_params import QuantParams
1314
from executorch.backends.xnnpack.partition.config.xnnpack_config import (
1415
ConfigPrecisionType,
@@ -358,18 +359,35 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
358359
why(node, "Only support 1D + 2D Conv")
359360
return False # Only support 1D + 2D Conv
360361

361-
precision = self._detect_precision(node)
362-
if precision == ConfigPrecisionType.DYNAMIC_QUANT and len(conv_stride) != 2:
363-
why(node, "Only support 2D Conv for dynamic quantization")
364-
return False
365-
366362
kernel_node = get_input_node(node, 1)
363+
kernel_shape = get_shape(kernel_node)
367364
weight_quant_params = QuantParams.from_weights(kernel_node, ep)
368-
369-
is_transpose = node.args[6]
370365
groups = cast(int, node.args[8])
366+
is_transpose = node.args[6]
367+
368+
if is_transpose:
369+
group_input_channels = int(kernel_shape[0] / groups)
370+
group_output_channels = kernel_shape[1]
371+
else:
372+
group_input_channels = kernel_shape[1]
373+
group_output_channels = int(kernel_shape[0] / groups)
374+
375+
is_depthwise = (
376+
group_input_channels == 1
377+
and group_output_channels % group_input_channels == 0
378+
)
379+
380+
# XNNPACK does not support dynamic quantization convs that are not 2D or are depthwise
381+
if self._detect_precision(node) == ConfigPrecisionType.DYNAMIC_QUANT and (
382+
len(conv_stride) != 2 or is_depthwise
383+
):
384+
why(
385+
node,
386+
"XNNPACK only supports standard 2D convolutions for dynamic quantization",
387+
)
388+
return False
371389

372-
# XNNPack does not support non-zero output padding in transposed
390+
# XNNPACK does not support non-zero output padding in transposed
373391
# convolutions.
374392
if is_transpose and any(
375393
out_pad != 0 for out_pad in cast(List[int], node.args[7])

backends/xnnpack/quantizer/xnnpack_quantizer_utils.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ def _do_annotate_conv(
323323
assert isinstance(weight, Node)
324324
input_qspec_map[weight] = get_weight_qspec(quantization_config)
325325

326-
# Only annotate dynamically quantized conv if it's 2D
326+
# Only annotate dynamically quantized conv if it's 2D and not depthwise
327327
if (
328328
quantization_config
329329
and quantization_config.input_activation
@@ -336,6 +336,22 @@ def _do_annotate_conv(
336336
if weight_shape is not None and len(weight_shape) != 4:
337337
continue
338338

339+
# Default to 1 since groups is not available in the node
340+
groups = 1
341+
if is_conv_transpose:
342+
group_input_channels = int(weight_shape[0] / groups)
343+
group_output_channels = weight_shape[1]
344+
else:
345+
group_input_channels = weight_shape[1]
346+
group_output_channels = int(weight_shape[0] / groups)
347+
348+
# Skip if depthwise
349+
if (
350+
group_input_channels == 1
351+
and group_output_channels % group_input_channels == 0
352+
):
353+
continue
354+
339355
# adding weight node to the partition as well
340356
partition = [conv_node, conv_node.args[1]]
341357

0 commit comments

Comments
 (0)