@@ -100,6 +100,8 @@ def forward(self, lhs, rhs):
100100 def matmul (self , lhs , rhs ):
101101 return torch .mm (lhs , rhs )
102102
103+ # ==============================================================================
104+
103105
104106@register_test_case (module_factory = lambda : MmTanhModule ())
105107def MmTanhModule_basic (module , tu : TestUtils ):
@@ -192,6 +194,8 @@ def forward(self, x):
192194def AdaptiveAvgPool2dModule_basic (module , tu : TestUtils ):
193195 module .forward (tu .rand (10 , 3 , 8 , 9 ))
194196
197+ # ==============================================================================
198+
195199
196200class FlattenStaticModule (torch .nn .Module ):
197201 def __init__ (self ):
@@ -211,6 +215,8 @@ def forward(self, x):
211215def FlattenStaticModule_basic (module , tu : TestUtils ):
212216 module .forward (tu .rand (10 , 3 , 8 , 9 , 3 , 4 ))
213217
218+ # ==============================================================================
219+
214220
215221class FlattenRank0Module (torch .nn .Module ):
216222 def __init__ (self ):
@@ -230,6 +236,8 @@ def forward(self, x):
230236def FlattenRank0Module_basic (module , tu : TestUtils ):
231237 module .forward (torch .tensor (4.0 ))
232238
239+ # ==============================================================================
240+
233241
234242class FlattenDynamicModule (torch .nn .Module ):
235243 def __init__ (self ):
@@ -249,6 +257,8 @@ def forward(self, x):
249257def FlattenDynamicModule_basic (module , tu : TestUtils ):
250258 module .forward (tu .rand (10 , 3 , 8 , 9 , 3 , 4 ))
251259
260+ # ==============================================================================
261+
252262
253263class MaxPool2dModule (torch .nn .Module ):
254264 def __init__ (self ):
@@ -266,6 +276,8 @@ def __init__(self):
266276 def forward (self , x ):
267277 return self .mp2d (x )
268278
279+ # ==============================================================================
280+
269281
270282@register_test_case (module_factory = lambda : MaxPool2dModule ())
271283def MaxPool2dModule_basic (module , tu : TestUtils ):
@@ -284,6 +296,8 @@ def __init__(self):
284296 def forward (self , x ):
285297 return torch .transpose (x , 0 , 1 )
286298
299+ # ==============================================================================
300+
287301
288302@register_test_case (module_factory = lambda : TransposeIntModule ())
289303def TransposeIntModule_basic (module , tu : TestUtils ):
@@ -305,6 +319,8 @@ def forward(self, x):
305319def PermuteModule_basic (module , tu : TestUtils ):
306320 module .forward (tu .rand (3 , 4 , 2 ))
307321
322+ # ==============================================================================
323+
308324class TransposeIntNegDimsModule (torch .nn .Module ):
309325 def __init__ (self ):
310326 super ().__init__ ()
@@ -317,6 +333,8 @@ def __init__(self):
317333 def forward (self , x ):
318334 return torch .transpose (x , - 1 , - 2 )
319335
336+ # ==============================================================================
337+
320338
321339@register_test_case (module_factory = lambda : TransposeIntNegDimsModule ())
322340def TransposeIntNegDimsModule_basic (module , tu : TestUtils ):
@@ -335,6 +353,8 @@ def __init__(self):
335353 def forward (self , x ):
336354 return x .permute (0 , - 1 , 1 )
337355
356+ # ==============================================================================
357+
338358@register_test_case (module_factory = lambda : PermuteNegativeIndexModule ())
339359def PermuteNegativeIndexModule_basic (module , tu : TestUtils ):
340360 module .forward (tu .rand (3 , 4 , 2 ))
@@ -357,6 +377,8 @@ def forward(self, x, y, z):
357377def TensorsConcatModule_basic (module , tu : TestUtils ):
358378 module .forward (tu .rand (2 , 2 , 4 ), tu .rand (2 , 1 , 4 ), tu .rand (2 , 3 , 4 ))
359379
380+ # ==============================================================================
381+
360382
361383class GatherModule (torch .nn .Module ):
362384 def __init__ (self ):
@@ -376,6 +398,8 @@ def forward(self, tensor, indices):
376398def GatherModule_basic (module , tu : TestUtils ):
377399 module .forward (tu .rand (2 , 3 , 4 ), torch .tensor ([[[1 , 2 , 3 ], [1 , 2 , 3 ]]]))
378400
401+ # ==============================================================================
402+
379403class AddSizeIntModule (torch .nn .Module ):
380404 def __init__ (self ):
381405 super ().__init__ ()
@@ -396,6 +420,8 @@ def forward(self, tensor):
396420def AddSizeIntModule_basic (module , tu : TestUtils ):
397421 module .forward (torch .randn (3 , 3 ))
398422
423+ # ==============================================================================
424+
399425
400426class AddSizeIntNegDimModule (torch .nn .Module ):
401427 def __init__ (self ):
@@ -417,6 +443,8 @@ def forward(self, tensor):
417443def AddSizeIntNegDimModule_basic (module , tu : TestUtils ):
418444 module .forward (torch .randn (3 , 3 ))
419445
446+ # ==============================================================================
447+
420448class EmbeddingModule (torch .nn .Module ):
421449 def __init__ (self ):
422450 super ().__init__ ()
@@ -438,6 +466,7 @@ def forward(self, indices):
438466def EmbeddingModule_basic (module , tu : TestUtils ):
439467 module .forward (torch .randint (100 , (3 , 3 )))
440468
469+ # ==============================================================================
441470
442471class SoftmaxIntModule (torch .nn .Module ):
443472 def __init__ (self ):
@@ -474,6 +503,8 @@ def forward(self, tensor):
474503def _SoftmaxModule_basic (module , tu : TestUtils ):
475504 module .forward (torch .randn (3 , 2 , 4 ))
476505
506+ # ==============================================================================
507+
477508
478509class SoftmaxIntNegDimModule (torch .nn .Module ):
479510 def __init__ (self ):
@@ -494,6 +525,8 @@ def forward(self, tensor):
494525def SoftmaxIntNegDimModule_basic (module , tu : TestUtils ):
495526 module .forward (torch .randn (3 , 2 , 4 ))
496527
528+ # ==============================================================================
529+
497530
498531class SoftmaxIntArgTypeF64Module (torch .nn .Module ):
499532 def __init__ (self ):
@@ -513,6 +546,7 @@ def forward(self, tensor):
513546def SoftmaxIntArgTypeF64Module_basic (module , tu : TestUtils ):
514547 module .forward (torch .randn (3 , 2 , 4 ).double ())
515548
549+ # ==============================================================================
516550
517551class BroadcastToModule (torch .nn .Module ):
518552 def __init__ (self ):
@@ -531,6 +565,8 @@ def forward(self, x):
531565def BroadcastToModule_basic (module , tu : TestUtils ):
532566 module .forward (tu .rand (3 , 1 , 1 ))
533567
568+ # ==============================================================================
569+
534570class ExpandModule (torch .nn .Module ):
535571 def __init__ (self ):
536572 super ().__init__ ()
@@ -548,6 +584,9 @@ def forward(self, x):
548584def ExpandModule_basic (module , tu : TestUtils ):
549585 module .forward (tu .rand (3 , 1 , 1 ))
550586
587+ # ==============================================================================
588+
589+
551590class OnesModuleInt (torch .nn .Module ):
552591 def __init__ (self ):
553592 super ().__init__ ()
@@ -563,6 +602,8 @@ def forward(self):
563602def OnesModuleInt_basic (module , tu : TestUtils ):
564603 module .forward ()
565604
605+ # ==============================================================================
606+
566607class OnesModuleFloat (torch .nn .Module ):
567608 def __init__ (self ):
568609 super ().__init__ ()
@@ -594,6 +635,7 @@ def forward(self):
594635def OnesModuleFalsePinMemory_basic (module , tu : TestUtils ):
595636 module .forward ()
596637
638+ # ==============================================================================
597639
598640class ContiguousModule (torch .nn .Module ):
599641 def __init__ (self ):
@@ -611,7 +653,7 @@ def forward(self, x):
611653@register_test_case (module_factory = lambda : ContiguousModule ())
612654def ContiguousModule_basic (module , tu : TestUtils ):
613655 module .forward (tu .rand (3 , 1 ))
614-
656+
615657class TensorToInt (torch .nn .Module ):
616658 def __init__ (self ):
617659 super ().__init__ ()
@@ -681,6 +723,7 @@ def forward(self):
681723def NumToTensorFloatModule_basic (module , tu : TestUtils ):
682724 module .forward ()
683725
726+ # ==============================================================================
684727
685728# This test can be removed once we have one real op returning 3 float32 tensors
686729class ReturnThreeTensorFloat32 (torch .nn .Module ):
0 commit comments