Skip to content

Commit 899d7e5

Browse files
Arm backend: Move is_consumer_node_depthwise_conv2d to the pass using it (#13859)
The function is_consumer_node_depthwise_conv2d is only used by annotate_channels_last_dim_order_pass and can therefore be moved closer to where it is used. Signed-off-by: Sebastian Larsson <[email protected]>
1 parent 71a7806 commit 899d7e5

File tree

2 files changed

+14
-17
lines changed

2 files changed

+14
-17
lines changed

backends/arm/_passes/to_tosa_memory_format_pass.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
get_first_fake_tensor,
1313
is_param_node,
1414
)
15-
from executorch.backends.arm.tosa_utils import is_consumer_node_depthwise_conv2d
1615
from executorch.exir import ExportedProgram
1716
from executorch.exir.dialects._ops import ops as exir_ops
1817
from executorch.exir.pass_base import ExportPass, PassResult
@@ -43,6 +42,19 @@ def __init__(self, exported_program: ExportedProgram) -> None:
4342
self.exported_program = exported_program
4443
super().__init__()
4544

45+
@staticmethod
46+
def _is_consumer_node_depthwise_conv2d(node: torch.fx.Node):
47+
consumer_node = list(node.users)[0]
48+
if consumer_node.target == exir_ops.edge.aten.convolution.default:
49+
consumer_node_inputs = consumer_node.all_input_nodes
50+
groups = consumer_node.args[-1]
51+
in_channels = consumer_node_inputs[0].meta["val"].shape[1]
52+
out_channels = consumer_node_inputs[1].meta["val"].shape[0]
53+
if (in_channels == groups) and (out_channels % in_channels) == 0:
54+
return True
55+
56+
return False
57+
4658
def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node):
4759
"""
4860
returns True for w in the following sequence;
@@ -53,7 +65,7 @@ def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node):
5365
consumer_node = list(node.users)[0]
5466
if self.is_weight_node_for_depthwise_conv2d(consumer_node):
5567
return True
56-
if is_consumer_node_depthwise_conv2d(node):
68+
if self._is_consumer_node_depthwise_conv2d(node):
5769
# Check that node is the weight-argument and not input or bias
5870
return consumer_node.args[1] == node
5971

backends/arm/tosa_utils.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@
1616
import torch
1717

1818
from executorch.backends.arm.tosa_mapping import extract_tensor_meta
19-
2019
from executorch.backends.arm.tosa_specification import TosaSpecification
21-
from executorch.exir.dialects._ops import ops as exir_ops
2220

2321
from torch._subclasses.fake_tensor import FakeTensor
2422
from torch.fx import Node
@@ -155,19 +153,6 @@ def build_reshape_tosa_1_0(
155153
)
156154

157155

158-
def is_consumer_node_depthwise_conv2d(node: Node):
159-
consumer_node = list(node.users)[0]
160-
if consumer_node.target == exir_ops.edge.aten.convolution.default:
161-
consumer_node_inputs = consumer_node.all_input_nodes
162-
groups = consumer_node.args[-1]
163-
in_channels = consumer_node_inputs[0].meta["val"].shape[1]
164-
out_channels = consumer_node_inputs[1].meta["val"].shape[0]
165-
if (in_channels == groups) and (out_channels % in_channels) == 0:
166-
return True
167-
168-
return False
169-
170-
171156
def tosa_shape(shape, dim_order):
172157
reordered = tuple([shape[dim] for dim in dim_order])
173158
# Dynamic shapes in executorch are represented with torch.SymInt objects in the shapes,

0 commit comments

Comments
 (0)