Skip to content

Commit 5dc0b4f

Browse files
GregoryComerfacebook-github-bot
authored andcommitted
Add stride constraint to XNN MaxPool
Summary: Add an XNNPACK partitioner constraint for MaxPool2d to enforce stride <= kernel_size. See https://github.com/google/XNNPACK/blob/860f2b9ad9d3602599aff49a41d0131d2a350e00/src/subgraph/max-pooling-2d.c#L327. Differential Revision: D67380978
1 parent f28e9a5 commit 5dc0b4f

File tree

2 files changed

+32
-2
lines changed

2 files changed

+32
-2
lines changed

backends/xnnpack/partition/config/generic_node_configs.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,15 +282,23 @@ class MaxPool2dConfig(GenericNodePartitionerConfig):
282282

283283
def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
284284
"""
285-
XNNPACK's maxpool2d does not support ceil mode
285+
XNNPACK's maxpool2d does not support ceil mode and requires stride <= kernel_size
286286
"""
287287
if not self.check_common_constraints(node, ep):
288288
return False
289-
289+
290+
kernel_size = node.args[1]
291+
stride = node.args[2]
290292
is_ceil_mode = len(node.args) >= 6 and cast(bool, node.args[5])
293+
291294
if is_ceil_mode:
292295
why(node, reason="ceil mode is not supported")
293296
return False
297+
298+
if stride[0] > kernel_size[0] or stride[1] > kernel_size[1]:
299+
why(node, reason="stride must be less than or equal to kernel size")
300+
return False
301+
294302
return True
295303

296304
def supported_precision_types(self) -> List[ConfigPrecisionType]:

backends/xnnpack/test/ops/test_maxpool2d.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,28 @@ def test_fp32_maxpool2d_unsupported_ceilmode(self):
114114
.serialize()
115115
.run_method_and_compare_outputs()
116116
)
117+
118+
def test_fp32_maxpool2d_unsupported_stride(self):
119+
"""
120+
XNNPACK MaxPool2d requires stride <= kernel_size.
121+
"""
122+
inputs = (torch.randn(1, 32, 23, 23),)
123+
(
124+
Tester(self.MaxPool2d(kernel_size=2, stride=3), inputs)
125+
.export()
126+
.check_count({"torch.ops.aten.max_pool2d.default": 1})
127+
.to_edge_transform_and_lower()
128+
# We expect it not be be delegated.
129+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 0})
130+
.check_count(
131+
{
132+
"executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default": 1
133+
}
134+
)
135+
.to_executorch()
136+
.serialize()
137+
.run_method_and_compare_outputs()
138+
)
117139

118140
def test_qs8_maxpool2d(self):
119141
class MaxPool(torch.nn.Module):

0 commit comments

Comments
 (0)