@@ -877,73 +877,124 @@ def test_quantized_conv_per_tensor(
877877 @expand (
878878 [
879879 # Test case 1: Basic int8 case with negative scale
880- (
881- "basic_int8" ,
882- torch .tensor ([- 1 , 0 , 1 , 3 ], dtype = torch .int8 ), # input
883- torch .tensor ([0 ], dtype = torch .int8 ), # X_zero_point (scalar broadcast)
884- 0 , # out_zero_point
885- torch .tensor ([1073741824 ]), # out_multiplier (0.5 * 2^31)
886- torch .tensor ([0 ]), # out_shift
887- torch .int8 , # dtype
888- torch .tensor (
889- [0 , 0 , 0 , - 2 ], dtype = torch .int8
890- ), # expected: relu(-1,0,1,3) = (0,0,1,3) * (-0.5) + 0 = (0,0,-0.5,-1.5) -> (0,0,0,-2)
891- ),
880+ * [
881+ (
882+ "basic_int8" ,
883+ torch .tensor ([- 1 , 0 , 1 , 3 ], dtype = dtype ), # input
884+ 0 , # X_zero_point (scalar broadcast)
885+ 0 , # out_zero_point
886+ 1073741824 , # out_multiplier (0.5 * 2^31)
887+ 0 , # out_shift
888+ dtype , # dtype
889+ torch .tensor (
890+ [0 , 0 , 0 , - 2 ], dtype = dtype
891+ ), # expected: relu(-1,0,1,3) = (0,0,1,3) * (-0.5) + 0 = (0,0,-0.5,-1.5) -> (0,0,0,-2)
892+ )
893+ for dtype in [torch .int8 ]
894+ ],
892895 # Test case 2: uint8 with non-zero zero point
893- (
894- "uint8_with_zp" ,
895- torch .tensor ([126 , 128 , 130 , 132 ], dtype = torch .uint8 ), # input
896- torch .tensor ([128 ], dtype = torch .uint8 ), # X_zero_point
897- 64 , # out_zero_point
898- torch .tensor ([536870912 ]), # out_multiplier (0.25 * 2^31)
899- torch .tensor ([0 ]), # out_shift
900- torch .uint8 , # dtype
901- torch .tensor (
902- [64 , 64 , 64 , 63 ], dtype = torch .uint8
903- ), # expected: relu(-2,0,2,4) = (0,0,2,4) * (-0.25) + 64 = (64,64,63.5,63) -> (64,64,64,63)
904- ),
896+ * [
897+ (
898+ "uint8_with_zp" ,
899+ torch .tensor ([126 , 128 , 130 , 132 ], dtype = dtype ), # input
900+ 128 , # X_zero_point
901+ 64 , # out_zero_point
902+ 536870912 , # out_multiplier (0.25 * 2^31)
903+ 0 , # out_shift
904+ dtype , # dtype
905+ torch .tensor (
906+ [64 , 64 , 64 , 63 ], dtype = dtype
907+ ), # expected: relu(-2,0,2,4) = (0,0,2,4) * (-0.25) + 64 = (64,64,63.5,63) -> (64,64,64,63)
908+ )
909+ for dtype in [torch .uint8 ]
910+ ],
905911 # Test case 3: All negative values (should all become zero after ReLU)
906- (
907- "all_negative_int8" ,
908- torch .tensor ([- 5 , - 3 , - 1 ], dtype = torch .int8 ), # input
909- torch .tensor ([0 ], dtype = torch .int8 ), # X_zero_point
910- 10 , # out_zero_point
911- torch .tensor ([1073741824 ]), # out_multiplier (0.5 * 2^31)
912- torch .tensor ([0 ]), # out_shift
913- torch .int8 , # dtype
914- torch .tensor (
915- [10 , 10 , 10 ], dtype = torch .int8
916- ), # expected: relu(-5,-3,-1) = (0,0,0) * (-0.5) + 10 = (10,10,10)
917- ),
912+ * [
913+ (
914+ "all_negative_int8" ,
915+ torch .tensor ([- 5 , - 3 , - 1 ], dtype = dtype ), # input
916+ 0 , # X_zero_point
917+ 10 , # out_zero_point
918+ 1073741824 , # out_multiplier (0.5 * 2^31)
919+ 0 , # out_shift
920+ dtype , # dtype
921+ torch .tensor (
922+ [10 , 10 , 10 ], dtype = dtype
923+ ), # expected: relu(-5,-3,-1) = (0,0,0) * (-0.5) + 10 = (10,10,10)
924+ )
925+ for dtype in [torch .int8 ]
926+ ],
918927 # Test case 4: All positive values with shift (scale becomes -0.25)
919- (
920- "positive_with_shift" ,
921- torch .tensor ([2 , 4 , 6 , 8 ], dtype = torch .int8 ), # input
922- torch .tensor ([1 ], dtype = torch .int8 ), # X_zero_point
923- 5 , # out_zero_point
924- torch .tensor ([1073741824 ]), # out_multiplier (0.5 * 2^31)
925- torch .tensor ([1 ]), # out_shift (multiply by 2^1 = 2)
926- torch .int8 , # dtype
927- torch .tensor (
928- [4 , 2 , 0 , - 2 ], dtype = torch .int8
929- ), # expected: relu(1,3,5,7) = (1,3,5,7) * (-1.0) + 5 = (4,2,0,-2)
930- ),
928+ * [
929+ (
930+ "positive_with_shift" ,
931+ torch .tensor ([2 , 4 , 6 , 8 ], dtype = dtype ), # input
932+ 1 , # X_zero_point
933+ 5 , # out_zero_point
934+ 1073741824 , # out_multiplier (0.5 * 2^31)
935+ 1 , # out_shift (multiply by 2^1 = 2)
936+ dtype , # dtype
937+ torch .tensor (
938+ [4 , 2 , 0 , - 2 ], dtype = dtype
939+ ), # expected: relu(1,3,5,7) = (1,3,5,7) * (-1.0) + 5 = (4,2,0,-2)
940+ )
941+ for dtype in [torch .int8 , torch .uint8 ]
942+ ],
943+ # Test case 4: Non-per-tensor
944+ * [
945+ (
946+ "non_per_tensor" ,
947+ torch .tensor ([- 1 , - 2 , - 3 , 1 , 2 , 3 ], dtype = dtype ), # input
948+ torch .tensor ([0 , 0 , 0 , 1 , 1 , 1 ]), # X_zero_point
949+ 5 , # out_zero_point
950+ torch .tensor ([1073741824 ]), # out_multiplier (0.5 * 2^31)
951+ torch .tensor ([1 ]), # out_shift (multiply by 2^1 = 2)
952+ dtype , # dtype
953+ torch .tensor ([5 , 5 , 5 , 5 , 4 , 3 ], dtype = dtype ),
954+ )
955+ for dtype in [torch .int8 ]
956+ ],
931957 ]
932958 )
933959 def test_quantized_relu (
934960 self ,
935961 name : str ,
936962 X : torch .Tensor ,
937- X_zero_point : torch .Tensor ,
963+ X_zero_point : torch .Tensor | int ,
938964 out_zero_point : int ,
939- out_multiplier : torch .Tensor ,
940- out_shift : torch .Tensor ,
965+ out_multiplier : torch .Tensor | int ,
966+ out_shift : torch .Tensor | int ,
941967 dtype : torch .dtype ,
942968 expected_output : torch .Tensor ,
943969 ) -> None :
944- output = torch .ops .cadence .quantized_relu (
945- X , X_zero_point , out_zero_point , out_multiplier , out_shift
946- )
970+
971+ if isinstance (X_zero_point , int ):
972+ assert isinstance (out_multiplier , int )
973+ assert isinstance (out_shift , int )
974+
975+ match dtype :
976+ case torch .int8 :
977+ quantized_relu = (
978+ torch .ops .cadence .quantized_relu_asym8s_asym8s .per_tensor
979+ )
980+ case torch .uint8 :
981+ quantized_relu = (
982+ torch .ops .cadence .quantized_relu_asym8u_asym8u .per_tensor
983+ )
984+ case _:
985+ quantized_relu = torch .ops .cadence .quantized_relu_per_tensor
986+
987+ output = quantized_relu (
988+ X ,
989+ X_zero_point ,
990+ out_zero_point ,
991+ out_multiplier ,
992+ out_shift ,
993+ )
994+ else :
995+ output = torch .ops .cadence .quantized_relu (
996+ X , X_zero_point , out_zero_point , out_multiplier , out_shift
997+ )
947998
948999 # Verify output properties
9491000 self .assertEqual (output .dtype , dtype , f"Output dtype should be { dtype } " )
0 commit comments