From d5b863077b3c0ccb13261276d5d65cc80623b011 Mon Sep 17 00:00:00 2001 From: Oscar Andersson Date: Mon, 19 May 2025 14:01:28 +0200 Subject: [PATCH] Arm backend: Add backend (DW) Conv2d dialect Adds backend dialect operators for: - Conv2D - Depthwise Conv2D Also removes AddBiasPass and moves that logic to new RewriteConv2dPass. Signed-off-by: Oscar Andersson Change-Id: I1185f51ded5e931ca042f445934cd21e20fdf469 --- backends/arm/_passes/__init__.py | 2 +- backends/arm/_passes/add_bias_pass.py | 76 ------ backends/arm/_passes/arm_pass_manager.py | 10 +- backends/arm/_passes/conv1d_unsqueeze_pass.py | 7 +- backends/arm/_passes/decompose_cumsum_pass.py | 4 +- backends/arm/_passes/rewrite_conv2d_pass.py | 237 ++++++++++++++++ .../arm/_passes/size_adjust_input_pass.py | 3 +- .../arm/_passes/to_tosa_memory_format_pass.py | 34 --- backends/arm/operators/__init__.py | 3 +- .../{op_conv2d.py => op_tosa_conv2d.py} | 141 ++-------- .../arm/operators/op_tosa_depthwise_conv2d.py | 31 +++ .../arm/test/misc/test_tosa_dialect_conv2d.py | 250 +++++++++++++++++ .../test/misc/test_tosa_dialect_dw_conv2d.py | 257 ++++++++++++++++++ backends/arm/tosa/dialect/__init__.py | 2 + backends/arm/tosa/dialect/ops/conv2d.py | 118 ++++++++ .../arm/tosa/dialect/ops/depthwise_conv2d.py | 65 +++++ 16 files changed, 995 insertions(+), 245 deletions(-) delete mode 100644 backends/arm/_passes/add_bias_pass.py create mode 100644 backends/arm/_passes/rewrite_conv2d_pass.py rename backends/arm/operators/{op_conv2d.py => op_tosa_conv2d.py} (58%) create mode 100644 backends/arm/operators/op_tosa_depthwise_conv2d.py create mode 100644 backends/arm/test/misc/test_tosa_dialect_conv2d.py create mode 100644 backends/arm/test/misc/test_tosa_dialect_dw_conv2d.py create mode 100644 backends/arm/tosa/dialect/ops/conv2d.py create mode 100644 backends/arm/tosa/dialect/ops/depthwise_conv2d.py diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index b1337c38a58..de9a793b9aa 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -6,7 +6,6 @@ from . import arm_pass_utils # noqa from .arm_pass import ArmPass # noqa # usort: skip -from .add_bias_pass import AddBiasPass # noqa from .annotate_decomposed_matmul import AnnotateDecomposedMatmulPass # noqa from .annotate_output_dim_order_pass import AnnotateOutputDimOrderPass # noqa from .broadcast_args_pass import BroadcastArgsPass # noqa @@ -92,6 +91,7 @@ ReplaceScalarWithTensorArgPassTOSABI, ReplaceScalarWithTensorArgPassTOSAMI, ) +from .rewrite_conv2d_pass import RewriteConv2dPass # noqa from .rewrite_matmul import RewriteMatmulPass # noqa from .rewrite_upsample import RewriteUpsamplePass # noqa from .scalars_to_attribute_pass import ScalarsToAttributePass # noqa diff --git a/backends/arm/_passes/add_bias_pass.py b/backends/arm/_passes/add_bias_pass.py deleted file mode 100644 index 2114d56ef5b..00000000000 --- a/backends/arm/_passes/add_bias_pass.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright 2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from typing import Set, Type - -import torch -from executorch.backends.arm._passes import ArmPass -from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor -from executorch.backends.arm.tosa.mapping import TosaSpecialDtype -from executorch.backends.transforms.utils import create_constant_placeholder -from executorch.exir import ExportedProgram - -from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass, PassResult -from torch.export.graph_signature import InputKind - - -class AddBiasPass(ArmPass): - """TOSA requires convolution nodes to have a bias input. - This pass adds a bias input to convolution nodes that do not have one. - The bias is set to zero. - """ - - _passes_required_after: Set[Type[ExportPass]] = set() - - targeted_ops = (exir_ops.edge.aten.convolution.default,) - - def __init__(self, exported_program: ExportedProgram) -> None: - super().__init__() - self.exported_program = exported_program - - def call(self, graph_module): - modified = False - for node in graph_module.graph.nodes: - if node.op != "call_function": - continue - if node.target not in self.targeted_ops: - continue - - if len(node.all_input_nodes) < 3: - modified = True - # bias is missing - weight_node = node.all_input_nodes[1] - output_channels = get_first_fake_tensor(weight_node).shape[0] - # add a node containging zeros - # if quantized, use int32, otherwise use float32 - if ( - "output_qparams" in node.meta - and len(node.meta["output_qparams"]) > 0 - ): - bias_data = torch.zeros(size=(output_channels,), dtype=torch.int32) - else: - bias_data = torch.zeros( - size=(output_channels,), dtype=torch.float32 - ) - - with graph_module.graph.inserting_after(weight_node): - bias_node = create_constant_placeholder( - self.exported_program, - graph=graph_module.graph, - kind=InputKind.PARAMETER, - data=bias_data, - persistent_buffer=True, - name=f"{node.name}_bias", - ) - if node.args[0].meta["val"].dtype == torch.int16: - bias_node.meta[TosaSpecialDtype.meta_key()] = ( - TosaSpecialDtype.INT48 - ) - node.update_arg(2, bias_node) - - if modified: - graph_module = super().call(graph_module).graph_module - return PassResult(graph_module, modified) diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index d0d3aae148f..b579d910752 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -12,7 +12,6 @@ import executorch.backends.arm.tosa.dialect # noqa: unused from executorch.backends.arm._passes import ( - AddBiasPass, AnnotateDecomposedMatmulPass, AnnotateOutputDimOrderPass, BroadcastArgsPass, @@ -93,6 +92,7 @@ ReplaceScalarWithTensorArgPassTOSABI, ReplaceScalarWithTensorArgPassTOSAMI, RetraceFoldedDtypesPass, + RewriteConv2dPass, RewriteMatmulPass, RewriteUpsamplePass, ScalarsToAttributePass, @@ -207,13 +207,13 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule: self.add_pass(InsertTableOpsPass(exported_program)) # If we have a conv2d with int16 activation split up into a convolution # and an addition, to work-around the lack of support for int48 in torch - # needs to happen before AddBiasPass, but after the table ops are inserted + # needs to happen before RewriteConv2dPass, but after the table ops are inserted # to be able to validate that conv2d has right dtype arguments. self.add_pass(DecomposeConv2dWithInt16ActivationPass()) - self.add_pass(RewriteUpsamplePass()) - self.add_pass(AddBiasPass(exported_program)) + self.add_pass(RewriteConv2dPass(exported_program)) self.add_pass(RewriteMatmulPass()) + self.add_pass(RewriteUpsamplePass()) self.add_pass(FuseEqualPlaceholdersPass(exported_program)) self.add_pass(ToTosaMemoryFormatPass(exported_program)) self.add_pass(RemoveNoopPass()) @@ -297,9 +297,9 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule: self.add_pass(FuseViewCopyTransform()) self.add_pass(FuseConstantArgsPass(exported_program)) + self.add_pass(RewriteConv2dPass(exported_program)) self.add_pass(CastInt64BuffersToInt32Pass(exported_program)) self.add_pass(RewriteUpsamplePass()) - self.add_pass(AddBiasPass(exported_program)) self.add_pass(InsertTableOpsPass(exported_program)) self.add_pass(RewriteMatmulPass()) self.add_pass(FuseEqualPlaceholdersPass(exported_program)) diff --git a/backends/arm/_passes/conv1d_unsqueeze_pass.py b/backends/arm/_passes/conv1d_unsqueeze_pass.py index 7784c850278..a368f1b65ed 100644 --- a/backends/arm/_passes/conv1d_unsqueeze_pass.py +++ b/backends/arm/_passes/conv1d_unsqueeze_pass.py @@ -10,7 +10,7 @@ from executorch.backends.arm._passes import ArmPass -from executorch.backends.arm._passes.add_bias_pass import AddBiasPass +from executorch.backends.arm._passes.rewrite_conv2d_pass import RewriteConv2dPass from executorch.backends.arm._passes.size_adjust_input_pass import SizeAdjustInputPass from executorch.exir.dialects._ops import ops as exir_ops @@ -28,7 +28,10 @@ class Conv1dUnsqueezePass(ArmPass): 3) squeeze the output back down to 3d. """ - _passes_required_after: Set[Type[ExportPass]] = {AddBiasPass, SizeAdjustInputPass} + _passes_required_after: Set[Type[ExportPass]] = { + RewriteConv2dPass, + SizeAdjustInputPass, + } def call_operator(self, op, args, kwargs, meta): if op != exir_ops.edge.aten.convolution.default: diff --git a/backends/arm/_passes/decompose_cumsum_pass.py b/backends/arm/_passes/decompose_cumsum_pass.py index 2111c654817..7066fdb16eb 100644 --- a/backends/arm/_passes/decompose_cumsum_pass.py +++ b/backends/arm/_passes/decompose_cumsum_pass.py @@ -8,9 +8,9 @@ import torch from executorch.backends.arm._passes import ArmPass -from executorch.backends.arm._passes.add_bias_pass import AddBiasPass from executorch.backends.arm._passes.arm_pass_utils import create_node from executorch.backends.arm._passes.quant_args import QuantArgs +from executorch.backends.arm._passes.rewrite_conv2d_pass import RewriteConv2dPass from executorch.backends.transforms.utils import create_constant_placeholder from executorch.exir import ExportedProgram @@ -42,7 +42,7 @@ class DecomposeCumsumPass(ArmPass): And the convolution is applied over dimension H. """ - _passes_required_after: Set[Type[ExportPass]] = {AddBiasPass} + _passes_required_after: Set[Type[ExportPass]] = {RewriteConv2dPass} def __init__(self, exported_program: ExportedProgram) -> None: super().__init__() diff --git a/backends/arm/_passes/rewrite_conv2d_pass.py b/backends/arm/_passes/rewrite_conv2d_pass.py new file mode 100644 index 00000000000..8b4f43c35c7 --- /dev/null +++ b/backends/arm/_passes/rewrite_conv2d_pass.py @@ -0,0 +1,237 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Set, Type + +import torch +from executorch.backends.arm._passes import ArmPass + +from executorch.backends.arm._passes.arm_pass_utils import ( + create_node, + get_first_fake_tensor, + get_param_tensor, + is_buffer, + is_param, +) +from executorch.backends.arm.constants import HWCM_ORDER, NHWC_INVERSE_ORDER +from executorch.backends.arm.tosa.mapping import TosaSpecialDtype +from executorch.backends.transforms.utils import create_constant_placeholder +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult +from torch.export.graph_signature import InputKind + + +class RewriteConv2dPass(ArmPass): + """Rewrites aten.convolution to tosa.CONV2D or tosa.DEPTHWISE_CONV2D.""" + + def __init__(self, exported_program: torch.export.ExportedProgram): + super().__init__() + self.exported_program = exported_program + + _passes_required_after: Set[Type[ExportPass]] = set() + + # torch.nn.Conv2d does not require the result of + # `(input + 2 * pad - dilation * (weight - 1) - 1) / stride` + # to be an integer, but tosa currently strictly require this property. + # This function adjusts the pad value to meet the requirement. + def _adjust_pad_if_needed( + self, input_len: int, input_weight: int, stride: int, pad: int, dilation: int + ) -> int: + """Adjust padding to satisfy TOSA's integer output-size requirement. + + Torch ``Conv2d`` does not require the result of + ``(input + 2 * pad - dilation * (weight - 1) - 1) / stride`` to be an + integer, but TOSA does. This helper reduces the provided padding so + that the expression becomes divisible by ``stride``. + + Args: + input_size (int): Spatial input size along the dimension (H or W). + input_weight (int): Kernel size along the same dimension. + stride (int): Stride along the same dimension. + pad (int): Padding value to adjust (bottom or right after duplication). + dilation (int): Dilation along the same dimension. + + Returns: + int: Adjusted padding value that yields an integer output size. + + Raises: + RuntimeError: If the required adjustment exceeds the provided + padding, which should be handled by the ``SizeAdjustInputPass`` + pass instead. + + """ + mod_remainder = ( + input_len + 2 * pad - dilation * (input_weight - 1) - 1 + ) % stride + + # No need to adjust + if mod_remainder == 0: + return pad + + if mod_remainder > pad: + raise RuntimeError( + "This case should be handled by the SizeAdjustInputPass, is it enabled?" + ) + return pad - mod_remainder + + def _is_depthwise_conv2d(self, node: torch.fx.Node) -> bool: + if ( + node.op != "call_function" + or node.target != exir_ops.edge.aten.convolution.default + ): + return False + groups = node.args[-1] + in_channels = get_first_fake_tensor(node.all_input_nodes[0]).shape[1] + out_channels = get_first_fake_tensor(node.all_input_nodes[1]).shape[0] + return (in_channels == groups) and (out_channels % in_channels) == 0 + + def _reshape_weights(self, weight_node: torch.fx.Node, in_channels: int) -> None: + """Reshape the weights for depthwise convolution such that when serialized to TOSA, + the weights are in the format [H, W, in_channels, m_length] where + m_length is the number of output channels per input channel. + """ + weight_tensor = get_param_tensor(self.exported_program, weight_node) # type: ignore[arg-type] + if weight_tensor is None: + raise RuntimeError( + f"Weight node {weight_node.name} is not a parameter or buffer" + ) + reshaped_weight_tensor = ( + weight_tensor.permute(HWCM_ORDER) + .reshape( + weight_tensor.shape[2], + weight_tensor.shape[3], + in_channels, + weight_tensor.shape[0] // in_channels, + ) + .permute(NHWC_INVERSE_ORDER) + ) + + if is_buffer(self.exported_program, weight_node): + param_name = self.exported_program.graph_signature.inputs_to_buffers[ + weight_node.name + ] + elif is_param(self.exported_program, weight_node): + param_name = self.exported_program.graph_signature.inputs_to_parameters[ + weight_node.name + ] + else: + raise RuntimeError( + f"Weight node {weight_node.name} is neither a parameter nor a buffer" + ) + self.exported_program.state_dict[param_name] = reshaped_weight_tensor + weight_node.meta["val"] = weight_node.meta["val"].reshape( + weight_tensor.shape[2], + weight_tensor.shape[0] // in_channels, + weight_tensor.shape[3], + in_channels, + ) + + def _add_bias( + self, + graph_module: torch.fx.GraphModule, + node: torch.fx.Node, + weight_node: torch.fx.Node, + ) -> torch.fx.Node: + output_channels = get_first_fake_tensor(node).shape[1] + # add a node containging zeros if quantized, use int32, otherwise use float32 + if "output_qparams" in node.meta and len(node.meta["output_qparams"]) > 0: + bias_data = torch.zeros(size=(output_channels,), dtype=torch.int32) + else: + bias_data = torch.zeros(size=(output_channels,), dtype=torch.float32) + + with graph_module.graph.inserting_after(weight_node): + bias_node = create_constant_placeholder( + self.exported_program, + graph=graph_module.graph, + kind=InputKind.PARAMETER, + data=bias_data, + persistent_buffer=True, + name=f"{node.name}_bias", + ) + if node.all_input_nodes[0].meta["val"].dtype == torch.int16: + bias_node.meta[TosaSpecialDtype.meta_key()] = TosaSpecialDtype.INT48 + node.update_arg(2, bias_node) + return bias_node + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + modified = False + for node in graph_module.graph.nodes: + if ( + node.op != "call_function" + or node.target != exir_ops.edge.aten.convolution.default + ): + continue + + modified = True + + ( + x, + weight, + bias, + stride, + pad, + dilation, + transposed, + output_pad, + group, + ) = node.args + + pad = [val for val in pad for _ in (0, 1)] + input_shape = get_first_fake_tensor(x).shape + weight_shape = get_first_fake_tensor(weight).shape + # Adjust the pad value if needed to meet the + # strict convolution output shape calculation. + pad[1] = self._adjust_pad_if_needed( + input_shape[2], + weight_shape[2], + stride[0], + pad[1], + dilation[0], + ) + pad[3] = self._adjust_pad_if_needed( + input_shape[3], + weight_shape[3], + stride[1], + pad[3], + dilation[1], + ) + + if bias is None: + bias = self._add_bias(graph_module, node, weight) + + if self._is_depthwise_conv2d(node): + target_op = exir_ops.backend.tosa.DEPTHWISE_CONV2D.default + self._reshape_weights(weight, input_shape[1]) + else: + target_op = exir_ops.backend.tosa.CONV2D.default + + conv2d_args = ( + x, + weight, + bias, + stride, + pad, + dilation, + transposed, + output_pad, + group, + ) + + with graph_module.graph.inserting_after(node): + tosa_op = create_node( + graph=graph_module.graph, + op_target=target_op, + args=conv2d_args, + from_node=node, + ) + + node.replace_all_uses_with(tosa_op) + graph_module.graph.erase_node(node) + + if modified: + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + return PassResult(graph_module, modified) diff --git a/backends/arm/_passes/size_adjust_input_pass.py b/backends/arm/_passes/size_adjust_input_pass.py index c82bcab947c..d0cc164ba30 100644 --- a/backends/arm/_passes/size_adjust_input_pass.py +++ b/backends/arm/_passes/size_adjust_input_pass.py @@ -10,6 +10,7 @@ import torch.fx from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import create_node +from executorch.backends.arm._passes.rewrite_conv2d_pass import RewriteConv2dPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -186,7 +187,7 @@ class SizeAdjustInputPass(ArmPass): input. """ - _passes_required_after: Set[Type[ExportPass]] = set() + _passes_required_after: Set[Type[ExportPass]] = {RewriteConv2dPass} def call(self, graph_module: torch.fx.GraphModule) -> PassResult: graph = graph_module.graph diff --git a/backends/arm/_passes/to_tosa_memory_format_pass.py b/backends/arm/_passes/to_tosa_memory_format_pass.py index 75646ce4379..3783f782610 100644 --- a/backends/arm/_passes/to_tosa_memory_format_pass.py +++ b/backends/arm/_passes/to_tosa_memory_format_pass.py @@ -20,7 +20,6 @@ is_param_node, ) from executorch.backends.arm.constants import ( - HWCM_ORDER, NCHW_ORDER, NHWC_INVERSE_ORDER, NHWC_ORDER, @@ -59,35 +58,6 @@ def __init__(self, exported_program: ExportedProgram) -> None: super().__init__() self.exported_program = exported_program - @staticmethod - def _is_consumer_node_depthwise_conv2d(node: torch.fx.Node): - consumer_node = list(node.users)[0] - if consumer_node.target == exir_ops.edge.aten.convolution.default: - consumer_node_inputs = consumer_node.all_input_nodes - groups = consumer_node.args[-1] - in_channels = consumer_node_inputs[0].meta["val"].shape[1] - out_channels = consumer_node_inputs[1].meta["val"].shape[0] - if (in_channels == groups) and (out_channels % in_channels) == 0: - return True - - return False - - def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node): - """ - returns True for w in the following sequence; - w -> depthwise_conv2d -> ... - """ - if node.op == "placeholder": - # node is an input, weight or bias node - consumer_node = list(node.users)[0] - if self.is_weight_node_for_depthwise_conv2d(consumer_node): - return True - if self._is_consumer_node_depthwise_conv2d(node): - # Check that node is the weight-argument and not input or bias - return consumer_node.args[1] == node - - return False - @staticmethod def memory_format_differs(shape): """Returns true if the shape will have a different memory layout in (N)NCHW and (N)NHWC format""" @@ -333,10 +303,6 @@ def call(self, graph_module: torch.fx.GraphModule): dim_order = node_data.dim_order() elif node_data.dim() == 4: dim_order = NHWC_ORDER - if self.is_weight_node_for_depthwise_conv2d(node): - # The weights of TOSA DEPTHWISE_CONV2D have shape (H, W, C, M) which corresponds to - # dim_order = (2, 3, 0, 1) (https://www.mlplatform.org/tosa/tosa_spec.html#_depthwise_conv2d). - dim_order = HWCM_ORDER elif node_data.dim() == 5: dim_order = NNHWC_ORDER elif node_data.dim() == 6: diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index 9278d25959f..e7812630f91 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -18,7 +18,6 @@ op_ceil, op_clamp, op_constant_pad_nd, - op_conv2d, op_cos, op_eq, op_erf, @@ -50,6 +49,8 @@ op_sum, op_tanh, op_to_dim_order_copy, + op_tosa_conv2d, + op_tosa_depthwise_conv2d, op_tosa_matmul, op_tosa_rescale, op_tosa_resize, diff --git a/backends/arm/operators/op_conv2d.py b/backends/arm/operators/op_tosa_conv2d.py similarity index 58% rename from backends/arm/operators/op_conv2d.py rename to backends/arm/operators/op_tosa_conv2d.py index 933e353387b..ec61995dc1b 100644 --- a/backends/arm/operators/op_conv2d.py +++ b/backends/arm/operators/op_tosa_conv2d.py @@ -31,13 +31,9 @@ @register_node_visitor class Conv2dVisitor(NodeVisitor): - """Provide a visitor that lowers ``aten.convolution`` to TOSA. + """Provide a visitor that serializes TOSA ``CONV2D``.""" - Map to ``CONV2D`` or ``DEPTHWISE_CONV2D`` as appropriate. - - """ - - target = "aten.convolution.default" + target = "tosa.CONV2D.default" tosa_specs = [ TosaSpecification.create_from_string("TOSA-1.0+INT"), @@ -47,46 +43,13 @@ class Conv2dVisitor(NodeVisitor): def __init__(self, *args): super().__init__(*args) - def adjust_pad_if_needed( - self, input_size: int, input_weight: int, stride: int, pad: int, dilation: int - ) -> int: - """Adjust padding to satisfy TOSA's integer output-size requirement. - - Torch ``Conv2d`` does not require the result of - ``(input + 2 * pad - dilation * (weight - 1) - 1) / stride`` to be an - integer, but TOSA does. This helper reduces the provided padding so - that the expression becomes divisible by ``stride``. - - Args: - input_size (int): Spatial input size along the dimension (H or W). - input_weight (int): Kernel size along the same dimension. - stride (int): Stride along the same dimension. - pad (int): Padding value to adjust (bottom or right after duplication). - dilation (int): Dilation along the same dimension. - - Returns: - int: Adjusted padding value that yields an integer output size. - - Raises: - RuntimeError: If the required adjustment exceeds the provided - padding, which should be handled by the ``SizeAdjustInputPass`` - pass instead. - - """ - mod_remainder = ( - input_size + 2 * pad - dilation * (input_weight - 1) - 1 - ) % stride + def _get_tosa_op(self): + import serializer.tosa_serializer as ts # type: ignore - # No need to adjust - if mod_remainder == 0: - return pad + return ts.TosaOp.Op().CONV2D - if mod_remainder > pad: - raise RuntimeError( - "This case should be handled by the SizeAdjustInputPass pass, " - "is it enabled?" - ) - return pad - mod_remainder + def _get_attr_func(self, attr): + return attr.Conv2dAttribute def define_node( self, @@ -129,28 +92,10 @@ def define_node( ) # Get the attributes of convolution. - attr = ts.TosaSerializerAttribute() - pad_attr = [val for val in pad.special for _ in (0, 1)] + pad_attr = pad.special stride_attr = stride.special dilation_attr = dilation.special - # Adjust the pad value if needed to meet the - # strict convolution output shape calculation. - pad_attr[1] = self.adjust_pad_if_needed( - input.shape[2], - weight.shape[2], - stride_attr[0], - pad_attr[1], - dilation_attr[0], - ) - pad_attr[3] = self.adjust_pad_if_needed( - input.shape[3], - weight.shape[3], - stride_attr[1], - pad_attr[3], - dilation_attr[1], - ) - input_zp = 0 if inputs[0].dtype in (ts.DType.INT8, ts.DType.INT16): # int8 and int16 input requires quantization information @@ -191,66 +136,16 @@ def define_node( name=f"{conv2d_output_name}_weight_zp", ) - # Given input.shape is (N, Ci, H, W), and weight.shape is (Co, Ci/G, H, W) - in_channels = input.shape[1] - out_channels = weight.shape[0] - if (in_channels == group.number) and (out_channels % in_channels) == 0: - """Depthwise convolution case.""" - # Reshape torch shape format of weight tensor to tosa required format. - # https://www.mlplatform.org/tosa/tosa_spec.html#_depthwise_conv2d - m_length = int(out_channels / in_channels) - weight_post_shape = [ - weight.shape[2], - weight.shape[3], - in_channels, - m_length, - ] - - weight_reshaped = tosa_graph.addIntermediate( - weight_post_shape, - weight.dtype, - ) - shape = tosa_graph.addConst( - [len(weight_post_shape)], - ts.DType.SHAPE, - weight_post_shape, - name=weight_reshaped.name + "_shape", - ) - - reshape_attr = ts.TosaSerializerAttribute() - reshape_attr.ReshapeAttribute() - self._serialize_operator( - node, - tosa_graph, - ts.TosaOp.Op().RESHAPE, - [weight.name, shape.name], - [weight_reshaped.name], - reshape_attr, - ) - - attr = ts.TosaSerializerAttribute() - tosa_op = ts.TosaOp.Op().DEPTHWISE_CONV2D - weight_name = weight_reshaped.name + tosa_op = self._get_tosa_op() - attr.DepthwiseConv2dAttribute( - pad=pad_attr, - stride=stride_attr, - dilation=dilation_attr, - local_bound=False, - acc_type=acc_type, - ) - else: - """Regular convolution case.""" - tosa_op = ts.TosaOp.Op().CONV2D - weight_name = weight.name - - attr.Conv2dAttribute( - pad=pad_attr, - stride=stride_attr, - dilation=dilation_attr, - local_bound=False, - acc_type=acc_type, - ) + attr = ts.TosaSerializerAttribute() + self._get_attr_func(attr)( + pad=pad_attr, + stride=stride_attr, + dilation=dilation_attr, + local_bound=False, + acc_type=acc_type, + ) self._serialize_operator( node, @@ -258,7 +153,7 @@ def define_node( tosa_op, [ input.name, - weight_name, + weight.name, bias.name, f"{conv2d_output_name}_input_zp", f"{conv2d_output_name}_weight_zp", diff --git a/backends/arm/operators/op_tosa_depthwise_conv2d.py b/backends/arm/operators/op_tosa_depthwise_conv2d.py new file mode 100644 index 00000000000..ef4da3845fe --- /dev/null +++ b/backends/arm/operators/op_tosa_depthwise_conv2d.py @@ -0,0 +1,31 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe +from executorch.backends.arm.operators.node_visitor import register_node_visitor +from executorch.backends.arm.operators.op_tosa_conv2d import Conv2dVisitor +from executorch.backends.arm.tosa import TosaSpecification + + +@register_node_visitor +class DepthwiseConv2dVisitor(Conv2dVisitor): + """Provide a visitor that serializes TOSA ``DEPTHWISE_CONV2D``.""" + + target = "tosa.DEPTHWISE_CONV2D.default" + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-1.0+INT"), + TosaSpecification.create_from_string("TOSA-1.0+FP"), + ] + + def _get_tosa_op(self): + import serializer.tosa_serializer as ts # type: ignore + + return ts.TosaOp.Op().DEPTHWISE_CONV2D + + def _get_attr_func(self, attr): + return attr.DepthwiseConv2dAttribute + + # Inheriting the define_node method from Conv2dVisitor diff --git a/backends/arm/test/misc/test_tosa_dialect_conv2d.py b/backends/arm/test/misc/test_tosa_dialect_conv2d.py new file mode 100644 index 00000000000..867578a4ff5 --- /dev/null +++ b/backends/arm/test/misc/test_tosa_dialect_conv2d.py @@ -0,0 +1,250 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import executorch.backends.arm.tosa.dialect # noqa: unused +import pytest +import torch +from executorch.backends.arm.tosa.dialect.lib import TosaValueError +from executorch.backends.arm.tosa.specification import ( + TosaLoweringContext, + TosaSpecification, +) +from executorch.exir.dialects._ops import ops as exir_ops +from torch._subclasses.fake_tensor import FakeTensorMode + + +def test_conv2d_tosa_INT(): + sample_inputs = [ + ( + ( + torch.randint(-128, 127, (1, 8, 20, 20), dtype=torch.int8), + torch.randint(-127, 127, (8, 2, 5, 5), dtype=torch.int8), + torch.randint(-(2**31), 2**31, (8,), dtype=torch.int32), + [1, 1], + [2, 2, 2, 2], + [1, 1], + False, + [0, 0], + 4, + ), + (1, 8, 20, 20), + torch.int8, + ), + ( + ( + torch.randint(-128, 127, (1, 8, 20, 20), dtype=torch.int8), + torch.randint(-127, 127, (4, 2, 5, 5), dtype=torch.int8), + None, + [2, 2], + [2, 2, 2, 2], + [1, 1], + False, + [0, 0], + 4, + ), + (1, 4, 10, 10), + torch.int8, + ), + ] + + with TosaLoweringContext( + TosaSpecification.create_from_string("TOSA-1.0+INT") + ), FakeTensorMode() as mode: + for sample_input, expected_output_shape, expected_output_type in sample_inputs: + output = exir_ops.backend.tosa.CONV2D.default( + *tuple( + [ + mode.from_tensor(i) if isinstance(i, torch.Tensor) else i + for i in sample_input + ] + ) + ) + assert ( + output.dtype == expected_output_type + ), f"Expected output dtype {expected_output_type} but got {output.dtype}" + assert ( + tuple(output.shape) == expected_output_shape + ), f"Expected output shape {expected_output_shape} but got {tuple(output.shape)}" + + +def test_conv2d_invalid_tosa_INT(): + tosa_spec = TosaSpecification.create_from_string("TOSA-1.0+INT") + sample_inputs = [ + ( + ( + torch.randn((1, 8, 20, 20), dtype=torch.float32), + torch.randn((8, 2, 5, 5), dtype=torch.float32), + torch.randn((8,), dtype=torch.float32), + [1, 1], + [2, 2, 2, 2], + [1, 1], + False, + [0, 0], + 4, + ), + TosaValueError, + f"doesn't support {torch.float32} but found input type {torch.float32}", + ), + ( + ( + torch.randint(-128, 127, (1, 8, 20, 20), dtype=torch.int8), + torch.randn((8, 2, 5, 5), dtype=torch.float32), + torch.randn((8,), dtype=torch.float32), + [1, 1], + [2, 2, 2, 2], + [1, 1], + False, + [0, 0], + 4, + ), + TosaValueError, + f"only supports {torch.int8} weights for {torch.int8} input but found {torch.float32}", + ), + ( + ( + torch.randint(-128, 127, (1, 8, 20, 20), dtype=torch.int8), + torch.randint(-127, 127, (8, 2, 5, 5), dtype=torch.int8), + torch.randn((8,), dtype=torch.float32), + [1, 1], + [2, 2, 2, 2], + [1, 1], + False, + [0, 0], + 4, + ), + TosaValueError, + f"only supports {torch.int32} bias for {torch.int8} input but found {torch.float32}", + ), + ] + + with TosaLoweringContext(tosa_spec), FakeTensorMode() as mode: + for sample_input, expected_error, expected_error_str in sample_inputs: + with pytest.raises(expected_error, match=expected_error_str): + exir_ops.backend.tosa.CONV2D.default( + *tuple( + [ + mode.from_tensor(i) if isinstance(i, torch.Tensor) else i + for i in sample_input + ] + ) + ) + + +def test_conv2d_tosa_FP(): + sample_inputs = [ + ( + ( + torch.randn((1, 8, 20, 20), dtype=torch.float32), + torch.randn((8, 2, 5, 5), dtype=torch.float32), + torch.randn((8,), dtype=torch.float32), + [1, 1], + [2, 2, 2, 2], + [1, 1], + False, + [0, 0], + 4, + ), + (1, 8, 20, 20), + torch.float32, + ), + ( + ( + torch.randn((1, 8, 20, 20), dtype=torch.float32), + torch.randn((4, 2, 5, 5), dtype=torch.float32), + None, + [2, 2], + [2, 2, 2, 2], + [1, 1], + False, + [0, 0], + 4, + ), + (1, 4, 10, 10), + torch.float32, + ), + ] + + with TosaLoweringContext( + TosaSpecification.create_from_string("TOSA-1.0+FP") + ), FakeTensorMode() as mode: + for sample_input, expected_output_shape, expected_output_type in sample_inputs: + output = exir_ops.backend.tosa.CONV2D.default( + *tuple( + [ + mode.from_tensor(i) if isinstance(i, torch.Tensor) else i + for i in sample_input + ] + ) + ) + assert ( + output.dtype == expected_output_type + ), f"Expected output dtype {expected_output_type} but got {output.dtype}" + assert ( + tuple(output.shape) == expected_output_shape + ), f"Expected output shape {expected_output_shape} but got {tuple(output.shape)}" + + +def test_conv2d_invalid_tosa_FP(): + tosa_spec = TosaSpecification.create_from_string("TOSA-1.0+FP") + sample_inputs = [ + ( + ( + torch.randint(-127, 127, (1, 8, 20, 20), dtype=torch.int8), + torch.randn((8, 2, 5, 5), dtype=torch.float32), + torch.randn((8,), dtype=torch.float32), + [1, 1], + [2, 2, 2, 2], + [1, 1], + False, + [0, 0], + 4, + ), + TosaValueError, + f"doesn't support {torch.int8} but found input type {torch.int8}", + ), + ( + ( + torch.randn((1, 8, 20, 20), dtype=torch.float32), + torch.randn((8, 2, 5, 5), dtype=torch.float16), + torch.randn((8,), dtype=torch.float32), + [1, 1], + [2, 2, 2, 2], + [1, 1], + False, + [0, 0], + 4, + ), + TosaValueError, + f"requires weights {torch.float16} to be of the same type as input {torch.float32}", + ), + ( + ( + torch.randn((1, 8, 20, 20), dtype=torch.float32), + torch.randn((8, 2, 5, 5), dtype=torch.float32), + torch.randn((8,), dtype=torch.float16), + [1, 1], + [2, 2, 2, 2], + [1, 1], + False, + [0, 0], + 4, + ), + TosaValueError, + f"requires bias {torch.float16} to be of the same type as input {torch.float32}", + ), + ] + + with TosaLoweringContext(tosa_spec), FakeTensorMode() as mode: + for sample_input, expected_error, expected_error_str in sample_inputs: + with pytest.raises(expected_error, match=expected_error_str): + exir_ops.backend.tosa.CONV2D.default( + *tuple( + [ + mode.from_tensor(i) if isinstance(i, torch.Tensor) else i + for i in sample_input + ] + ) + ) diff --git a/backends/arm/test/misc/test_tosa_dialect_dw_conv2d.py b/backends/arm/test/misc/test_tosa_dialect_dw_conv2d.py new file mode 100644 index 00000000000..8d9224d90fe --- /dev/null +++ b/backends/arm/test/misc/test_tosa_dialect_dw_conv2d.py @@ -0,0 +1,257 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import executorch.backends.arm.tosa.dialect # noqa: unused +import pytest +import torch +from executorch.backends.arm.tosa.dialect.lib import TosaValueError +from executorch.backends.arm.tosa.specification import ( + TosaLoweringContext, + TosaSpecification, +) +from executorch.exir.dialects._ops import ops as exir_ops +from torch._subclasses.fake_tensor import FakeTensorMode + + +def test_depthwise_conv2d_tosa_INT(): + sample_inputs = [ + ( + ( + torch.randint(-128, 127, (1, 8, 20, 20), dtype=torch.int8), + # weight shape is [H, m_length, W, in_channels], where m_length = out_channels // in_channels + torch.randint(-127, 127, (5, 2, 5, 8), dtype=torch.int8), + torch.randint(-(2**31), 2**31, (16,), dtype=torch.int32), + [1, 1], + [2, 2, 2, 2], + [1, 1], + False, + [0, 0], + 8, + ), + (1, 16, 20, 20), + torch.int8, + ), + ( + ( + torch.randint(-128, 127, (1, 8, 20, 20), dtype=torch.int8), + # weight shape is [H, m_length, W, in_channels], where m_length = out_channels // in_channels + torch.randint(-127, 127, (5, 4, 5, 8), dtype=torch.int8), + None, + [2, 2], + [2, 2, 2, 2], + [1, 1], + False, + [0, 0], + 8, + ), + (1, 32, 10, 10), + torch.int8, + ), + ] + + with TosaLoweringContext( + TosaSpecification.create_from_string("TOSA-1.0+INT") + ), FakeTensorMode() as mode: + for sample_input, expected_output_shape, expected_output_type in sample_inputs: + output = exir_ops.backend.tosa.DEPTHWISE_CONV2D.default( + *tuple( + [ + mode.from_tensor(i) if isinstance(i, torch.Tensor) else i + for i in sample_input + ] + ) + ) + assert ( + output.dtype == expected_output_type + ), f"Expected output dtype {expected_output_type} but got {output.dtype}" + assert ( + tuple(output.shape) == expected_output_shape + ), f"Expected output shape {expected_output_shape} but got {tuple(output.shape)}" + + +def test_depthwise_conv2d_invalid_tosa_INT(): + tosa_spec = TosaSpecification.create_from_string("TOSA-1.0+INT") + sample_inputs = [ + ( + ( + torch.randn((1, 8, 20, 20), dtype=torch.float32), + # weight shape is [H, m_length, W, in_channels], where m_length = out_channels // in_channels + torch.randn((5, 2, 5, 8), dtype=torch.float32), + torch.randn((16,), dtype=torch.float32), + [1, 1], + [2, 2, 2, 2], + [1, 1], + False, + [0, 0], + 8, + ), + TosaValueError, + f"doesn't support {torch.float32} but found input type {torch.float32}", + ), + ( + ( + torch.randint(-128, 127, (1, 8, 20, 20), dtype=torch.int8), + # weight shape is [H, m_length, W, in_channels], where m_length = out_channels // in_channels + torch.randn((5, 2, 5, 8), dtype=torch.float32), + torch.randn((16,), dtype=torch.float32), + [1, 1], + [2, 2, 2, 2], + [1, 1], + False, + [0, 0], + 8, + ), + TosaValueError, + f"only supports {torch.int8} weights for {torch.int8} input but found {torch.float32}", + ), + ( + ( + torch.randint(-128, 127, (1, 8, 20, 20), dtype=torch.int8), + torch.randint(-127, 127, (5, 2, 5, 8), dtype=torch.int8), + torch.randn((16,), dtype=torch.float32), + [1, 1], + [2, 2, 2, 2], + [1, 1], + False, + [0, 0], + 8, + ), + TosaValueError, + f"only supports {torch.int32} bias for {torch.int8} input but found {torch.float32}", + ), + ] + + with TosaLoweringContext(tosa_spec), FakeTensorMode() as mode: + for sample_input, expected_error, expected_error_str in sample_inputs: + with pytest.raises(expected_error, match=expected_error_str): + exir_ops.backend.tosa.DEPTHWISE_CONV2D.default( + *tuple( + [ + mode.from_tensor(i) if isinstance(i, torch.Tensor) else i + for i in sample_input + ] + ) + ) + + +def test_depthwise_conv2d_tosa_FP(): + sample_inputs = [ + ( + ( + torch.randn((1, 8, 20, 20), dtype=torch.float32), + # weight shape is [H, m_length, W, in_channels], where m_length = out_channels // in_channels + torch.randn((5, 2, 5, 8), dtype=torch.float32), + torch.randn((16,), dtype=torch.float32), + [1, 1], + [2, 2, 2, 2], + [1, 1], + False, + [0, 0], + 8, + ), + (1, 16, 20, 20), + torch.float32, + ), + ( + ( + torch.randn((1, 8, 20, 20), dtype=torch.float32), + # weight shape is [H, m_length, W, in_channels], where m_length = out_channels // in_channels + torch.randn((5, 4, 5, 8), dtype=torch.float32), + None, + [2, 2], + [2, 2, 2, 2], + [1, 1], + False, + [0, 0], + 8, + ), + (1, 32, 10, 10), + torch.float32, + ), + ] + + with TosaLoweringContext( + TosaSpecification.create_from_string("TOSA-1.0+FP") + ), FakeTensorMode() as mode: + for sample_input, expected_output_shape, expected_output_type in sample_inputs: + output = exir_ops.backend.tosa.DEPTHWISE_CONV2D.default( + *tuple( + [ + mode.from_tensor(i) if isinstance(i, torch.Tensor) else i + for i in sample_input + ] + ) + ) + assert ( + output.dtype == expected_output_type + ), f"Expected output dtype {expected_output_type} but got {output.dtype}" + assert ( + tuple(output.shape) == expected_output_shape + ), f"Expected output shape {expected_output_shape} but got {tuple(output.shape)}" + + +def test_depthwise_conv2d_invalid_tosa_FP(): + + tosa_spec = TosaSpecification.create_from_string("TOSA-1.0+FP") + + sample_inputs = [ + ( + ( + torch.randint(-127, 127, (1, 8, 20, 20), dtype=torch.int8), + torch.randn((5, 2, 5, 8), dtype=torch.float32), + torch.randn((16,), dtype=torch.float32), + [1, 1], + [2, 2, 2, 2], + [1, 1], + False, + [0, 0], + 8, + ), + TosaValueError, + f"doesn't support {torch.int8} but found input type {torch.int8}", + ), + ( + ( + torch.randn((1, 8, 20, 20), dtype=torch.float32), + torch.randn((5, 2, 5, 8), dtype=torch.float16), + torch.randn((16,), dtype=torch.float32), + [1, 1], + [2, 2, 2, 2], + [1, 1], + False, + [0, 0], + 8, + ), + TosaValueError, + f"requires weights {torch.float16} to be of the same type as input {torch.float32}", + ), + ( + ( + torch.randn((1, 8, 20, 20), dtype=torch.float32), + torch.randn((5, 2, 5, 8), dtype=torch.float32), + torch.randn((16,), dtype=torch.float16), + [1, 1], + [2, 2, 2, 2], + [1, 1], + False, + [0, 0], + 8, + ), + TosaValueError, + f"requires bias {torch.float16} to be of the same type as input {torch.float32}", + ), + ] + with TosaLoweringContext(tosa_spec), FakeTensorMode() as mode: + for sample_input, expected_error, expected_error_str in sample_inputs: + with pytest.raises(expected_error, match=expected_error_str): + exir_ops.backend.tosa.CONV2D.default( + *tuple( + [ + mode.from_tensor(i) if isinstance(i, torch.Tensor) else i + for i in sample_input + ] + ) + ) diff --git a/backends/arm/tosa/dialect/__init__.py b/backends/arm/tosa/dialect/__init__.py index 897de70279f..adb5064454b 100644 --- a/backends/arm/tosa/dialect/__init__.py +++ b/backends/arm/tosa/dialect/__init__.py @@ -4,6 +4,8 @@ # LICENSE file in the root directory of this source tree. from executorch.backends.arm.tosa.dialect.ops import ( # noqa F401 + conv2d, + depthwise_conv2d, matmul, rescale, resize, diff --git a/backends/arm/tosa/dialect/ops/conv2d.py b/backends/arm/tosa/dialect/ops/conv2d.py new file mode 100644 index 00000000000..052c1111615 --- /dev/null +++ b/backends/arm/tosa/dialect/ops/conv2d.py @@ -0,0 +1,118 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch +from executorch.backends.arm.tosa.dialect.lib import TosaValueError +from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op + +from executorch.backends.arm.tosa.specification import ( + get_context_spec, + TosaSpecification, +) +from executorch.exir.dialects._ops import ops as exir_ops + + +def validate_conv2d_args_dtypes( + tosa_spec: TosaSpecification, + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + op: str = "CONV2D", +) -> torch.dtype: + output_dtype = None + supported_int_types = (torch.int8, torch.int16) + supported_float_types = ( + torch.float16, + torch.float32, + ) + if x.dtype in supported_int_types: + if not tosa_spec.support_integer(): + raise TosaValueError( + f"TOSA spec {tosa_spec} doesn't support {x.dtype} but found input type {x.dtype}", + op=op, + ) + if weight.dtype not in (torch.int8,): + raise TosaValueError( + f"TOSA spec {tosa_spec} only supports {torch.int8} weights for {x.dtype} input but found {weight.dtype}", + op=op, + ) + if bias is not None and bias.dtype not in (torch.int32,): + raise TosaValueError( + f"TOSA spec {tosa_spec} only supports {torch.int32} bias for {x.dtype} input but found {bias.dtype}", + op=op, + ) + # TODO update to int32 for int8 inputs + output_dtype = torch.int8 if x.dtype == torch.int8 else torch.int16 + + elif x.dtype in supported_float_types: + if not tosa_spec.support_float(): + raise TosaValueError( + f"TOSA spec {tosa_spec} doesn't support {x.dtype} but found input type {x.dtype}", + op=op, + ) + if weight.dtype != x.dtype: + raise TosaValueError( + f"TOSA spec {tosa_spec} requires weights {weight.dtype} to be of the same type as input {x.dtype}", + op=op, + ) + if bias is not None and bias.dtype != x.dtype: + raise TosaValueError( + f"TOSA spec {tosa_spec} requires bias {bias.dtype} to be of the same type as input {x.dtype}", + op=op, + ) + output_dtype = x.dtype + else: + raise TosaValueError( + f"Unsupported input dtype {x.dtype}, supported types are {supported_int_types + supported_float_types} ", + op=op, + ) + return output_dtype + + +@register_fake_tosa_op( + "CONV2D(Tensor input, " + "Tensor weight, " + "Tensor bias, " + "int[2] stride, " + "int[4] pad, " + "int[2] dialation, " + "bool transposed, " + "int[2] output_padding, " + "int groups) -> Tensor", # schema + ( + TosaSpecification.create_from_string("TOSA-1.0+FP"), + TosaSpecification.create_from_string("TOSA-1.0+INT"), + ), # target TOSA specifications +) +def CONV2D( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: list[int], + pad: list[int], + dialation: list[int], + transposed: bool, + output_padding: list[int], + groups: int, +) -> torch.Tensor: + tosa_spec = get_context_spec() + + output_dtype = validate_conv2d_args_dtypes(tosa_spec, x, weight, bias, op="CONV2D") + + torch_pad = [pad[0], pad[2]] + aten_fake_tensor = exir_ops.edge.aten.convolution.default( + x, + weight, + bias, + stride, + torch_pad, + dialation, + transposed, + output_padding, + groups, + ) + return aten_fake_tensor.to(dtype=output_dtype) diff --git a/backends/arm/tosa/dialect/ops/depthwise_conv2d.py b/backends/arm/tosa/dialect/ops/depthwise_conv2d.py new file mode 100644 index 00000000000..c234a2e84a8 --- /dev/null +++ b/backends/arm/tosa/dialect/ops/depthwise_conv2d.py @@ -0,0 +1,65 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.backends.arm.tosa.dialect.ops.conv2d import validate_conv2d_args_dtypes +from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op + +from executorch.backends.arm.tosa.specification import ( + get_context_spec, + TosaSpecification, +) +from executorch.exir.dialects._ops import ops as exir_ops + + +@register_fake_tosa_op( + "DEPTHWISE_CONV2D(Tensor input, " + "Tensor weight, " + "Tensor bias, " + "int[2] stride, " + "int[4] pad, " + "int[2] dialation, " + "bool transposed, " + "int[2] output_padding, " + "int groups) -> Tensor", # schema + ( + TosaSpecification.create_from_string("TOSA-1.0+FP"), + TosaSpecification.create_from_string("TOSA-1.0+INT"), + ), # target TOSA specifications +) +def DEPTHWISE_CONV2D( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: list[int], + pad: list[int], + dialation: list[int], + transposed: bool, + output_padding: list[int], + groups: int, +) -> torch.Tensor: + tosa_spec = get_context_spec() + + output_dtype = validate_conv2d_args_dtypes( + tosa_spec, x, weight, bias, op="DEPTHWISE_CONV2D" + ) + + torch_pad = [pad[0], pad[2]] + H, W = weight.shape[0], weight.shape[2] + in_channels_group = x.shape[1] // groups + out_channels = weight.shape[1] * x.shape[1] + torch_weight = weight.reshape(out_channels, in_channels_group, H, W) + aten_fake_tensor = exir_ops.edge.aten.convolution.default( + x, + torch_weight, + bias, + stride, + torch_pad, + dialation, + transposed, + output_padding, + groups, + ) + return aten_fake_tensor.to(dtype=output_dtype)