Skip to content

Commit 8ee9f91

Browse files
authored
Arm backend: Adjust MaxPool2d padding when window is not divisible by stride (#10751)
* MaxPool2dVisitor will adjust padding if the pooling window is not divisible by the stride Signed-off-by: Tom Allsop <[email protected]>
1 parent e5c38e9 commit 8ee9f91

File tree

2 files changed

+47
-0
lines changed

2 files changed

+47
-0
lines changed

backends/arm/operators/op_max_pool2d.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,24 @@
2323
from executorch.backends.arm.tosa_specification import TosaSpecification
2424

2525

26+
# Similarly to Conv2d, the TOSA spec requires that following is exactly divisible:
27+
# `(input + 2 * pad - kernel_size) / stride`
28+
# PyTorch however, does not require this, so as needed, we must adjust the padding.
29+
def adjust_pad_if_needed(
30+
input_size: int, kernel_size: int, stride: int, pad: int
31+
) -> int:
32+
if pad == 0:
33+
return pad
34+
35+
mod_remainder = (input_size + 2 * pad - kernel_size) % stride
36+
37+
# No need to adjust
38+
if mod_remainder == 0:
39+
return pad
40+
41+
return pad - mod_remainder
42+
43+
2644
@register_node_visitor
2745
class MaxPool2dVisitor_0_80(NodeVisitor):
2846
target = "aten.max_pool2d.default"
@@ -61,6 +79,20 @@ def define_node(
6179
except IndexError:
6280
pad_size_list = [0, 0, 0, 0]
6381

82+
# Adjust the padding as necessary
83+
pad_size_list[1] = adjust_pad_if_needed(
84+
input_tensor.shape[2],
85+
kernel_size[0],
86+
stride[0],
87+
pad_size_list[1],
88+
)
89+
pad_size_list[3] = adjust_pad_if_needed(
90+
input_tensor.shape[3],
91+
kernel_size[1],
92+
stride[1],
93+
pad_size_list[3],
94+
)
95+
6496
accumulator_type = output.dtype
6597

6698
# Initilize zero point to zero.
@@ -131,6 +163,20 @@ def define_node(
131163
except IndexError:
132164
pad_size_list = [0, 0, 0, 0]
133165

166+
# Adjust the padding as necessary
167+
pad_size_list[1] = adjust_pad_if_needed(
168+
input_tensor.shape[2],
169+
kernel_size[0],
170+
stride[0],
171+
pad_size_list[1],
172+
)
173+
pad_size_list[3] = adjust_pad_if_needed(
174+
input_tensor.shape[3],
175+
kernel_size[1],
176+
stride[1],
177+
pad_size_list[3],
178+
)
179+
134180
attr = ts.TosaSerializerAttribute()
135181
attr.MaxPool2dAttribute(
136182
kernel=kernel_size, stride=stride, pad=pad_size_list, nan_mode=1

backends/arm/test/ops/test_max_pool.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
("zeros", torch.zeros(1, 1, 4, 8), [2, 2, 1]),
3232
("ones", torch.ones(1, 16, 50, 32), [4, 2, 0]),
3333
("rand", torch.rand(1, 16, 52, 16), [4, 3, 0]),
34+
("non_divisible", torch.rand(1, 16, 112, 112), [3, 2, 1]),
3435
]
3536

3637
test_data_suite_mult_batches = [

0 commit comments

Comments
 (0)