@@ -14,6 +14,30 @@ class TestStaticConstantPad(unittest.TestCase):
1414 def setUp (self ):
1515 torch ._dynamo .reset ()
1616
17+ class NHWCStaticConstantPad (torch .nn .Module ):
18+ def __init__ (self ):
19+ super ().__init__ ()
20+ self .conv1 = torch .nn .Conv2d (in_channels = 2 , out_channels = 2 , kernel_size = 1 )
21+ self .conv2 = torch .nn .Conv2d (in_channels = 13 , out_channels = 13 , kernel_size = 1 )
22+
23+ def forward (self , x ):
24+ a = self .conv1 (x )
25+ pad_6 = (1 , 2 , 3 , 4 , 5 , 6 )
26+ a = torch .nn .functional .pad (
27+ input = a ,
28+ pad = pad_6 ,
29+ mode = "constant" ,
30+ value = 3.1 ,
31+ )
32+ # tensorshape = [1, 13, 10, 7]
33+ a = self .conv2 (a )
34+
35+ return a
36+
37+ def sample_inputs (self ):
38+ # NCHW
39+ return (torch .randn (1 , 2 , 3 , 4 ),)
40+
1741 class StaticConstantPadFunctional (torch .nn .Module ):
1842 def __init__ (self ):
1943 super ().__init__ ()
@@ -205,3 +229,24 @@ def test_qs8_static_constant_pad_2d(self):
205229 .serialize ()
206230 .run_method_and_compare_outputs ()
207231 )
232+
233+ def test_fp32_static_constant_pad_nhwc (self ):
234+ conv = self .NHWCStaticConstantPad ()
235+ inputs = conv .sample_inputs ()
236+ (
237+ Tester (conv , inputs )
238+ .export ()
239+ .check_count ({"torch.ops.aten.pad.default" : 1 })
240+ .dump_artifact ()
241+ .to_edge_transform_and_lower ()
242+ .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
243+ .check_not (
244+ [
245+ "executorch_exir_dialects_edge__ops_aten_constant_pad_nd_default" ,
246+ "executorch_exir_dialects_edge__ops_aten_convolution_default" ,
247+ ]
248+ )
249+ .to_executorch ()
250+ .serialize ()
251+ .run_method_and_compare_outputs ()
252+ )
0 commit comments