diff --git a/backends/xnnpack/partition/config/generic_node_configs.py b/backends/xnnpack/partition/config/generic_node_configs.py index 3faf1b12066..559d1522275 100644 --- a/backends/xnnpack/partition/config/generic_node_configs.py +++ b/backends/xnnpack/partition/config/generic_node_configs.py @@ -589,6 +589,21 @@ def supported_precision_types(self) -> List[ConfigPrecisionType]: class ConstantPadConfig(GenericNodePartitionerConfig): target_name = "constant_pad_nd.default" + def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: + """ + XNNPACK does not support cropping with negative padding sizes. + """ + if not self.check_common_constraints(node, ep): + return False + + # Check for negative padding values + padding = cast(List[int], node.args[1]) + if any(p < 0 for p in padding): + why(node, reason="XNNPACK does not support negative padding values") + return False + + return True + def supported_precision_types(self) -> List[ConfigPrecisionType]: return [ConfigPrecisionType.FP32] diff --git a/backends/xnnpack/test/ops/test_static_constant_pad.py b/backends/xnnpack/test/ops/test_static_constant_pad.py index 9613308f6a6..53224071edd 100644 --- a/backends/xnnpack/test/ops/test_static_constant_pad.py +++ b/backends/xnnpack/test/ops/test_static_constant_pad.py @@ -7,7 +7,10 @@ import unittest import torch +from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner from executorch.backends.xnnpack.test.tester import Tester +from executorch.exir import to_edge_transform_and_lower +from torch.export import export class TestStaticConstantPad(unittest.TestCase): @@ -125,6 +128,45 @@ def _test_static_constant_pad_functional(self, inputs): .run_method_and_compare_outputs() ) + class NegativePadModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.pad = torch.nn.ConstantPad2d((0, 0, -2, 2), 0.0) + + def forward(self, input): + input = self.pad(input) + return input + + def test_negative_pad_model_with_ints(self): + """Test that negative padding with integer inputs falls back to PyTorch implementation as XNNPACK does not support negative padding dimensions""" + input_tensor = torch.tensor([[4], [5], [6]]) + model = self.NegativePadModel() + model.eval() + model.to("cpu") + + exported_model = export(model, (input_tensor,)) + + executorch_program = to_edge_transform_and_lower( + exported_model, partitioner=[XnnpackPartitioner()] + ).to_executorch() + + self.assertIsNotNone(executorch_program) + + def test_negative_pad_model_with_floats(self): + """Test that negative padding with float inputs is now rejected by XNNPACK partitioner as XNNPACK does not support negative padding dimensions""" + input_tensor = torch.tensor([[4.0], [5.0], [6.0]]) + model = self.NegativePadModel() + model.eval() + model.to("cpu") + + exported_model = export(model, (input_tensor,)) + + executorch_program = to_edge_transform_and_lower( + exported_model, partitioner=[XnnpackPartitioner()] + ).to_executorch() + + self.assertIsNotNone(executorch_program) + def test_fp16_static_constant_pad_functional(self): inputs = ( torch.randn(size=(5, 4, 3, 2)).to(torch.float16),