Skip to content

Commit f071ea5

Browse files
authored
Fix the stride issue in DecomposeAvgPool2d and add a test for it. (#13152)
1 parent 016eece commit f071ea5

File tree

2 files changed

+87
-2
lines changed

2 files changed

+87
-2
lines changed

backends/arm/_passes/decompose_avg_pool2d.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,10 @@ def call_operator(self, op, args, kwargs, meta):
4545
x = args[0]
4646
kernel_h, kernel_w = args[1]
4747
kernel_size = kernel_h * kernel_w
48-
stride_h, stride_w = args[2]
48+
if len(args) > 2 and args[2] is not None:
49+
stride_h, stride_w = args[2]
50+
else:
51+
stride_h, stride_w = kernel_h, kernel_w
4952
pad_h, pad_w = new_pad_h, new_pad_w = args[3] if len(args) > 3 else (0, 0)
5053
ceil_mode = args[4] if len(args) > 4 else False
5154
count_include_pad = args[5] if len(args) > 5 else True
@@ -108,7 +111,14 @@ def call_operator(self, op, args, kwargs, meta):
108111
x = super().call_operator(cat_op, (cat_nodes, 2), kwargs, meta)
109112
new_pad_h = 0
110113

111-
avgpool_args = (x, args[1], args[2], [new_pad_h, new_pad_w], ceil_mode, False)
114+
avgpool_args = (
115+
x,
116+
args[1],
117+
[stride_h, stride_w],
118+
[new_pad_h, new_pad_w],
119+
ceil_mode,
120+
False,
121+
)
112122
x = super().call_operator(avgpool_op, avgpool_args, kwargs, meta)
113123

114124
# Multiply by factor (kernel_size / divisor_override) if divisor_override
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import torch
9+
from executorch.backends.arm._passes.decompose_avg_pool2d import DecomposeAvgPool2d
10+
from executorch.backends.arm.test import common
11+
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
12+
13+
input_t = Tuple[torch.Tensor] # Input x
14+
15+
16+
class AvgPool2dWithStride(torch.nn.Module):
17+
"""
18+
avg_pool2d model with explicit stride parameter
19+
"""
20+
21+
def get_inputs(self) -> input_t:
22+
return (torch.rand(1, 3, 8, 8),)
23+
24+
def forward(self, x):
25+
return torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
26+
27+
28+
class AvgPool2dWithoutStride(torch.nn.Module):
29+
"""
30+
avg_pool2d model without stride parameter (should default to kernel_size)
31+
"""
32+
33+
def get_inputs(self) -> input_t:
34+
return (torch.rand(1, 3, 8, 8),)
35+
36+
def forward(self, x):
37+
return torch.nn.functional.avg_pool2d(x, kernel_size=3)
38+
39+
40+
class AvgPool2dListKernel(torch.nn.Module):
41+
"""
42+
avg_pool2d model with list kernel_size and no stride
43+
"""
44+
45+
def get_inputs(self) -> input_t:
46+
return (torch.rand(1, 3, 8, 8),)
47+
48+
def forward(self, x):
49+
return torch.nn.functional.avg_pool2d(x, kernel_size=[2, 3])
50+
51+
52+
modules = {
53+
"avg_pool2d_with_stride": AvgPool2dWithStride(),
54+
"avg_pool2d_without_stride": AvgPool2dWithoutStride(),
55+
"avg_pool2d_list_kernel": AvgPool2dListKernel(),
56+
}
57+
58+
59+
@common.parametrize("module", modules)
60+
def test_decompose_avg_pool2d_tosa_MI(module):
61+
"""Test that DecomposeAvgPool2d pass works correctly with and without stride parameters."""
62+
pipeline = PassPipeline[input_t](
63+
module,
64+
module.get_inputs(),
65+
quantize=False,
66+
ops_before_pass={
67+
"executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 1,
68+
},
69+
ops_after_pass={
70+
# After decomposition, we should still see avg_pool2d (transformed)
71+
"executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 1,
72+
},
73+
pass_list=[DecomposeAvgPool2d],
74+
)
75+
pipeline.run()

0 commit comments

Comments
 (0)