@@ -62,6 +62,28 @@ def ElementwiseBinaryModule_basic(module, tu: TestUtils):
6262# ==============================================================================
6363
6464
65+ class ElementwiseBinaryStaticShapeModule (torch .nn .Module ):
66+ def __init__ (self ):
67+ super ().__init__ ()
68+
69+ @export
70+ @annotate_args ([
71+ None ,
72+ ([5 , 4 , 3 , 3 , 1 ], torch .float32 , True ),
73+ ([4 , 3 , 1 , 2 ], torch .float32 , True ),
74+ ])
75+ def forward (self , a , b ):
76+ return a * b
77+
78+ @register_test_case (
79+ module_factory = lambda : ElementwiseBinaryStaticShapeModule ())
80+ def ElementwiseBinaryStaticShapeModule_basic (module , tu : TestUtils ):
81+ module .forward (tu .rand (5 , 4 , 3 , 3 , 1 ), tu .rand (4 , 3 , 1 , 2 ))
82+
83+
84+ # ==============================================================================
85+
86+
6587class ElementwiseTernaryModule (torch .nn .Module ):
6688 def __init__ (self ):
6789 super ().__init__ ()
@@ -171,8 +193,7 @@ def forward(self, a):
171193 return torch .unsqueeze (a , - 3 )
172194
173195
174- @register_test_case (
175- module_factory = lambda : ElementwiseUnsqueezeNegDimsModule ())
196+ @register_test_case (module_factory = lambda : ElementwiseUnsqueezeNegDimsModule ())
176197def ElementwiseUnsqueezeNegDimsModule_basic (module , tu : TestUtils ):
177198 module .forward (tu .rand (4 , 3 ))
178199
@@ -255,7 +276,7 @@ def forward(self, x):
255276
256277@register_test_case (module_factory = lambda : ElementwiseGeluModule ())
257278def ElementwiseGeluModule_basic (module , tu : TestUtils ):
258- module .forward (2 * tu .rand (5 , 3 ) - 0.5 )
279+ module .forward (2 * tu .rand (5 , 3 ) - 0.5 )
259280
260281
261282# ==============================================================================
@@ -359,7 +380,7 @@ def forward(self, x):
359380
360381@register_test_case (module_factory = lambda : ElementwiseGtIntScalarModule ())
361382def ElementwiseGtIntScalarModule_basic (module , tu : TestUtils ):
362- module .forward (torch .randint (- 10 , 15 , (3 ,4 )))
383+ module .forward (torch .randint (- 10 , 15 , (3 , 4 )))
363384
364385
365386class ElementwiseGtMixed2ScalarModule (torch .nn .Module ):
@@ -377,7 +398,7 @@ def forward(self, x):
377398
378399@register_test_case (module_factory = lambda : ElementwiseGtMixed2ScalarModule ())
379400def ElementwiseGtMixed2ScalarModule_basic (module , tu : TestUtils ):
380- module .forward (torch .randint (- 10 , 15 , (3 ,4 )).to (torch .int32 ))
401+ module .forward (torch .randint (- 10 , 15 , (3 , 4 )).to (torch .int32 ))
381402
382403
383404class ElementwiseGtFloatTensorModule (torch .nn .Module ):
@@ -415,10 +436,12 @@ def forward(self, x, y):
415436
416437@register_test_case (module_factory = lambda : ElementwiseGtIntTensorModule ())
417438def ElementwiseGtIntTensorModule_basic (module , tu : TestUtils ):
418- module .forward (torch .randint (10 , (3 , 5 )), torch .randint (10 , (5 ,)))
439+ module .forward (torch .randint (10 , (3 , 5 )), torch .randint (10 , (5 , )))
440+
419441
420442# ==============================================================================
421443
444+
422445class ElementwiseLtFloatScalarModule (torch .nn .Module ):
423446 def __init__ (self ):
424447 super ().__init__ ()
@@ -452,7 +475,7 @@ def forward(self, x):
452475
453476@register_test_case (module_factory = lambda : ElementwiseLtIntScalarModule ())
454477def ElementwiseLtIntScalarModule_basic (module , tu : TestUtils ):
455- module .forward (torch .randint (- 10 , 15 , (3 ,4 )))
478+ module .forward (torch .randint (- 10 , 15 , (3 , 4 )))
456479
457480
458481class ElementwiseLtDiffWidthScalarModule (torch .nn .Module ):
@@ -468,9 +491,10 @@ def forward(self, x):
468491 return torch .lt (x , 2 )
469492
470493
471- @register_test_case (module_factory = lambda : ElementwiseLtDiffWidthScalarModule ())
494+ @register_test_case (
495+ module_factory = lambda : ElementwiseLtDiffWidthScalarModule ())
472496def ElementwiseLtDiffWidthScalarModule_basic (module , tu : TestUtils ):
473- module .forward (torch .randint (- 10 , 15 , (3 ,4 )).to (torch .int32 ))
497+ module .forward (torch .randint (- 10 , 15 , (3 , 4 )).to (torch .int32 ))
474498
475499
476500class ElementwiseLtFloatTensorModule (torch .nn .Module ):
@@ -508,10 +532,12 @@ def forward(self, x, y):
508532
509533@register_test_case (module_factory = lambda : ElementwiseLtIntTensorModule ())
510534def ElementwiseLtIntTensorModule_basic (module , tu : TestUtils ):
511- module .forward (torch .randint (10 , (3 , 5 )), torch .randint (10 , (5 ,)))
535+ module .forward (torch .randint (10 , (3 , 5 )), torch .randint (10 , (5 , )))
536+
512537
513538# ==============================================================================
514539
540+
515541class ElementwiseEqFloatScalarModule (torch .nn .Module ):
516542 def __init__ (self ):
517543 super ().__init__ ()
@@ -527,8 +553,8 @@ def forward(self, x):
527553
528554@register_test_case (module_factory = lambda : ElementwiseEqFloatScalarModule ())
529555def ElementwiseEqFloatScalarModule_basic (module , tu : TestUtils ):
530- module .forward (torch . tensor ([[ 1.0 , 2.2 , 6.0 ], [ 6.0 , 2.0 , 3.1 ]])
531- .to (torch .float32 ))
556+ module .forward (
557+ torch . tensor ([[ 1.0 , 2.2 , 6.0 ], [ 6.0 , 2.0 , 3.1 ]]) .to (torch .float32 ))
532558
533559
534560class ElementwiseEqIntScalarModule (torch .nn .Module ):
@@ -546,7 +572,7 @@ def forward(self, x):
546572
547573@register_test_case (module_factory = lambda : ElementwiseEqIntScalarModule ())
548574def ElementwiseEqIntScalarModule_basic (module , tu : TestUtils ):
549- module .forward (torch .randint (2 , 4 , (5 ,8 )))
575+ module .forward (torch .randint (2 , 4 , (5 , 8 )))
550576
551577
552578class ElementwiseEqDiffWidthScalarModule (torch .nn .Module ):
@@ -562,9 +588,10 @@ def forward(self, x):
562588 return torch .eq (x , 2 )
563589
564590
565- @register_test_case (module_factory = lambda : ElementwiseEqDiffWidthScalarModule ())
591+ @register_test_case (
592+ module_factory = lambda : ElementwiseEqDiffWidthScalarModule ())
566593def ElementwiseEqDiffWidthScalarModule_basic (module , tu : TestUtils ):
567- module .forward (torch .randint (2 , 4 , (5 ,8 )).to (torch .int32 ))
594+ module .forward (torch .randint (2 , 4 , (5 , 8 )).to (torch .int32 ))
568595
569596
570597class ElementwiseEqFloatTensorModule (torch .nn .Module ):
@@ -583,9 +610,9 @@ def forward(self, x, y):
583610
584611@register_test_case (module_factory = lambda : ElementwiseEqFloatTensorModule ())
585612def ElementwiseEqFloatTensorModule_basic (module , tu : TestUtils ):
586- module .forward (torch . tensor ([[ 1.0 , 2.2 , 6.0 ], [ 6.0 , 2.0 , 3.1 ]])
587- . to (torch .float32 ),
588- torch .tensor ([1.0 , 2.4 , 6.0 ]).to (torch .float32 ))
613+ module .forward (
614+ torch . tensor ([[ 1.0 , 2.2 , 6.0 ], [ 6.0 , 2.0 , 3.1 ]]). to (torch .float32 ),
615+ torch .tensor ([1.0 , 2.4 , 6.0 ]).to (torch .float32 ))
589616
590617
591618class ElementwiseEqIntTensorModule (torch .nn .Module ):
@@ -604,10 +631,12 @@ def forward(self, x, y):
604631
605632@register_test_case (module_factory = lambda : ElementwiseEqIntTensorModule ())
606633def ElementwiseEqIntTensorModule_basic (module , tu : TestUtils ):
607- module .forward (torch .randint (2 , 4 , (8 , 5 )), torch .randint (2 , 4 , (5 ,)))
634+ module .forward (torch .randint (2 , 4 , (8 , 5 )), torch .randint (2 , 4 , (5 , )))
635+
608636
609637# ==============================================================================
610638
639+
611640class ElementwiseClampModule (torch .nn .Module ):
612641 def __init__ (self ):
613642 super ().__init__ ()
@@ -666,7 +695,7 @@ def forward(self, x):
666695@register_test_case (module_factory = lambda : RsubModule_noalpha ())
667696def RsubModule_noalpha_basic (module , tu : TestUtils ):
668697 module .forward (tu .rand (3 , 4 ))
669-
698+
670699# ==============================================================================
671700
672701class ElementwiseMulScalarIntModule (torch .nn .Module ):
@@ -734,12 +763,10 @@ def forward(self, a, b):
734763 return torch .mul (a , b )
735764
736765
737- @register_test_case (
738- module_factory = lambda : ElementwiseMulTensorFloatModule ())
766+ @register_test_case (module_factory = lambda : ElementwiseMulTensorFloatModule ())
739767def ElementwiseMulTensorFloatModule_basic (module , tu : TestUtils ):
740- module .forward (
741- tu .rand (4 ),
742- tu .rand (4 ).type (torch .float64 ))
768+ module .forward (tu .rand (4 ), tu .rand (4 ).type (torch .float64 ))
769+
743770
744771class ElementwiseMulTensorIntModule (torch .nn .Module ):
745772 def __init__ (self ):
@@ -755,12 +782,10 @@ def forward(self, a, b):
755782 return torch .mul (a , b )
756783
757784
758- @register_test_case (
759- module_factory = lambda : ElementwiseMulTensorIntModule ())
785+ @register_test_case (module_factory = lambda : ElementwiseMulTensorIntModule ())
760786def ElementwiseMulTensorIntModule_basic (module , tu : TestUtils ):
761787 module .forward (
762- torch .randint (10 , [4 ]).type (torch .int32 ),
763- torch .randint (10 , [4 ]))
788+ torch .randint (10 , [4 ]).type (torch .int32 ), torch .randint (10 , [4 ]))
764789
765790
766791# ==============================================================================
@@ -783,7 +808,7 @@ def ElementwiseLogModule_basic(module, tu: TestUtils):
783808
784809
785810class ElementwiseSqrtModule (torch .nn .Module ):
786- def __init__ (self ):
811+ def __init__ (self ):
787812 super ().__init__ ()
788813
789814 @export
@@ -898,7 +923,7 @@ def ElementwiseLog2Module_basic(module, tu: TestUtils):
898923 module .forward (tu .rand (3 , 4 ))
899924
900925class ElementwiseRsqrtModule (torch .nn .Module ):
901- def __init__ (self ):
926+ def __init__ (self ):
902927 super ().__init__ ()
903928
904929 @export
@@ -984,12 +1009,9 @@ def forward(self, a, b):
9841009 return torch .div (a , b )
9851010
9861011
987- @register_test_case (
988- module_factory = lambda : ElementwiseDivTensorFloatModule ())
1012+ @register_test_case (module_factory = lambda : ElementwiseDivTensorFloatModule ())
9891013def ElementwiseDivTensorFloatModule_basic (module , tu : TestUtils ):
990- module .forward (
991- tu .rand (4 ),
992- tu .rand (4 ).type (torch .float64 ))
1014+ module .forward (tu .rand (4 ), tu .rand (4 ).type (torch .float64 ))
9931015
9941016
9951017# ==============================================================================
@@ -1005,15 +1027,15 @@ def __init__(self):
10051027 ([- 1 , - 1 ], torch .int32 , True ),
10061028 ([- 1 , - 1 ], torch .int64 , True ),
10071029 ])
1008-
10091030 def forward (self , x , y ):
10101031 return torch .bitwise_and (x , y )
10111032
10121033
10131034@register_test_case (module_factory = lambda : ElementwiseAndIntegerModule ())
10141035def ElementwiseAndIntegerModule_basic (module , tu : TestUtils ):
1015- module .forward (torch .randint (- 10 , 10 , (3 , 4 )).to (torch .int32 ),
1016- torch .randint (- 10 , 10 , (3 , 4 )))
1036+ module .forward (
1037+ torch .randint (- 10 , 10 , (3 , 4 )).to (torch .int32 ),
1038+ torch .randint (- 10 , 10 , (3 , 4 )))
10171039
10181040
10191041class ElementwiseSubScalarIntModule (torch .nn .Module ):
@@ -1026,7 +1048,8 @@ def __init__(self):
10261048 ([- 1 , - 1 ], torch .int64 , True ),
10271049 ])
10281050 def forward (self , x ):
1029- return torch .sub (x , 2.1 , alpha = 2 )
1051+ return torch .sub (x , 2.1 , alpha = 2 )
1052+
10301053
10311054@register_test_case (module_factory = lambda : ElementwiseSubScalarIntModule ())
10321055def ElementwiseSubScalarIntModule_basic (module , tu : TestUtils ):
@@ -1077,7 +1100,8 @@ def __init__(self):
10771100 ([- 1 , - 1 ], torch .float32 , True ),
10781101 ])
10791102 def forward (self , x ):
1080- return torch .add (x , 3.0 , alpha = 2 )
1103+ return torch .add (x , 3.0 , alpha = 2 )
1104+
10811105
10821106@register_test_case (module_factory = lambda : ElementwiseAddScalarFloatModule ())
10831107def ElementwiseAddScalarFloatModule_basic (module , tu : TestUtils ):
0 commit comments