@@ -598,3 +598,105 @@ def test_max_pool(nd, kernel_size, stride, padding, dilation, ceil_mode):
598598 inputs = [torch .randn (* input_size ).cuda ()]
599599 cross_validate (module , inputs , fp16_mode = False , tol = 1e-1 )
600600
601+ @pytest .mark .parametrize ("op" , ["min" ,"max" , "fmod" ])
602+ def test_binary_op_elementwise (op ):
603+ if op == "max" :
604+ fn = lambda x , y : torch .max (x , y )
605+ elif op == "min" :
606+ fn = lambda x , y : torch .min (x , y )
607+ elif op == "fmod" :
608+ fn = lambda x , y : torch .fmod (x , y )
609+
610+
611+ module = BinaryModule (fn ).cuda ().eval ()
612+ inputs = [torch .randn (1 , 3 , 3 ).cuda (), torch .randn (1 , 3 , 3 ).cuda ()]
613+ cross_validate (module , inputs , fp16_mode = False , tol = 1e-1 )
614+
615+ @pytest .mark .parametrize ("op" , ["min" ,"max" ,"mean" ,"sum" ])
616+ def test_reduce_op (op ):
617+ if op == "max" :
618+ fn = lambda x : torch .max (x )
619+ elif op == "min" :
620+ fn = lambda x : torch .min (x )
621+ elif op == "mean" :
622+ fn = lambda x : torch .mean (x )
623+ elif op == "sum" :
624+ fn = lambda x : torch .sum (x )
625+
626+ module = UnaryModule (fn ).cuda ().eval ()
627+ inputs = [torch .randn (1 , 3 , 3 ).cuda ()]
628+ cross_validate (module , inputs , fp16_mode = False , tol = 1e-1 )
629+
630+ @pytest .mark .parametrize ("op" , ["mul" , "__mul__" , "__rmul__" ])
631+ @pytest .mark .parametrize ("scalar" , [2 , 2. ])
632+ def test_mul_scalar (op , scalar ):
633+
634+ op_map = {
635+ "mul" : lambda x : torch .mul (x , scalar ),
636+ "__mul__" : lambda x : x * scalar ,
637+ "__rmul__" : lambda x : scalar * x
638+ }
639+
640+ module = UnaryModule (op_map [op ]).cuda ().eval ()
641+
642+ inputs = [torch .randn (1 , 3 , 3 ).cuda ()]
643+
644+ cross_validate (module , inputs , fp16_mode = False , tol = 1e-1 )
645+
646+ @pytest .mark .parametrize ("dim,start,length" , [
647+ (0 , 0 , 2 ),
648+ (1 , 1 , 2 ),
649+ (- 1 , - 1 , 1 )
650+ ])
651+ def test_narrow (dim , start , length ):
652+ module = UnaryModule (lambda x : torch .narrow (x , dim , start , length )).cuda ().eval ()
653+
654+ inputs = [torch .randn (3 , 3 ).cuda ()]
655+ cross_validate (module , inputs , fp16_mode = False , tol = 1e-1 )
656+
657+
658+ def test_ne_binary ():
659+ module = BinaryModule (lambda x , y : x != y ).cuda ().eval ()
660+ inputs = [torch .zeros (1 , 3 , 3 ).cuda ()]
661+ inputs .append (inputs [0 ].clone ())
662+ inputs [0 ][0 ] = 1
663+
664+ cross_validate (module , inputs , fp16_mode = False , tol = 1e-1 )
665+
666+
667+ @pytest .mark .parametrize ("p" , [1 , 2 ])
668+ @pytest .mark .parametrize ("dim" , [1 , 2 ])
669+ def test_normalize (p , dim ):
670+ module = UnaryModule (lambda x : torch .nn .functional .normalize (x , p , dim )).cuda ().eval ()
671+ inputs = [torch .zeros (1 , 3 , 3 ).cuda ()]
672+ cross_validate (module , inputs , fp16_mode = False , tol = 1e-1 )
673+
674+
675+ @pytest .mark .parametrize ("pad,mode,value" , [
676+ ((1 , 1 ), "constant" , 0. ),
677+ ((1 , 1 , 2 , 2 ), "constant" , 0. ),
678+ ((0 , 1 , 2 , 1 , 3 , 3 ), "constant" , 0. ),
679+ ])
680+ def test_pad (pad , mode , value ):
681+ module = UnaryModule (
682+ lambda x : torch .nn .functional .pad (
683+ x , pad , mode , value
684+ )
685+ ).cuda ().eval ()
686+ inputs = [torch .randn (3 , 3 , 4 , 2 ).cuda ()]
687+ cross_validate (module , inputs , fp16_mode = False , tol = 1e-1 )
688+
689+
690+ @pytest .mark .parametrize ("permutation" , [
691+ (0 , 2 , 1 ),
692+ (0 , 2 , 1 , 3 )
693+ ])
694+ def test_permute (permutation ):
695+
696+ module = UnaryModule (
697+ lambda x : x .permute (* permutation )
698+ ).cuda ().eval ()
699+ sizes = [i + 1 for i in range (len (permutation ))]
700+
701+ inputs = [torch .randn (* sizes ).cuda ()]
702+ cross_validate (module , inputs , fp16_mode = False , tol = 1e-1 )
0 commit comments