1111
1212# ==============================================================================
1313
14-
1514class MmModule (torch .nn .Module ):
1615 def __init__ (self ):
1716 super ().__init__ ()
@@ -38,7 +37,6 @@ def MmModule_chained(module, tu: TestUtils):
3837
3938# ==============================================================================
4039
41-
4240class BmmModule (torch .nn .Module ):
4341 def __init__ (self ):
4442 super ().__init__ ()
@@ -57,10 +55,8 @@ def forward(self, lhs, rhs):
5755def BmmModule_basic (module , tu : TestUtils ):
5856 module .forward (tu .rand (3 , 4 , 5 ), tu .rand (3 , 5 , 4 ))
5957
60-
6158# ==============================================================================
6259
63-
6460# A subgraph with multiple mm ops.
6561class MmDagModule (torch .nn .Module ):
6662 def __init__ (self ):
@@ -80,10 +76,8 @@ def forward(self, lhs, rhs):
8076def MmDagModule_basic (module , tu : TestUtils ):
8177 module .forward (tu .rand (4 , 4 ), tu .rand (4 , 4 ))
8278
83-
8479# ==============================================================================
8580
86-
8781class MmTanhModule (torch .nn .Module ):
8882 def __init__ (self ):
8983 super ().__init__ ()
@@ -100,16 +94,13 @@ def forward(self, lhs, rhs):
10094 def matmul (self , lhs , rhs ):
10195 return torch .mm (lhs , rhs )
10296
103- # ==============================================================================
104-
10597
10698@register_test_case (module_factory = lambda : MmTanhModule ())
10799def MmTanhModule_basic (module , tu : TestUtils ):
108100 module .forward (tu .rand (4 , 2 ), tu .rand (2 , 4 ))
109101
110102# ==============================================================================
111103
112-
113104class AddmmModuleFloat (torch .nn .Module ):
114105 def __init__ (self ):
115106 super ().__init__ ()
@@ -196,7 +187,6 @@ def AdaptiveAvgPool2dModule_basic(module, tu: TestUtils):
196187
197188# ==============================================================================
198189
199-
200190class FlattenStaticModule (torch .nn .Module ):
201191 def __init__ (self ):
202192 super ().__init__ ()
@@ -217,7 +207,6 @@ def FlattenStaticModule_basic(module, tu: TestUtils):
217207
218208# ==============================================================================
219209
220-
221210class FlattenRank0Module (torch .nn .Module ):
222211 def __init__ (self ):
223212 super ().__init__ ()
@@ -238,7 +227,6 @@ def FlattenRank0Module_basic(module, tu: TestUtils):
238227
239228# ==============================================================================
240229
241-
242230class FlattenDynamicModule (torch .nn .Module ):
243231 def __init__ (self ):
244232 super ().__init__ ()
@@ -259,7 +247,6 @@ def FlattenDynamicModule_basic(module, tu: TestUtils):
259247
260248# ==============================================================================
261249
262-
263250class MaxPool2dModule (torch .nn .Module ):
264251 def __init__ (self ):
265252 super ().__init__ ()
@@ -276,14 +263,86 @@ def __init__(self):
276263 def forward (self , x ):
277264 return self .mp2d (x )
278265
279- # ==============================================================================
280-
281266
282267@register_test_case (module_factory = lambda : MaxPool2dModule ())
283268def MaxPool2dModule_basic (module , tu : TestUtils ):
284269 module .forward (tu .rand (1 , 1 , 20 , 20 ) - 0.5 )
285270
286271
272+ class ConstantPad2dStaticModule (torch .nn .Module ):
273+ def __init__ (self ):
274+ super ().__init__ ()
275+ self .pad2d = torch .nn .ConstantPad2d ((0 , 1 , 2 , 3 ), - float ('inf' ))
276+
277+ @export
278+ @annotate_args ([
279+ None ,
280+ ([1 , 1 , 20 , 20 ], torch .float32 , True ),
281+ ])
282+ def forward (self , x ):
283+ return self .pad2d (x )
284+
285+
286+ @register_test_case (module_factory = lambda : ConstantPad2dStaticModule ())
287+ def ConstantPad2dStaticModule_basic (module , tu : TestUtils ):
288+ module .forward (tu .rand (1 , 1 , 20 , 20 ) - 0.5 )
289+
290+ # ==============================================================================
291+
292+ class ConstantPadNdModule (torch .nn .Module ):
293+ def __init__ (self ):
294+ super ().__init__ ()
295+
296+ @export
297+ @annotate_args ([
298+ None ,
299+ ([- 1 , - 1 , - 1 , - 1 , - 1 , - 1 ], torch .float32 , True ),
300+ ])
301+ def forward (self , x ):
302+ return torch .ops .aten .constant_pad_nd (x , (0 , 1 ), - float ('inf' ))
303+
304+
305+ @register_test_case (module_factory = lambda : ConstantPadNdModule ())
306+ def ConstantPadNdModule_basic (module , tu : TestUtils ):
307+ module .forward (tu .rand (1 , 1 , 20 , 20 , 4 , 4 ) - 0.5 )
308+
309+
310+ class ConstantPadNdStaticModule (torch .nn .Module ):
311+ def __init__ (self ):
312+ super ().__init__ ()
313+
314+ @export
315+ @annotate_args ([
316+ None ,
317+ ([1 , 1 , 20 , 20 , 4 , 4 ], torch .float32 , True ),
318+ ])
319+ def forward (self , x ):
320+ return torch .ops .aten .constant_pad_nd (x , (0 , 1 ), - float ('inf' ))
321+
322+
323+ @register_test_case (module_factory = lambda : ConstantPadNdStaticModule ())
324+ def ConstantPadNdStaticModule_basic (module , tu : TestUtils ):
325+ module .forward (tu .rand (1 , 1 , 20 , 20 , 4 , 4 ) - 0.5 )
326+
327+ class ConstantPadNdPartialStaticModule (torch .nn .Module ):
328+ def __init__ (self ):
329+ super ().__init__ ()
330+
331+ @export
332+ @annotate_args ([
333+ None ,
334+ ([1 , 1 , 20 , 20 , - 1 , - 1 ], torch .float32 , True ),
335+ ])
336+ def forward (self , x ):
337+ return torch .ops .aten .constant_pad_nd (x , (0 , 1 , 2 , 3 ), - float ('inf' ))
338+
339+
340+ @register_test_case (module_factory = lambda : ConstantPadNdPartialStaticModule ())
341+ def ConstantPadNdPartialStaticModule_basic (module , tu : TestUtils ):
342+ module .forward (tu .rand (1 , 1 , 20 , 20 , 4 , 4 ) - 0.5 )
343+
344+ # ==============================================================================
345+
287346class TransposeIntModule (torch .nn .Module ):
288347 def __init__ (self ):
289348 super ().__init__ ()
@@ -296,13 +355,13 @@ def __init__(self):
296355 def forward (self , x ):
297356 return torch .transpose (x , 0 , 1 )
298357
299- # ==============================================================================
300-
301358
302359@register_test_case (module_factory = lambda : TransposeIntModule ())
303360def TransposeIntModule_basic (module , tu : TestUtils ):
304361 module .forward (tu .rand (3 , 4 , 2 ))
305362
363+ # ==============================================================================
364+
306365class PermuteModule (torch .nn .Module ):
307366 def __init__ (self ):
308367 super ().__init__ ()
@@ -333,13 +392,12 @@ def __init__(self):
333392 def forward (self , x ):
334393 return torch .transpose (x , - 1 , - 2 )
335394
336- # ==============================================================================
337-
338395
339396@register_test_case (module_factory = lambda : TransposeIntNegDimsModule ())
340397def TransposeIntNegDimsModule_basic (module , tu : TestUtils ):
341398 module .forward (tu .rand (3 , 4 , 2 ))
342399
400+ # ==============================================================================
343401
344402class PermuteNegativeIndexModule (torch .nn .Module ):
345403 def __init__ (self ):
@@ -353,11 +411,12 @@ def __init__(self):
353411 def forward (self , x ):
354412 return x .permute (0 , - 1 , 1 )
355413
356- # ==============================================================================
357-
358414@register_test_case (module_factory = lambda : PermuteNegativeIndexModule ())
359415def PermuteNegativeIndexModule_basic (module , tu : TestUtils ):
360416 module .forward (tu .rand (3 , 4 , 2 ))
417+
418+ # ==============================================================================
419+
361420class TensorsConcatModule (torch .nn .Module ):
362421 def __init__ (self ):
363422 super ().__init__ ()
@@ -379,7 +438,6 @@ def TensorsConcatModule_basic(module, tu: TestUtils):
379438
380439# ==============================================================================
381440
382-
383441class GatherModule (torch .nn .Module ):
384442 def __init__ (self ):
385443 super ().__init__ ()
@@ -422,7 +480,6 @@ def AddSizeIntModule_basic(module, tu: TestUtils):
422480
423481# ==============================================================================
424482
425-
426483class AddSizeIntNegDimModule (torch .nn .Module ):
427484 def __init__ (self ):
428485 super ().__init__ ()
@@ -505,7 +562,6 @@ def _SoftmaxModule_basic(module, tu: TestUtils):
505562
506563# ==============================================================================
507564
508-
509565class SoftmaxIntNegDimModule (torch .nn .Module ):
510566 def __init__ (self ):
511567 super ().__init__ ()
@@ -527,7 +583,6 @@ def SoftmaxIntNegDimModule_basic(module, tu: TestUtils):
527583
528584# ==============================================================================
529585
530-
531586class SoftmaxIntArgTypeF64Module (torch .nn .Module ):
532587 def __init__ (self ):
533588 super ().__init__ ()
0 commit comments