Skip to content

Commit 0a06a5e

Browse files
Arm backend: Support for constant_pad for TOSA 1.0
Add support for PAD in TOSA 1.0. Also moves rewrites old constant_pad tests. Signed-off-by: Oscar Andersson <[email protected]> Change-Id: I653e27b68e89e6d30a57e027aa14c1d1732ad272
1 parent 9f3a16d commit 0a06a5e

File tree

3 files changed

+248
-129
lines changed

3 files changed

+248
-129
lines changed

backends/arm/operators/op_constant_pad_nd.py

Lines changed: 79 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,10 @@
55

66
# pyre-unsafe
77

8-
from typing import List
8+
from typing import Any, List
99

1010
import torch
1111

12-
import tosa_tools.v0_80.serializer.tosa_serializer as ts
13-
1412
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
1513
get_input_qparams,
1614
)
@@ -19,20 +17,27 @@
1917
register_node_visitor,
2018
)
2119
from executorch.backends.arm.tosa_mapping import TosaArg
20+
from executorch.backends.arm.tosa_specification import TosaSpecification
2221

2322

2423
@register_node_visitor
25-
class ConstantPadNDVisitor(NodeVisitor):
24+
class ConstantPadNDVisitor_0_80(NodeVisitor):
2625

2726
target = "aten.constant_pad_nd.default"
2827

28+
tosa_specs = [
29+
TosaSpecification.create_from_string("TOSA-0.80+BI"),
30+
TosaSpecification.create_from_string("TOSA-0.80+MI"),
31+
]
32+
2933
def define_node(
3034
self,
3135
node: torch.fx.Node,
32-
tosa_graph: ts.TosaSerializer,
36+
tosa_graph: Any,
3337
inputs: List[TosaArg],
3438
output: TosaArg,
3539
) -> None:
40+
import tosa_tools.v0_80.serializer.tosa_serializer as ts
3641

