Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool:
operator.getitem,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
exir_ops.edge.aten.constant_pad_nd.default,
]

return supported
1 change: 1 addition & 0 deletions backends/arm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
op_bmm,
op_cat,
op_clamp,
op_constant_pad_nd,
op_conv2d,
op_eq,
op_exp,
Expand Down
74 changes: 74 additions & 0 deletions backends/arm/operators/op_constant_pad_nd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

from typing import List

import serializer.tosa_serializer as ts
import torch

from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
get_input_qparams,
)
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from serializer.tosa_serializer import TosaOp


@register_node_visitor
class ConstantPadNDVisitor(NodeVisitor):

target = "aten.constant_pad_nd.default"

def define_node(
self,
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
) -> None:

if inputs[0].dtype == ts.DType.INT8:
input_qparams = get_input_qparams(node)
qargs = input_qparams[0]
pad_const_qs = qargs.quantize_value(inputs[2].number).item()
pad_const_fp = 0.0
else:
pad_const_fp = inputs[2].number
pad_const_qs = 0

rank = len(output.shape)
# Each dim needs 2 padding values. For example, to pad the last dimension, the pad has the form
# (padding_left, padding_right); to pad the last two dimensions, the pad has the form
# (padding_left, padding_right, padding_top, padding_bottom), and so on. For PyTorch NCHW format, the padding
# values are in the reverse order. So, firstly we need to reverse the input padding parameters.
input_pad = sum(
[
[inputs[1].special[i], inputs[1].special[i + 1]]
for i in range(0, len(inputs[1].special), 2)
][::-1],
[],
)
# Then, add dummy zeros to make sure that both input_pad and output_pad has the same size.
input_pad = [0] * (rank * 2 - len(inputs[1].special)) + input_pad
# For PyTorch NCHW format, dim order is [0,...,rank-1]
input_dim_order = list(range(rank))
output_pad = [0] * rank * 2

# Map input padding parameters into output padding parameters. TOSA is NHWC format.
for input_dim_idx, input_dim in enumerate(input_dim_order):
output_dim_idx = output.dim_order.index(input_dim)
output_pad[output_dim_idx * 2 : (output_dim_idx + 1) * 2] = input_pad[
input_dim_idx * 2 : (input_dim_idx + 1) * 2
]

attr = ts.TosaSerializerAttribute()
attr.PadAttribute(tosa_graph.builder, output_pad, pad_const_qs, pad_const_fp)

tosa_graph.addOperator(TosaOp.Op().PAD, [inputs[0].name], [output.name], attr)
4 changes: 4 additions & 0 deletions backends/arm/quantizer/quantization_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def _match_pattern(
torch.ops.aten.chunk.default,
torch.ops.aten.contiguous.default,
torch.ops.aten.upsample_nearest2d.vec,
torch.ops.aten.pad.default,
]

