@@ -612,21 +612,90 @@ def test_binary_op_elementwise(op):
612612 inputs = [torch .randn (1 , 3 , 3 ).cuda (), torch .randn (1 , 3 , 3 ).cuda ()]
613613 cross_validate (module , inputs , fp16_mode = False , tol = 1e-1 )
614614
615- @pytest .mark .parametrize ("op" , ["min" ,"max" ,"mean" ,"sum" ])
616- def test_reduce_op (op ):
617- if op == "max" :
618- fn = lambda x : torch .max (x )
619- elif op == "min" :
620- fn = lambda x : torch .min (x )
621- elif op == "mean" :
622- fn = lambda x : torch .mean (x )
623- elif op == "sum" :
624- fn = lambda x : torch .sum (x )
625-
626- module = UnaryModule (fn ).cuda ().eval ()
615+ unary_1d_randn_ops = {
616+ "torch.max" : lambda x : torch .max (x ),
617+ "torch.Tensor.max" : lambda x : x .max (),
618+ "torch.min" : lambda x : torch .min (x ),
619+ "torch.Tensor.min" : lambda x : x .min (),
620+ "torch.mean" : lambda x : torch .mean (x ),
621+ "torch.Tensor.mean" : lambda x : x .mean (),
622+ "torch.sum" : lambda x : torch .sum (x ),
623+ "torch.Tensor.sum" : lambda x : x .sum (),
624+ "torch.prod" : lambda x : torch .prod (x ),
625+ "torch.Tensor.prod" : lambda x : x .prod (),
626+ "torch.relu" : lambda x : torch .relu (x ),
627+ "torch.nn.functional.relu" : lambda x : torch .nn .functional .relu (x ),
628+ "torch.Tensor.relu" : lambda x : x .relu (),
629+ "torch.nn.functional.relu6" : lambda x : torch .nn .functional .relu6 (x ),
630+ "torch.sigmoid" : lambda x : torch .sigmoid (x ),
631+ "torch.nn.functional.sigmoid" : lambda x : torch .nn .functional .sigmoid (x ),
632+ "torch.Tensor.sigmoid" : lambda x : x .sigmoid (),
633+ "torch.nn.functional.silu" : lambda x : torch .nn .functional .silu (x ),
634+ "torch.Tensor.softmax" : lambda x : x .softmax (1 ),
635+ "torch.nn.functional.softmax" : lambda x : torch .nn .functional .softmax (x , 1 ),
636+ "torch.Tensor.squeeze" : lambda x : x .squeeze (),
637+ "torch.squeeze" : lambda x : torch .squeeze (x ),
638+ "torch.stack" : lambda x : torch .stack ([x , x ], dim = 1 ),
639+ "torch.sub" : lambda x : torch .sub (x , x ),
640+ "torch.Tensor.__sub__" : lambda x : x - x ,
641+ "torch.Tensor.__rsub__[int]" : lambda x : 1 - x ,
642+ "torch.Tensor.__rsub__[float]" : lambda x : 1.0 - x ,
643+ "torch.tanh" : lambda x : torch .tanh (x ),
644+ "torch.nn.functional.tanh" : lambda x : torch .nn .functional .tanh (x ),
645+ "torch.tensor" : lambda x : torch .tensor (x ),
646+ "torch.Tensor.transpose" : lambda x : x .transpose (1 , 2 ),
647+ "torch.transpose" : lambda x : torch .transpose (x , 1 , 2 ),
648+ "torch.exp" : lambda x : torch .exp (x ),
649+ "torch.Tensor.exp" : lambda x : x .exp (),
650+ "torch.abs" : lambda x : torch .abs (x ),
651+ "torch.Tensor.abs" : lambda x : x .abs (),
652+ "torch.neg" : lambda x : torch .neg (x ),
653+ "torch.Tensor.neg" : lambda x : - x ,
654+ "torch.sin" : lambda x : torch .sin (x ),
655+ "torch.Tensor.sin" : lambda x : x .sin (),
656+ "torch.cos" : lambda x : torch .cos (x ),
657+ "torch.Tensor.cos" : lambda x : x .cos (),
658+ "torch.sinh" : lambda x : torch .sinh (x ),
659+ "torch.Tensor.sinh" : lambda x : x .sinh (),
660+ "torch.cosh" : lambda x : torch .cosh (x ),
661+ "torch.Tensor.cosh" : lambda x : x .cosh (),
662+ "torch.atan" : lambda x : torch .atan (x ),
663+ "torch.Tensor.atan" : lambda x : x .atan (),
664+ "torch.ceil" : lambda x : torch .ceil (x ),
665+ "torch.Tensor.ceil" : lambda x : x .ceil (),
666+ "torch.floor" : lambda x : torch .floor (x ),
667+ "torch.Tensor.floor" : lambda x : x .floor ()
668+ }
669+
670+ @pytest .mark .parametrize ("op" , unary_1d_randn_ops .keys ())
671+ def test_unary_1d_randn (op ):
672+ module = UnaryModule (unary_1d_randn_ops [op ]).cuda ().eval ()
627673 inputs = [torch .randn (1 , 3 , 3 ).cuda ()]
628674 cross_validate (module , inputs , fp16_mode = False , tol = 1e-1 )
629675
676+
677+ unary_1d_positive_ops = {
678+ "torch.log" : lambda x : torch .log (x ),
679+ "torch.Tensor.log" : lambda x : x .log (),
680+ "torch.sqrt" : lambda x : torch .sqrt (x ),
681+ "torch.Tensor.sqrt" : lambda x : x .sqrt (),
682+ "torch.reciprocal" : lambda x : torch .reciprocal (x ),
683+ "torch.Tensor.reciprocal" : lambda x : x .reciprocal (),
684+ "torch.tan" : lambda x : torch .tan (x ),
685+ "torch.Tensor.tan" : lambda x : x .tan (),
686+ "torch.asin" : lambda x : torch .asin (x ),
687+ "torch.Tensor.asin" : lambda x : x .asin (),
688+ "torch.acos" : lambda x : torch .acos (x ),
689+ "torch.Tensor.acos" : lambda x : x .acos (),
690+ }
691+
692+ @pytest .mark .parametrize ("op" , unary_1d_positive_ops .keys ())
693+ def test_unary_1d_ones (op ):
694+ module = UnaryModule (unary_1d_positive_ops [op ]).cuda ().eval ()
695+ inputs = [0.5 * torch .ones (1 , 3 , 3 ).cuda ()]
696+ cross_validate (module , inputs , fp16_mode = False , tol = 1e-1 )
697+
698+
630699@pytest .mark .parametrize ("op" , ["mul" , "__mul__" , "__rmul__" ])
631700@pytest .mark .parametrize ("scalar" , [2 , 2. ])
632701def test_mul_scalar (op , scalar ):
@@ -699,4 +768,36 @@ def test_permute(permutation):
699768 sizes = [i + 1 for i in range (len (permutation ))]
700769
701770 inputs = [torch .randn (* sizes ).cuda ()]
702- cross_validate (module , inputs , fp16_mode = False , tol = 1e-1 )
771+ cross_validate (module , inputs , fp16_mode = False , tol = 1e-1 )
772+
773+ @pytest .mark .parametrize ("op" , [
774+ "torch.pow" ,
775+ "torch.Tensor.__ipow__" ,
776+ "torch.Tensor.__pow__" ,
777+ "torch.Tensor.__rpow__"
778+ ])
779+ @pytest .mark .parametrize ("scalar" , [2 , 2. ])
780+ def test_scalar_op (op , scalar ):
781+ if op == "torch.pow" :
782+ fn = lambda x : torch .pow (x , scalar )
783+ elif op == "torch.Tensor.__ipow__" :
784+ def ipow (x ):
785+ x **= scalar
786+ return x
787+ fn = ipow
788+ elif op == "torch.Tensor.__pow__" :
789+ fn = lambda x : x ** scalar
790+ elif op == "torch.Tensor.__rpow__" :
791+ fn = lambda x : scalar ** x
792+
793+
794+ module = UnaryModule (fn ).cuda ().eval ()
795+ inputs = [torch .randn (1 , 2 ).cuda ()]
796+ cross_validate (module , inputs , fp16_mode = False , tol = 1e-1 )
797+
798+
799+ def test_prelu ():
800+ module = nn .PReLU (4 ).cuda ().eval ()
801+ inputs = [torch .randn (1 , 4 , 3 , 3 ).cuda ()]
802+ cross_validate (module , inputs , fp16_mode = False , tol = 1e-1 )
803+
0 commit comments