3742
if inputs[0].dtype == ts.DType.INT8:
3843
input_qparams = get_input_qparams(node)
@@ -74,3 +79,72 @@ def define_node(
7479
tosa_graph.addOperator(
7580
ts.TosaOp.Op().PAD, [inputs[0].name], [output.name], attr
7681
)
82+
83+
84+
@register_node_visitor
85+
class ConstantPadNDVisitor(NodeVisitor):
86+
87+
target = "aten.constant_pad_nd.default"
88+
89+
tosa_specs = [
90+
TosaSpecification.create_from_string("TOSA-1.0+INT"),
91+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
92+
]
93+
94+
def define_node(
95+
self,
96+
node: torch.fx.Node,
97+
tosa_graph: Any,
98+
inputs: List[TosaArg],
99+
output: TosaArg,
100+
) -> None:
101+
102+
import serializer.tosa_serializer as ts # type: ignore
103+
104+
if inputs[0].dtype == ts.DType.INT8:
105+
input_qparams = get_input_qparams(node)
106+
qargs = input_qparams[0]
107+
pad_const_val = qargs.quantize_value(inputs[2].number).item()
108+
pad_const_dtype = ts.DType.INT8
109+
else:
110+
pad_const_val = inputs[2].number
111+
pad_const_dtype = inputs[0].dtype
112+
113+
rank = len(output.shape)
114+
# Each dim needs 2 padding values. For example, to pad the last dimension, the pad has the form
115+
# (padding_left, padding_right); to pad the last two dimensions, the pad has the form
116+
# (padding_left, padding_right, padding_top, padding_bottom), and so on. For PyTorch NCHW format, the padding
117+
# values are in the reverse order. So, firstly we need to reverse the input padding parameters.
118+
input_pad = sum(
119+
[
120+
[inputs[1].special[i], inputs[1].special[i + 1]]
121+
for i in range(0, len(inputs[1].special), 2)
122+
][::-1],
123+
[],
124+
)
125+
# Then, add dummy zeros to make sure that both input_pad and output_pad has the same size.
126+
input_pad = [0] * (rank * 2 - len(inputs[1].special)) + input_pad
127+
# For PyTorch NCHW format, dim order is [0,...,rank-1]
128+
input_dim_order = list(range(rank))
129+
output_pad = [0] * rank * 2
130+
131+
# Map input padding parameters into output padding parameters. TOSA is NHWC format.
132+
for input_dim_idx, input_dim in enumerate(input_dim_order):
133+
output_dim_idx = output.dim_order.index(input_dim)
134+
output_pad[output_dim_idx * 2 : (output_dim_idx + 1) * 2] = input_pad[
135+
input_dim_idx * 2 : (input_dim_idx + 1) * 2
136+
]
137+
138+
padding = tosa_graph.addConst(
139+
shape=[len(output_pad)], dtype=ts.DType.SHAPE, vals=output_pad
140+
)
141+
142+
pad_const = tosa_graph.addConst(
143+
shape=[1], dtype=pad_const_dtype, vals=[pad_const_val]
144+
)
145+
146+
tosa_graph.addOperator(
147+
ts.TosaOp.Op().PAD,
148+
[inputs[0].name, padding.name, pad_const.name],
149+
[output.name],
150+
)

backends/arm/test/ops/test_constant_pad_nd.py

Lines changed: 55 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -2,143 +2,74 @@
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
5-
65
#
76
# Test the pad_constant_nd op which pads the input tensor at specific dimension(s).
87
#
9-
import unittest
108
from typing import Tuple
119

1210
import torch
13-
import torch.nn as nn
1411
import torch.nn.functional as F
1512
from executorch.backends.arm.test import common
16-
from executorch.backends.arm.test.tester.arm_tester import ArmTester
17-
from parameterized import parameterized
18-
19-
test_data_suite = [
20-
("4dim_last1dim", torch.rand(1, 1, 16, 16), (1, 1, 0, 0, 0, 0, 0, 0), 1),
21-
("4dim_last2dim", torch.rand(1, 1, 16, 16), (1, 0, 1, 0, 0, 0, 0, 0), 2),
22-
("4dim_last3dim", torch.rand(1, 1, 16, 16), (1, 1, 0, 2, 0, 2, 0, 0), 3),
23-
("4dim_last4dim", torch.rand(1, 1, 16, 16), (1, 0, 1, 1, 0, 2, 0, 2), 4),
24-
("3dim_last1dim", torch.rand(1, 1, 16), (1, 1, 0, 0, 0, 0), 1),
25-
("3dim_last2dim", torch.rand(1, 1, 16), (1, 0, 1, 1, 0, 0), 2),
26-
("3dim_last3dim", torch.rand(1, 1, 16), (1, 0, 1, 0, 1, 1), 3),
27-
("2dim_last1dim", torch.rand(1, 1, 16), (1, 1, 0, 0), 1),
28-
("2dim_last2dim", torch.rand(1, 1, 16), (1, 0, 1, 1), 2),
29-
]
30-
31-
32-
class TestConstantPadND(unittest.TestCase):
33-
"""Tests pad."""
34-
35-
class ConstantPadND(torch.nn.Module):
36-
def __init__(self, pad: Tuple, value: float | None = None):
37-
super().__init__()
38-
self.dim = len(pad) // 2
39-
self.value = value
40-
in_channels = 1
41-
# Only apply conv2d when the input dim = 4.
42-
if self.dim == 4:
43-
in_channels += pad[-3] + pad[-4]
44-
45-
self.conv2d = nn.Conv2d(
46-
in_channels=in_channels,
47-
out_channels=3,
48-
kernel_size=3,
49-
bias=True,
50-
stride=(2, 2),
51-
padding=0,
52-
)
13+
from executorch.backends.arm.test.tester.test_pipeline import (
14+
TosaPipelineBI,
15+
TosaPipelineMI,
16+
)
5317

54-
in_channels = 3
55-
in_channels += pad[-3] + pad[-4]
56-
self.conv2d_1 = nn.Conv2d(
57-
in_channels=in_channels,
58-
out_channels=3,
59-
kernel_size=3,
60-
bias=True,
61-
padding="same",
62-
)
18+
aten_op = "torch.ops.aten.pad.default"
19+
exir_op = "executorch_exir_dialects_edge__ops_aten_pad_default"
20+
input_t1 = Tuple[torch.Tensor] # Input x
21+
test_data_suite = {
22+
"4dim_last1dim": (torch.rand(1, 1, 16, 16), (1, 1, 0, 0, 0, 0, 0, 0), 1),
23+
"4dim_last2dim": (torch.rand(1, 1, 16, 16), (1, 0, 1, 0, 0, 0, 0, 0), 2),
24+
"4dim_last3dim": (torch.rand(1, 1, 16, 16), (1, 1, 0, 2, 0, 2, 0, 0), 3),
25+
"4dim_last4dim": (torch.rand(1, 1, 16, 16), (1, 0, 1, 1, 0, 2, 0, 2), 4),
26+
"3dim_last1dim": (torch.rand(1, 1, 16), (1, 1, 0, 0, 0, 0), 1),
27+
"3dim_last2dim": (torch.rand(1, 1, 16), (1, 0, 1, 1, 0, 0), 2),
28+
"3dim_last3dim": (torch.rand(1, 1, 16), (1, 0, 1, 0, 1, 1), 3),
29+
"2dim_last1dim": (torch.rand(1, 1, 16), (1, 1, 0, 0), 1),
30+
"2dim_last2dim": (torch.rand(1, 1, 16), (1, 0, 1, 1), 2),
31+
}
32+
"""Tests pad."""
6333

64-
nonzero_idx = len(pad)
65-
for i in range(0, len(pad), 2):
66-
if pad[i] + pad[i + 1] == 0:
67-
nonzero_idx = i
68-
break
69-
self.pad = pad[:nonzero_idx]
70-
self.relu = nn.ReLU()
71-
self.sigmoid = nn.Sigmoid()
7234

73-
def forward(self, x: torch.Tensor):
74-
x = F.pad(x, pad=self.pad, mode="constant", value=self.value)
75-
if self.dim == 4:
76-
x = self.conv2d(x)
77-
x = self.relu(x)
35+
class ConstantPadND(torch.nn.Module):
36+
def __init__(self, pad: Tuple, value: float | None = None):
37+
super().__init__()
38+
self.value = value
39+
nonzero_idx = len(pad)
40+
for i in range(0, len(pad), 2):
41+
if pad[i] + pad[i + 1] == 0:
42+
nonzero_idx = i
43+
break
44+
self.pad = pad[:nonzero_idx]
7845

79-
x = F.pad(x, pad=self.pad, mode="constant", value=self.value)
80-
if self.dim == 4:
81-
x = self.conv2d_1(x)
82-
x = self.sigmoid(x)
83-
return x
46+
def forward(self, x: torch.Tensor):
47+
x = F.pad(x, pad=self.pad, mode="constant", value=self.value)
48+
return x
8449

85-
def _test_constant_pad_nd_tosa_MI_pipeline(
86-
self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
87-
):
88-
(
89-
ArmTester(
90-
module,
91-
example_inputs=test_data,
92-
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"),
93-
)
94-
.export()
95-
.check_count({"torch.ops.aten.pad.default": 2})
96-
.to_edge()
97-
.partition()
98-
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
99-
.to_executorch()
100-
.run_method_and_compare_outputs(inputs=test_data)
101-
)
10250

103-
def _test_constant_pad_nd_tosa_BI_pipeline(
104-
self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
105-
):
106-
(
107-
ArmTester(
108-
module,
109-
example_inputs=test_data,
110-
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"),
111-
)
112-
.quantize()
113-
.export()
114-
.check_count({"torch.ops.aten.pad.default": 2})
115-
.to_edge()
116-
.partition()
117-
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
118-
.to_executorch()
119-
.run_method_and_compare_outputs(inputs=test_data, qtol=1)
120-
)
51+
@common.parametrize(
52+
"test_data",
53+
test_data_suite,
54+
)
55+
def test_constant_pad_nd_tosa_MI(test_data: Tuple):
56+
test_data, padding, value = test_data
57+
pipeline = TosaPipelineMI[input_t1](
58+
ConstantPadND(padding, value),
59+
(test_data,),
60+
aten_op,
61+
exir_op,
62+
)
63+
pipeline.run()
12164

122-
@parameterized.expand(test_data_suite)
123-
def test_constant_pad_nd_tosa_MI(
124-
self,
125-
test_name: str,
126-
test_data: torch.Tensor,
127-
padding: Tuple,
128-
value: float | None = None,
129-
):
130-
self._test_constant_pad_nd_tosa_MI_pipeline(
131-
self.ConstantPadND(padding, value), (test_data,)
132-
)
13365

134-
@parameterized.expand(test_data_suite)
135-
def test_constant_pad_nd_tosa_BI(
136-
self,
137-
test_name: str,
138-
test_data: torch.Tensor,
139-
padding: Tuple,
140-
value: float | None = None,
141-
):
142-
self._test_constant_pad_nd_tosa_BI_pipeline(
143-
self.ConstantPadND(padding, value), (test_data,)
144-
)
66+
@common.parametrize("test_data", test_data_suite)
67+
def test_constant_pad_nd_tosa_BI(test_data: Tuple):
68+
test_data, padding, value = test_data
69+
pipeline = TosaPipelineBI[input_t1](
70+
ConstantPadND(padding, value),
71+
(test_data,),
72+
aten_op,
73+
exir_op,
74+
)
75+
pipeline.run()

0 commit comments

Comments
 (0)