# Operators that can inherit the quantization specs from its parent node
Expand Down Expand Up @@ -216,6 +217,7 @@ def any_or_hardtanh_min_zero(n: Node):
torch.ops.aten.conv1d.default,
torch.ops.aten.conv2d.default,
torch.ops.aten.linear.default,
torch.ops.aten.conv2d.padding,
],
[torch.ops.aten.relu.default, torch.ops.aten.hardtanh.default],
],
Expand All @@ -225,6 +227,7 @@ def any_or_hardtanh_min_zero(n: Node):
torch.ops.aten.conv1d.default,
torch.ops.aten.conv2d.default,
torch.ops.aten.linear.default,
torch.ops.aten.conv2d.padding,
):
quant_properties.quant_inputs = [
_QuantProperty(0, input_act_qspec),
Expand All @@ -237,6 +240,7 @@ def any_or_hardtanh_min_zero(n: Node):
torch.ops.aten.conv1d.default,
torch.ops.aten.conv2d.default,
torch.ops.aten.linear.default,
torch.ops.aten.conv2d.padding,
):
quant_properties.quant_inputs = [
_QuantProperty(0, input_act_qspec),
Expand Down
1 change: 1 addition & 0 deletions backends/arm/quantizer/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def _derive_qparams_fn(
torch.ops.aten.conv1d.default,
torch.ops.aten.conv2d.default,
torch.ops.aten.linear.default,
torch.ops.aten.conv2d.padding,
]:
input_act = node.args[0]
weight = node.args[1]
Expand Down
144 changes: 144 additions & 0 deletions backends/arm/test/ops/test_constant_pad_nd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

#
# Test the pad_constant_nd op which pads the input tensor at specific dimension(s).
#
import unittest
from typing import Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.arm_tester import ArmTester
from parameterized import parameterized

test_data_suite = [
("4dim_last1dim", torch.rand(1, 1, 16, 16), (1, 1, 0, 0, 0, 0, 0, 0), 1),
("4dim_last2dim", torch.rand(1, 1, 16, 16), (1, 0, 1, 0, 0, 0, 0, 0), 2),
("4dim_last3dim", torch.rand(1, 1, 16, 16), (1, 1, 0, 2, 0, 2, 0, 0), 3),
("4dim_last4dim", torch.rand(1, 1, 16, 16), (1, 0, 1, 1, 0, 2, 0, 2), 4),
("3dim_last1dim", torch.rand(1, 1, 16), (1, 1, 0, 0, 0, 0), 1),
("3dim_last2dim", torch.rand(1, 1, 16), (1, 0, 1, 1, 0, 0), 2),
("3dim_last3dim", torch.rand(1, 1, 16), (1, 0, 1, 0, 1, 1), 3),
("2dim_last1dim", torch.rand(1, 1, 16), (1, 1, 0, 0), 1),
("2dim_last2dim", torch.rand(1, 1, 16), (1, 0, 1, 1), 2),
]


class TestConstantPadND(unittest.TestCase):
"""Tests pad."""

class ConstantPadND(torch.nn.Module):
def __init__(self, pad: Tuple, value: float | None = None):
super().__init__()
self.dim = len(pad) // 2
self.value = value
in_channels = 1
# Only apply conv2d when the input dim = 4.
if self.dim == 4:
in_channels += pad[-3] + pad[-4]

self.conv2d = nn.Conv2d(
in_channels=in_channels,
out_channels=3,
kernel_size=3,
bias=True,
stride=(2, 2),
padding=0,
)

in_channels = 3
in_channels += pad[-3] + pad[-4]
self.conv2d_1 = nn.Conv2d(
in_channels=in_channels,
out_channels=3,
kernel_size=3,
bias=True,
padding="same",
)

nonzero_idx = len(pad)
for i in range(0, len(pad), 2):
if pad[i] + pad[i + 1] == 0:
nonzero_idx = i
break
self.pad = pad[:nonzero_idx]
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()

def forward(self, x: torch.Tensor):
x = F.pad(x, pad=self.pad, mode="constant", value=self.value)
if self.dim == 4:
x = self.conv2d(x)
x = self.relu(x)

x = F.pad(x, pad=self.pad, mode="constant", value=self.value)
if self.dim == 4:
x = self.conv2d_1(x)
x = self.sigmoid(x)
return x

def _test_constant_pad_nd_tosa_MI_pipeline(
self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
):
(
ArmTester(
module,
example_inputs=test_data,
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"),
)
.export()
.check_count({"torch.ops.aten.pad.default": 2})
.to_edge()
.partition()
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
.run_method_and_compare_outputs(inputs=test_data)
)

def _test_constant_pad_nd_tosa_BI_pipeline(
self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
):
(
ArmTester(
module,
example_inputs=test_data,
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"),
)
.quantize()
.export()
.check_count({"torch.ops.aten.pad.default": 2})
.to_edge()
.partition()
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
.run_method_and_compare_outputs(inputs=test_data, qtol=1)
)

@parameterized.expand(test_data_suite)
def test_constant_pad_nd_tosa_MI(
self,
test_name: str,
test_data: torch.Tensor,
padding: Tuple,
value: float | None = None,
):
self._test_constant_pad_nd_tosa_MI_pipeline(
self.ConstantPadND(padding, value), (test_data,)
)

@parameterized.expand(test_data_suite)
def test_constant_pad_nd_tosa_BI(
self,
test_name: str,
test_data: torch.Tensor,
padding: Tuple,
value: float | None = None,
):
self._test_constant_pad_nd_tosa_BI_pipeline(
self.ConstantPadND(padding, value), (test_data,)
)
Loading