Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 16 additions & 11 deletions backends/xnnpack/partition/config/generic_node_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from executorch.exir.backend.canonical_partitioners.config_partitioner import (
format_target_name,
)
from executorch.exir.backend.utils import WhyNoPartition
from executorch.exir.backend.utils import is_shape_dynamic, WhyNoPartition
from torch.export import ExportedProgram

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -284,19 +284,27 @@ class MaxPool2dConfig(GenericNodePartitionerConfig):

def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
"""
XNNPACK's maxpool2d does not support ceil mode
XNNPACK's maxpool2d does not support ceil mode and requires stride <= kernel_size
"""
if not self.check_common_constraints(node, ep):
return False

# Ceil mode is supported via op padding, which must be statically known.
kernel_size = node.args[1]
stride = node.args[2]
is_ceil_mode = len(node.args) >= 6 and cast(bool, node.args[5])
is_dynamic = "val" in node.meta and any(
isinstance(d, torch.SymInt) for d in node.meta["val"].shape
)
if is_ceil_mode and is_dynamic:

# Ceil mode is supported via op padding, which must be statically known.
if is_ceil_mode and is_shape_dynamic(node):
why(node, reason="ceil mode is not supported for dynamic shapes")
return False

if stride[0] > kernel_size[0] or stride[1] > kernel_size[1]: # pyre-ignore[16]
why(
node,
reason=f"stride ({stride}) must be less than or equal to kernel size ({kernel_size})",
)
return False

return True

def supported_precision_types(self) -> List[ConfigPrecisionType]:
Expand All @@ -316,10 +324,7 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
if not self.check_common_constraints(node, ep):
return False

is_output_dynamic = "val" in node.meta and any(
isinstance(d, torch.SymInt) for d in node.meta["val"].shape
)
if is_output_dynamic:
if is_shape_dynamic(node):
why(node, reason="dynamic output sizes are not supported")
return False
return True
Expand Down
22 changes: 22 additions & 0 deletions backends/xnnpack/test/ops/test_maxpool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,28 @@ def test_fp32_maxpool2d_unsupported_dynamic_ceilmode(self):
.run_method_and_compare_outputs()
)

def test_fp32_maxpool2d_unsupported_stride(self):
"""
XNNPACK MaxPool2d requires stride <= kernel_size.
"""
inputs = (torch.randn(1, 32, 23, 23),)
(
Tester(self.MaxPool2d(kernel_size=2, stride=3), inputs)
.export()
.check_count({"torch.ops.aten.max_pool2d.default": 1})
.to_edge_transform_and_lower()
# We expect it not be be delegated.
.check_count({"torch.ops.higher_order.executorch_call_delegate": 0})
.check_count(
{
"executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default": 1
}
)
.to_executorch()
.serialize()
.run_method_and_compare_outputs()
)

def test_qs8_maxpool2d(self):
class MaxPool(torch.nn.Module):
def __init__(self, maxpool_params):
Expand Down
13 changes: 13 additions & 0 deletions exir/backend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import logging
import operator
from collections import defaultdict
Expand Down Expand Up @@ -417,6 +419,17 @@ def tag_mutated_buffer(edge_program: ExportedProgram) -> None:
node.meta["delegation_tag"] = user_tags.pop()


def is_shape_dynamic(node: torch.fx.Node) -> bool:
"""
Check if the node shape is dynamic.
"""

# Shape is dynamic if any of the dimensions don't evaluate to a static value
return "val" in node.meta and any(
isinstance(d, torch.SymInt) for d in node.meta["val"].shape
)


# TODO - style: use templated types
class DelegateMappingBuilder:
"""
Expand Down