Skip to content

Commit a37308f

Browse files
committed
Add stride constraint to XNN MaxPool (pytorch#7354)
Summary: Add an XNNPACK partitioner constraint for MaxPool2d to enforce stride <= kernel_size. See https://github.com/google/XNNPACK/blob/860f2b9ad9d3602599aff49a41d0131d2a350e00/src/subgraph/max-pooling-2d.c#L327. Test Plan: Added an operator-level test (test_fp32_maxpool2d_unsupported_stride) to verify the new constraint. ``` buck test executorch/backends/xnnpack/test:test_xnnpack_ops -- maxpool2d ``` Reviewed By: digantdesai Differential Revision: D67380978 Pulled By: GregoryComer
1 parent ed15042 commit a37308f

File tree

3 files changed

+50
-12
lines changed

3 files changed

+50
-12
lines changed

backends/xnnpack/partition/config/generic_node_configs.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# pyre-unsafe
88

99
import logging
10-
from typing import cast, List, Optional
10+
from typing import cast, List, Optional, Sequence
1111

1212
import torch
1313
from executorch.backends.xnnpack.partition.config.xnnpack_config import (
@@ -19,7 +19,7 @@
1919
from executorch.exir.backend.canonical_partitioners.config_partitioner import (
2020
format_target_name,
2121
)
22-
from executorch.exir.backend.utils import WhyNoPartition
22+
from executorch.exir.backend.utils import is_shape_dynamic, WhyNoPartition
2323
from torch.export import ExportedProgram
2424

2525
logger = logging.getLogger(__name__)
@@ -284,19 +284,27 @@ class MaxPool2dConfig(GenericNodePartitionerConfig):
284284

285285
def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
286286
"""
287-
XNNPACK's maxpool2d does not support ceil mode
287+
XNNPACK's maxpool2d does not support ceil mode and requires stride <= kernel_size
288288
"""
289289
if not self.check_common_constraints(node, ep):
290290
return False
291291

292-
# Ceil mode is supported via op padding, which must be statically known.
292+
kernel_size: Sequence[int] = node.args[1]
293+
stride: Sequence[int] = node.args[2]
293294
is_ceil_mode = len(node.args) >= 6 and cast(bool, node.args[5])
294-
is_dynamic = "val" in node.meta and any(
295-
isinstance(d, torch.SymInt) for d in node.meta["val"].shape
296-
)
297-
if is_ceil_mode and is_dynamic:
295+
296+
# Ceil mode is supported via op padding, which must be statically known.
297+
if is_ceil_mode and is_shape_dynamic(node):
298298
why(node, reason="ceil mode is not supported for dynamic shapes")
299299
return False
300+
301+
if stride[0] > kernel_size[0] or stride[1] > kernel_size[1]:
302+
why(
303+
node,
304+
reason=f"stride ({stride}) must be less than or equal to kernel size ({kernel_size})",
305+
)
306+
return False
307+
300308
return True
301309

302310
def supported_precision_types(self) -> List[ConfigPrecisionType]:
@@ -316,10 +324,7 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
316324
if not self.check_common_constraints(node, ep):
317325
return False
318326

319-
is_output_dynamic = "val" in node.meta and any(
320-
isinstance(d, torch.SymInt) for d in node.meta["val"].shape
321-
)
322-
if is_output_dynamic:
327+
if is_shape_dynamic(node):
323328
why(node, reason="dynamic output sizes are not supported")
324329
return False
325330
return True

backends/xnnpack/test/ops/test_maxpool2d.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,28 @@ def test_fp32_maxpool2d_unsupported_dynamic_ceilmode(self):
163163
.run_method_and_compare_outputs()
164164
)
165165

166+
def test_fp32_maxpool2d_unsupported_stride(self):
167+
"""
168+
XNNPACK MaxPool2d requires stride <= kernel_size.
169+
"""
170+
inputs = (torch.randn(1, 32, 23, 23),)
171+
(
172+
Tester(self.MaxPool2d(kernel_size=2, stride=3), inputs)
173+
.export()
174+
.check_count({"torch.ops.aten.max_pool2d.default": 1})
175+
.to_edge_transform_and_lower()
176+
# We expect it not be be delegated.
177+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 0})
178+
.check_count(
179+
{
180+
"executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default": 1
181+
}
182+
)
183+
.to_executorch()
184+
.serialize()
185+
.run_method_and_compare_outputs()
186+
)
187+
166188
def test_qs8_maxpool2d(self):
167189
class MaxPool(torch.nn.Module):
168190
def __init__(self, maxpool_params):

exir/backend/utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,17 @@ def tag_mutated_buffer(edge_program: ExportedProgram) -> None:
417417
node.meta["delegation_tag"] = user_tags.pop()
418418

419419

420+
def is_shape_dynamic(node: torch.fx.Node) -> bool:
421+
"""
422+
Check if the node shape is dynamic.
423+
"""
424+
return "val" in node.meta and any(
425+
isinstance(d, torch.SymInt) for d in node.meta["val"].shape
426+
)
427+
428+
# Shape is dynamic if any of the dimensions don't evaluate to a static value
429+
430+
420431
# TODO - style: use templated types
421432
class DelegateMappingBuilder:
422433
"""

0 commit comments

Comments
 (0)