Skip to content

Commit e336df6

Browse files
committed
Move depthwise conv check to helper function in utils
1 parent 5a01127 commit e336df6

File tree

3 files changed

+36
-28
lines changed

3 files changed

+36
-28
lines changed

backends/xnnpack/partition/config/gemm_configs.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
)
2929
from executorch.backends.xnnpack.utils.utils import (
3030
get_input_node,
31+
is_depthwise_conv,
3132
is_getitem,
3233
is_node,
3334
is_param_node,
@@ -365,21 +366,10 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
365366
groups = cast(int, node.args[8])
366367
is_transpose = node.args[6]
367368

368-
if is_transpose:
369-
group_input_channels = int(kernel_shape[0] / groups)
370-
group_output_channels = kernel_shape[1]
371-
else:
372-
group_input_channels = kernel_shape[1]
373-
group_output_channels = int(kernel_shape[0] / groups)
374-
375-
is_depthwise = (
376-
group_input_channels == 1
377-
and group_output_channels % group_input_channels == 0
378-
)
379-
380369
# XNNPACK does not support dynamic quantization convs that are not 2D or are depthwise
381370
if self._detect_precision(node) == ConfigPrecisionType.DYNAMIC_QUANT and (
382-
len(conv_stride) != 2 or is_depthwise
371+
len(conv_stride) != 2
372+
or is_depthwise_conv(kernel_shape, groups, is_transpose)
383373
):
384374
why(
385375
node,

backends/xnnpack/quantizer/xnnpack_quantizer_utils.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import torch
88
import torch.nn.functional as F
9+
from executorch.backends.xnnpack.utils.utils import is_depthwise_conv
910
from torch._subclasses import FakeTensor
1011
from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
1112
from torch.ao.quantization.pt2e.export_utils import _WrapperModule
@@ -29,7 +30,6 @@
2930
)
3031
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
3132

32-
3333
__all__ = [
3434
"OperatorConfig",
3535
"OperatorPatternType",
@@ -336,20 +336,8 @@ def _do_annotate_conv(
336336
if weight_shape is not None and len(weight_shape) != 4:
337337
continue
338338

339-
# Default to 1 since groups is not available in the node
340-
groups = 1
341-
if is_conv_transpose:
342-
group_input_channels = int(weight_shape[0] / groups)
343-
group_output_channels = weight_shape[1]
344-
else:
345-
group_input_channels = weight_shape[1]
346-
group_output_channels = int(weight_shape[0] / groups)
347-
348-
# Skip if depthwise
349-
if (
350-
group_input_channels == 1
351-
and group_output_channels % group_input_channels == 0
352-
):
339+
# Skip if depthwise (default to groups=1 since it's not an arg)
340+
if is_depthwise_conv(weight_shape, 1, is_conv_transpose):
353341
continue
354342

355343
# adding weight node to the partition as well

backends/xnnpack/utils/utils.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,33 @@ def get_source_fn(node: torch.fx.Node) -> Optional[torch.fx.Node]:
158158
return None
159159
source_fn = source_fn_st[-1]
160160
return source_fn[1]
161+
162+
163+
def is_depthwise_conv(
164+
kernel_shape: Tuple[int, ...], groups: int = 1, is_transpose: bool = False
165+
) -> bool:
166+
"""
167+
A convolution is depthwise if:
168+
1) groups = input_channels (i.e. group_input_channels = 1)
169+
2) output_channels is a positive integer multiple of input channels
170+
171+
For standard convolutions:
172+
weight shape = (out_channels, in_channels_per_group, height, width)
173+
For transposed convolutions:
174+
weight shape = (in_channels, out_channels_per_group, height, width)
175+
176+
Returns True if the convolution is depthwise
177+
"""
178+
if len(kernel_shape) < 2 or groups < 1:
179+
return False
180+
181+
if is_transpose:
182+
group_input_channels = int(kernel_shape[0] / groups)
183+
group_output_channels = kernel_shape[1]
184+
else:
185+
group_input_channels = kernel_shape[1]
186+
group_output_channels = int(kernel_shape[0] / groups)
187+
188+
return (
189+
group_input_channels == 1 and group_output_channels % group_input_channels == 0
190+
)

0 commit comments

Comments
 (0)