diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index dd092968764..557ace5668d 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -138,6 +138,7 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool: operator.getitem, exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + exir_ops.edge.aten.constant_pad_nd.default, ] return supported diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index f57ba092bc4..f6d5c27ee41 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -12,6 +12,7 @@ op_bmm, op_cat, op_clamp, + op_constant_pad_nd, op_conv2d, op_eq, op_exp, diff --git a/backends/arm/operators/op_constant_pad_nd.py b/backends/arm/operators/op_constant_pad_nd.py new file mode 100644 index 00000000000..73f6d2751c5 --- /dev/null +++ b/backends/arm/operators/op_constant_pad_nd.py @@ -0,0 +1,74 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from typing import List + +import serializer.tosa_serializer as ts +import torch + +from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + get_input_qparams, +) +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.tosa_mapping import TosaArg +from serializer.tosa_serializer import TosaOp + + +@register_node_visitor +class ConstantPadNDVisitor(NodeVisitor): + + target = "aten.constant_pad_nd.default" + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + + if inputs[0].dtype == ts.DType.INT8: + input_qparams = get_input_qparams(node) + qargs = input_qparams[0] + pad_const_qs = qargs.quantize_value(inputs[2].number).item() + pad_const_fp = 0.0 + else: + pad_const_fp = inputs[2].number + pad_const_qs = 0 + + rank = len(output.shape) + # Each dim needs 2 padding values. For example, to pad the last dimension, the pad has the form + # (padding_left, padding_right); to pad the last two dimensions, the pad has the form + # (padding_left, padding_right, padding_top, padding_bottom), and so on. For PyTorch NCHW format, the padding + # values are in the reverse order. So, firstly we need to reverse the input padding parameters. + input_pad = sum( + [ + [inputs[1].special[i], inputs[1].special[i + 1]] + for i in range(0, len(inputs[1].special), 2) + ][::-1], + [], + ) + # Then, add dummy zeros to make sure that both input_pad and output_pad has the same size. + input_pad = [0] * (rank * 2 - len(inputs[1].special)) + input_pad + # For PyTorch NCHW format, dim order is [0,...,rank-1] + input_dim_order = list(range(rank)) + output_pad = [0] * rank * 2 + + # Map input padding parameters into output padding parameters. TOSA is NHWC format. + for input_dim_idx, input_dim in enumerate(input_dim_order): + output_dim_idx = output.dim_order.index(input_dim) + output_pad[output_dim_idx * 2 : (output_dim_idx + 1) * 2] = input_pad[ + input_dim_idx * 2 : (input_dim_idx + 1) * 2 + ] + + attr = ts.TosaSerializerAttribute() + attr.PadAttribute(tosa_graph.builder, output_pad, pad_const_qs, pad_const_fp) + + tosa_graph.addOperator(TosaOp.Op().PAD, [inputs[0].name], [output.name], attr) diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index f6f6221510f..f1cef971782 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -172,6 +172,7 @@ def _match_pattern( torch.ops.aten.chunk.default, torch.ops.aten.contiguous.default, torch.ops.aten.upsample_nearest2d.vec, + torch.ops.aten.pad.default, ] # Operators that can inherit the quantization specs from its parent node @@ -216,6 +217,7 @@ def any_or_hardtanh_min_zero(n: Node): torch.ops.aten.conv1d.default, torch.ops.aten.conv2d.default, torch.ops.aten.linear.default, + torch.ops.aten.conv2d.padding, ], [torch.ops.aten.relu.default, torch.ops.aten.hardtanh.default], ], @@ -225,6 +227,7 @@ def any_or_hardtanh_min_zero(n: Node): torch.ops.aten.conv1d.default, torch.ops.aten.conv2d.default, torch.ops.aten.linear.default, + torch.ops.aten.conv2d.padding, ): quant_properties.quant_inputs = [ _QuantProperty(0, input_act_qspec), @@ -237,6 +240,7 @@ def any_or_hardtanh_min_zero(n: Node): torch.ops.aten.conv1d.default, torch.ops.aten.conv2d.default, torch.ops.aten.linear.default, + torch.ops.aten.conv2d.padding, ): quant_properties.quant_inputs = [ _QuantProperty(0, input_act_qspec), diff --git a/backends/arm/quantizer/quantization_config.py b/backends/arm/quantizer/quantization_config.py index 394995201e4..65435ac7c63 100644 --- a/backends/arm/quantizer/quantization_config.py +++ b/backends/arm/quantizer/quantization_config.py @@ -78,6 +78,7 @@ def _derive_qparams_fn( torch.ops.aten.conv1d.default, torch.ops.aten.conv2d.default, torch.ops.aten.linear.default, + torch.ops.aten.conv2d.padding, ]: input_act = node.args[0] weight = node.args[1] diff --git a/backends/arm/test/ops/test_constant_pad_nd.py b/backends/arm/test/ops/test_constant_pad_nd.py new file mode 100644 index 00000000000..3dfd28640eb --- /dev/null +++ b/backends/arm/test/ops/test_constant_pad_nd.py @@ -0,0 +1,144 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# +# Test the pad_constant_nd op which pads the input tensor at specific dimension(s). +# +import unittest +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.arm_tester import ArmTester +from parameterized import parameterized + +test_data_suite = [ + ("4dim_last1dim", torch.rand(1, 1, 16, 16), (1, 1, 0, 0, 0, 0, 0, 0), 1), + ("4dim_last2dim", torch.rand(1, 1, 16, 16), (1, 0, 1, 0, 0, 0, 0, 0), 2), + ("4dim_last3dim", torch.rand(1, 1, 16, 16), (1, 1, 0, 2, 0, 2, 0, 0), 3), + ("4dim_last4dim", torch.rand(1, 1, 16, 16), (1, 0, 1, 1, 0, 2, 0, 2), 4), + ("3dim_last1dim", torch.rand(1, 1, 16), (1, 1, 0, 0, 0, 0), 1), + ("3dim_last2dim", torch.rand(1, 1, 16), (1, 0, 1, 1, 0, 0), 2), + ("3dim_last3dim", torch.rand(1, 1, 16), (1, 0, 1, 0, 1, 1), 3), + ("2dim_last1dim", torch.rand(1, 1, 16), (1, 1, 0, 0), 1), + ("2dim_last2dim", torch.rand(1, 1, 16), (1, 0, 1, 1), 2), +] + + +class TestConstantPadND(unittest.TestCase): + """Tests pad.""" + + class ConstantPadND(torch.nn.Module): + def __init__(self, pad: Tuple, value: float | None = None): + super().__init__() + self.dim = len(pad) // 2 + self.value = value + in_channels = 1 + # Only apply conv2d when the input dim = 4. + if self.dim == 4: + in_channels += pad[-3] + pad[-4] + + self.conv2d = nn.Conv2d( + in_channels=in_channels, + out_channels=3, + kernel_size=3, + bias=True, + stride=(2, 2), + padding=0, + ) + + in_channels = 3 + in_channels += pad[-3] + pad[-4] + self.conv2d_1 = nn.Conv2d( + in_channels=in_channels, + out_channels=3, + kernel_size=3, + bias=True, + padding="same", + ) + + nonzero_idx = len(pad) + for i in range(0, len(pad), 2): + if pad[i] + pad[i + 1] == 0: + nonzero_idx = i + break + self.pad = pad[:nonzero_idx] + self.relu = nn.ReLU() + self.sigmoid = nn.Sigmoid() + + def forward(self, x: torch.Tensor): + x = F.pad(x, pad=self.pad, mode="constant", value=self.value) + if self.dim == 4: + x = self.conv2d(x) + x = self.relu(x) + + x = F.pad(x, pad=self.pad, mode="constant", value=self.value) + if self.dim == 4: + x = self.conv2d_1(x) + x = self.sigmoid(x) + return x + + def _test_constant_pad_nd_tosa_MI_pipeline( + self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), + ) + .export() + .check_count({"torch.ops.aten.pad.default": 2}) + .to_edge() + .partition() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data) + ) + + def _test_constant_pad_nd_tosa_BI_pipeline( + self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), + ) + .quantize() + .export() + .check_count({"torch.ops.aten.pad.default": 2}) + .to_edge() + .partition() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data, qtol=1) + ) + + @parameterized.expand(test_data_suite) + def test_constant_pad_nd_tosa_MI( + self, + test_name: str, + test_data: torch.Tensor, + padding: Tuple, + value: float | None = None, + ): + self._test_constant_pad_nd_tosa_MI_pipeline( + self.ConstantPadND(padding, value), (test_data,) + ) + + @parameterized.expand(test_data_suite) + def test_constant_pad_nd_tosa_BI( + self, + test_name: str, + test_data: torch.Tensor, + padding: Tuple, + value: float | None = None, + ): + self._test_constant_pad_nd_tosa_BI_pipeline( + self.ConstantPadND(padding, value), (test_data,) + )