Skip to content

Commit 57f4364

Browse files
committed
Fixed bug with ConstantPad2d accepting negative padding dimensions
1 parent 0305ddc commit 57f4364

File tree

2 files changed

+57
-0
lines changed

2 files changed

+57
-0
lines changed

backends/xnnpack/partition/config/generic_node_configs.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,21 @@ def supported_precision_types(self) -> List[ConfigPrecisionType]:
589589
class ConstantPadConfig(GenericNodePartitionerConfig):
590590
target_name = "constant_pad_nd.default"
591591

592+
def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
593+
"""
594+
XNNPACK does not support cropping with negative padding sizes.
595+
"""
596+
if not self.check_common_constraints(node, ep):
597+
return False
598+
599+
# Check for negative padding values
600+
padding = cast(List[int], node.args[1])
601+
if any(p < 0 for p in padding):
602+
why(node, reason="XNNPACK does not support negative padding values")
603+
return False
604+
605+
return True
606+
592607
def supported_precision_types(self) -> List[ConfigPrecisionType]:
593608
return [ConfigPrecisionType.FP32]
594609

backends/xnnpack/test/ops/test_static_constant_pad.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77
import unittest
88

99
import torch
10+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
1011
from executorch.backends.xnnpack.test.tester import Tester
12+
from executorch.exir import to_edge_transform_and_lower
13+
from torch.export import export
1114

1215

1316
class TestStaticConstantPad(unittest.TestCase):
@@ -125,6 +128,45 @@ def _test_static_constant_pad_functional(self, inputs):
125128
.run_method_and_compare_outputs()
126129
)
127130

131+
class NegativePadModel(torch.nn.Module):
132+
def __init__(self):
133+
super().__init__()
134+
self.pad = torch.nn.ConstantPad2d((0, 0, -2, 2), 0.0)
135+
136+
def forward(self, input):
137+
input = self.pad(input)
138+
return input
139+
140+
def test_negative_pad_model_with_ints(self):
141+
"""Test that negative padding with integer inputs falls back to PyTorch implementation as XNNPACK does not support negative padding dimensions"""
142+
input_tensor = torch.tensor([[4], [5], [6]])
143+
model = self.NegativePadModel()
144+
model.eval()
145+
model.to("cpu")
146+
147+
exported_model = export(model, (input_tensor,))
148+
149+
executorch_program = to_edge_transform_and_lower(
150+
exported_model, partitioner=[XnnpackPartitioner()]
151+
).to_executorch()
152+
153+
self.assertIsNotNone(executorch_program)
154+
155+
def test_negative_pad_model_with_floats(self):
156+
"""Test that negative padding with float inputs is now rejected by XNNPACK partitioner as XNNPACK does not support negative padding dimensions"""
157+
input_tensor = torch.tensor([[4.0], [5.0], [6.0]])
158+
model = self.NegativePadModel()
159+
model.eval()
160+
model.to("cpu")
161+
162+
exported_model = export(model, (input_tensor,))
163+
164+
executorch_program = to_edge_transform_and_lower(
165+
exported_model, partitioner=[XnnpackPartitioner()]
166+
).to_executorch()
167+
168+
self.assertIsNotNone(executorch_program)
169+
128170
def test_fp16_static_constant_pad_functional(self):
129171
inputs = (
130172
torch.randn(size=(5, 4, 3, 2)).to(torch.float16),

0 commit comments

Comments
 (0)