@@ -586,144 +586,6 @@ def ExpandModule_basic(module, tu: TestUtils):
586586
587587# ==============================================================================
588588
589-
590- class OnesModuleInt (torch .nn .Module ):
591- def __init__ (self ):
592- super ().__init__ ()
593-
594- @export
595- @annotate_args ([
596- None ,
597- ])
598- def forward (self ):
599- return torch .ones (3 , 4 , dtype = torch .int64 )
600-
601- @register_test_case (module_factory = lambda : OnesModuleInt ())
602- def OnesModuleInt_basic (module , tu : TestUtils ):
603- module .forward ()
604-
605- # ==============================================================================
606-
607- class OnesModuleFloat (torch .nn .Module ):
608- def __init__ (self ):
609- super ().__init__ ()
610-
611- @export
612- @annotate_args ([
613- None ,
614- ])
615- def forward (self ):
616- return torch .ones (3 , 4 , dtype = torch .float32 )
617-
618- @register_test_case (module_factory = lambda : OnesModuleFloat ())
619- def OnesModuleFloat_basic (module , tu : TestUtils ):
620- module .forward ()
621-
622-
623- class OnesModuleFalsePinMemory (torch .nn .Module ):
624- def __init__ (self ):
625- super ().__init__ ()
626-
627- @export
628- @annotate_args ([
629- None ,
630- ])
631- def forward (self ):
632- return torch .ones (3 , 4 , dtype = torch .float32 , pin_memory = False )
633-
634- @register_test_case (module_factory = lambda : OnesModuleFalsePinMemory ())
635- def OnesModuleFalsePinMemory_basic (module , tu : TestUtils ):
636- module .forward ()
637-
638- # ==============================================================================
639-
640- class EmptyIntModule (torch .nn .Module ):
641- def __init__ (self ):
642- super ().__init__ ()
643-
644- @export
645- @annotate_args ([
646- None ,
647- ])
648- def forward (self ):
649- return 0 * torch .empty ((3 , 4 ), dtype = torch .int64 )
650-
651- @register_test_case (module_factory = lambda : EmptyIntModule ())
652- def EmptyModule_int (module , tu : TestUtils ):
653- module .forward ()
654-
655- # ==============================================================================
656-
657- class EmptyFloatModule (torch .nn .Module ):
658- def __init__ (self ):
659- super ().__init__ ()
660-
661- @export
662- @annotate_args ([
663- None ,
664- ])
665- def forward (self ):
666- return torch .pow (torch .empty ((3 , 4 ), dtype = torch .float32 ), 0 )
667-
668- @register_test_case (module_factory = lambda : EmptyFloatModule ())
669- def EmptyModule_float (module , tu : TestUtils ):
670- module .forward ()
671-
672-
673- class EmptyFalsePinMemoryModule (torch .nn .Module ):
674- def __init__ (self ):
675- super ().__init__ ()
676-
677- @export
678- @annotate_args ([
679- None ,
680- ])
681- def forward (self ):
682- return torch .pow (torch .empty ((3 , 4 ), dtype = torch .float32 ,
683- pin_memory = False ), 0 )
684-
685- @register_test_case (module_factory = lambda : EmptyFalsePinMemoryModule ())
686- def EmptyModule_falsePinMemory (module , tu : TestUtils ):
687- module .forward ()
688-
689- # ==============================================================================
690-
691- class EmptyLikeIntModule (torch .nn .Module ):
692- def __init__ (self ):
693- super ().__init__ ()
694-
695- @export
696- @annotate_args ([
697- None ,
698- ([- 1 , - 1 ], torch .int64 , True ),
699- ])
700- def forward (self , a ):
701- return 0 * torch .empty_like (a , dtype = torch .int64 )
702-
703- @register_test_case (module_factory = lambda : EmptyLikeIntModule ())
704- def EmptyLikeModule_int (module , tu : TestUtils ):
705- module .forward (torch .randint (10 , (3 , 5 )))
706-
707- # ==============================================================================
708-
709- class EmptyLikeFloatModule (torch .nn .Module ):
710- def __init__ (self ):
711- super ().__init__ ()
712-
713- @export
714- @annotate_args ([
715- None ,
716- ([- 1 , - 1 ], torch .float32 , True ),
717- ])
718- def forward (self , a ):
719- return torch .pow (torch .empty_like (a , dtype = torch .float32 ), 0 )
720-
721- @register_test_case (module_factory = lambda : EmptyLikeFloatModule ())
722- def EmptyLikeModule_float (module , tu : TestUtils ):
723- module .forward (tu .rand (4 , 5 ))
724-
725- # ==============================================================================
726-
727589class ContiguousModule (torch .nn .Module ):
728590 def __init__ (self ):
729591 super ().__init__ ()
@@ -926,57 +788,6 @@ def DropoutModule_basic(module, tu: TestUtils):
926788 module .forward (tu .rand (3 , 4 ))
927789
928790
929- class Fill_TensorFloat64WithFloat32 (torch .nn .Module ):
930- def __init__ (self ):
931- super ().__init__ ()
932-
933- @export
934- @annotate_args ([
935- None ,
936- ([- 1 , - 1 , - 1 ], torch .float32 , True ),
937- ])
938- def forward (self , tensor ):
939- return torch .ops .aten .fill_ (tensor , 3.0 )
940-
941- @register_test_case (module_factory = lambda : Fill_TensorFloat64WithFloat32 ())
942- def Fill_TensorFloat64WithFloat32_basic (module , tu : TestUtils ):
943- module .forward (torch .randn (3 , 2 , 4 ))
944-
945-
946- class Fill_TensorFloat64WithFloat64 (torch .nn .Module ):
947- def __init__ (self ):
948- super ().__init__ ()
949-
950- @export
951- @annotate_args ([
952- None ,
953- ([- 1 , - 1 , - 1 ], torch .float64 , True ),
954- ])
955- def forward (self , tensor ):
956- return torch .ops .aten .fill_ (tensor , 3.0 )
957-
958- @register_test_case (module_factory = lambda : Fill_TensorFloat64WithFloat64 ())
959- def Fill_TensorFloat64WithFloat64_basic (module , tu : TestUtils ):
960- module .forward (torch .randn (3 , 2 , 4 ).to (torch .float64 ))
961-
962-
963- class Fill_TensorFloat64WithInt64 (torch .nn .Module ):
964- def __init__ (self ):
965- super ().__init__ ()
966-
967- @export
968- @annotate_args ([
969- None ,
970- ([- 1 , - 1 , - 1 ], torch .float64 , True ),
971- ])
972- def forward (self , tensor ):
973- return torch .ops .aten .fill_ (tensor , 3 )
974-
975- @register_test_case (module_factory = lambda : Fill_TensorFloat64WithInt64 ())
976- def Fill_TensorFloat64WithInt64_basic (module , tu : TestUtils ):
977- module .forward (torch .randn (3 , 2 , 4 ).to (torch .float64 ))
978-
979-
980791class MeanModule (torch .nn .Module ):
981792 def __init__ (self ):
982793 super ().__init__ ()
@@ -1047,86 +858,6 @@ def NumelZeroRankModule_basic(module, tu: TestUtils):
1047858 module .forward (torch .randint (10 ,[]))
1048859
1049860
1050- class ZerosModuleInt2D (torch .nn .Module ):
1051- def __init__ (self ):
1052- super ().__init__ ()
1053-
1054- @export
1055- @annotate_args ([
1056- None ,
1057- ])
1058- def forward (self ):
1059- return torch .zeros (3 , 4 , dtype = torch .int64 )
1060-
1061- @register_test_case (module_factory = lambda : ZerosModuleInt2D ())
1062- def ZerosModuleInt2D_basic (module , tu : TestUtils ):
1063- module .forward ()
1064-
1065-
1066- class ZerosModuleInt3D (torch .nn .Module ):
1067- def __init__ (self ):
1068- super ().__init__ ()
1069-
1070- @export
1071- @annotate_args ([
1072- None ,
1073- ])
1074- def forward (self ):
1075- return torch .zeros (3 , 4 , 5 , dtype = torch .int64 )
1076-
1077- @register_test_case (module_factory = lambda : ZerosModuleInt3D ())
1078- def ZerosModuleInt3D_basic (module , tu : TestUtils ):
1079- module .forward ()
1080-
1081-
1082- class ZerosModuleFloat2D (torch .nn .Module ):
1083- def __init__ (self ):
1084- super ().__init__ ()
1085-
1086- @export
1087- @annotate_args ([
1088- None ,
1089- ])
1090- def forward (self ):
1091- return torch .zeros (3 , 4 , dtype = torch .float32 )
1092-
1093- @register_test_case (module_factory = lambda : ZerosModuleFloat2D ())
1094- def ZerosModuleFloat2D_basic (module , tu : TestUtils ):
1095- module .forward ()
1096-
1097-
1098- class ZerosModuleFloat3D (torch .nn .Module ):
1099- def __init__ (self ):
1100- super ().__init__ ()
1101-
1102- @export
1103- @annotate_args ([
1104- None ,
1105- ])
1106- def forward (self ):
1107- return torch .zeros (3 , 4 , 5 , dtype = torch .float32 )
1108-
1109- @register_test_case (module_factory = lambda : ZerosModuleFloat3D ())
1110- def ZerosModuleFloat3D_basic (module , tu : TestUtils ):
1111- module .forward ()
1112-
1113-
1114- class ZerosModuleFalsePinMemory (torch .nn .Module ):
1115- def __init__ (self ):
1116- super ().__init__ ()
1117-
1118- @export
1119- @annotate_args ([
1120- None ,
1121- ])
1122- def forward (self ):
1123- return torch .zeros (3 , 4 , dtype = torch .float32 , pin_memory = False )
1124-
1125- @register_test_case (module_factory = lambda : ZerosModuleFalsePinMemory ())
1126- def ZerosModuleFalsePinMemory_basic (module , tu : TestUtils ):
1127- module .forward ()
1128-
1129-
1130861class BoolTensorReturnFalseModule (torch .nn .Module ):
1131862 def __init__ (self ):
1132863 super ().__init__ ()
@@ -1181,6 +912,7 @@ def BoolTensorReturnMixedModule_basic(module, tu: TestUtils):
1181912 module .forward (torch .tensor ([[1 , 0 ], [0 ,1 ]], dtype = torch .bool ))
1182913
1183914# ==============================================================================
915+
1184916class TModuleRank2 (torch .nn .Module ):
1185917 def __init__ (self ):
1186918 super ().__init__ ()
@@ -1193,11 +925,11 @@ def __init__(self):
1193925 def forward (self , lhs ):
1194926 return torch .t (lhs )
1195927
1196-
1197928@register_test_case (module_factory = lambda : TModuleRank2 ())
1198929def TModuleRank2_basic (module , tu : TestUtils ):
1199930 module .forward (tu .rand (3 , 4 ))
1200931
932+
1201933class TModuleRank1 (torch .nn .Module ):
1202934 def __init__ (self ):
1203935 super ().__init__ ()
@@ -1210,11 +942,11 @@ def __init__(self):
1210942 def forward (self , lhs ):
1211943 return torch .t (lhs )
1212944
1213-
1214945@register_test_case (module_factory = lambda : TModuleRank1 ())
1215946def TModuleRank1_basic (module , tu : TestUtils ):
1216947 module .forward (tu .rand (3 ))
1217948
949+
1218950class TModuleRank0 (torch .nn .Module ):
1219951 def __init__ (self ):
1220952 super ().__init__ ()
@@ -1227,7 +959,6 @@ def __init__(self):
1227959 def forward (self , lhs ):
1228960 return torch .t (lhs )
1229961
1230-
1231962@register_test_case (module_factory = lambda : TModuleRank0 ())
1232963def TModuleRank0_basic (module , tu : TestUtils ):
1233964 module .forward (torch .tensor (7 , dtype = torch .float32 ))
0 commit comments