@@ -884,73 +884,124 @@ def test_quantized_conv_per_tensor(
884884 @expand (
885885 [
886886 # Test case 1: Basic int8 case with negative scale
887- (
888- "basic_int8" ,
889- torch .tensor ([- 1 , 0 , 1 , 3 ], dtype = torch .int8 ), # input
890- torch .tensor ([0 ], dtype = torch .int8 ), # X_zero_point (scalar broadcast)
891- 0 , # out_zero_point
892- torch .tensor ([1073741824 ]), # out_multiplier (0.5 * 2^31)
893- torch .tensor ([0 ]), # out_shift
894- torch .int8 , # dtype
895- torch .tensor (
896- [0 , 0 , 0 , - 2 ], dtype = torch .int8
897- ), # expected: relu(-1,0,1,3) = (0,0,1,3) * (-0.5) + 0 = (0,0,-0.5,-1.5) -> (0,0,0,-2)
898- ),
887+ * [
888+ (
889+ "basic_int8" ,
890+ torch .tensor ([- 1 , 0 , 1 , 3 ], dtype = dtype ), # input
891+ 0 , # X_zero_point (scalar broadcast)
892+ 0 , # out_zero_point
893+ 1073741824 , # out_multiplier (0.5 * 2^31)
894+ 0 , # out_shift
895+ dtype , # dtype
896+ torch .tensor (
897+ [0 , 0 , 0 , - 2 ], dtype = dtype
898+ ), # expected: relu(-1,0,1,3) = (0,0,1,3) * (-0.5) + 0 = (0,0,-0.5,-1.5) -> (0,0,0,-2)
899+ )
900+ for dtype in [torch .int8 ]
901+ ],
899902 # Test case 2: uint8 with non-zero zero point
900- (
901- "uint8_with_zp" ,
902- torch .tensor ([126 , 128 , 130 , 132 ], dtype = torch .uint8 ), # input
903- torch .tensor ([128 ], dtype = torch .uint8 ), # X_zero_point
904- 64 , # out_zero_point
905- torch .tensor ([536870912 ]), # out_multiplier (0.25 * 2^31)
906- torch .tensor ([0 ]), # out_shift
907- torch .uint8 , # dtype
908- torch .tensor (
909- [64 , 64 , 64 , 63 ], dtype = torch .uint8
910- ), # expected: relu(-2,0,2,4) = (0,0,2,4) * (-0.25) + 64 = (64,64,63.5,63) -> (64,64,64,63)
911- ),
903+ * [
904+ (
905+ "uint8_with_zp" ,
906+ torch .tensor ([126 , 128 , 130 , 132 ], dtype = dtype ), # input
907+ 128 , # X_zero_point
908+ 64 , # out_zero_point
909+ 536870912 , # out_multiplier (0.25 * 2^31)
910+ 0 , # out_shift
911+ dtype , # dtype
912+ torch .tensor (
913+ [64 , 64 , 64 , 63 ], dtype = dtype
914+ ), # expected: relu(-2,0,2,4) = (0,0,2,4) * (-0.25) + 64 = (64,64,63.5,63) -> (64,64,64,63)
915+ )
916+ for dtype in [torch .uint8 ]
917+ ],
912918 # Test case 3: All negative values (should all become zero after ReLU)
913- (
914- "all_negative_int8" ,
915- torch .tensor ([- 5 , - 3 , - 1 ], dtype = torch .int8 ), # input
916- torch .tensor ([0 ], dtype = torch .int8 ), # X_zero_point
917- 10 , # out_zero_point
918- torch .tensor ([1073741824 ]), # out_multiplier (0.5 * 2^31)
919- torch .tensor ([0 ]), # out_shift
920- torch .int8 , # dtype
921- torch .tensor (
922- [10 , 10 , 10 ], dtype = torch .int8
923- ), # expected: relu(-5,-3,-1) = (0,0,0) * (-0.5) + 10 = (10,10,10)
924- ),
919+ * [
920+ (
921+ "all_negative_int8" ,
922+ torch .tensor ([- 5 , - 3 , - 1 ], dtype = dtype ), # input
923+ 0 , # X_zero_point
924+ 10 , # out_zero_point
925+ 1073741824 , # out_multiplier (0.5 * 2^31)
926+ 0 , # out_shift
927+ dtype , # dtype
928+ torch .tensor (
929+ [10 , 10 , 10 ], dtype = dtype
930+ ), # expected: relu(-5,-3,-1) = (0,0,0) * (-0.5) + 10 = (10,10,10)
931+ )
932+ for dtype in [torch .int8 ]
933+ ],
925934 # Test case 4: All positive values with shift (scale becomes -0.25)
926- (
927- "positive_with_shift" ,
928- torch .tensor ([2 , 4 , 6 , 8 ], dtype = torch .int8 ), # input
929- torch .tensor ([1 ], dtype = torch .int8 ), # X_zero_point
930- 5 , # out_zero_point
931- torch .tensor ([1073741824 ]), # out_multiplier (0.5 * 2^31)
932- torch .tensor ([1 ]), # out_shift (multiply by 2^1 = 2)
933- torch .int8 , # dtype
934- torch .tensor (
935- [4 , 2 , 0 , - 2 ], dtype = torch .int8
936- ), # expected: relu(1,3,5,7) = (1,3,5,7) * (-1.0) + 5 = (4,2,0,-2)
937- ),
935+ * [
936+ (
937+ "positive_with_shift" ,
938+ torch .tensor ([2 , 4 , 6 , 8 ], dtype = dtype ), # input
939+ 1 , # X_zero_point
940+ 5 , # out_zero_point
941+ 1073741824 , # out_multiplier (0.5 * 2^31)
942+ 1 , # out_shift (multiply by 2^1 = 2)
943+ dtype , # dtype
944+ torch .tensor (
945+ [4 , 2 , 0 , - 2 ], dtype = dtype
946+ ), # expected: relu(1,3,5,7) = (1,3,5,7) * (-1.0) + 5 = (4,2,0,-2)
947+ )
948+ for dtype in [torch .int8 , torch .uint8 ]
949+ ],
950+ # Test case 4: Non-per-tensor
951+ * [
952+ (
953+ "non_per_tensor" ,
954+ torch .tensor ([- 1 , - 2 , - 3 , 1 , 2 , 3 ], dtype = dtype ), # input
955+ torch .tensor ([0 , 0 , 0 , 1 , 1 , 1 ]), # X_zero_point
956+ 5 , # out_zero_point
957+ torch .tensor ([1073741824 ]), # out_multiplier (0.5 * 2^31)
958+ torch .tensor ([1 ]), # out_shift (multiply by 2^1 = 2)
959+ dtype , # dtype
960+ torch .tensor ([5 , 5 , 5 , 5 , 4 , 3 ], dtype = dtype ),
961+ )
962+ for dtype in [torch .int8 ]
963+ ],
938964 ]
939965 )
940966 def test_quantized_relu (
941967 self ,
942968 name : str ,
943969 X : torch .Tensor ,
944- X_zero_point : torch .Tensor ,
970+ X_zero_point : torch .Tensor | int ,
945971 out_zero_point : int ,
946- out_multiplier : torch .Tensor ,
947- out_shift : torch .Tensor ,
972+ out_multiplier : torch .Tensor | int ,
973+ out_shift : torch .Tensor | int ,
948974 dtype : torch .dtype ,
949975 expected_output : torch .Tensor ,
950976 ) -> None :
951- output = torch .ops .cadence .quantized_relu (
952- X , X_zero_point , out_zero_point , out_multiplier , out_shift
953- )
977+
978+ if isinstance (X_zero_point , int ):
979+ assert isinstance (out_multiplier , int )
980+ assert isinstance (out_shift , int )
981+
982+ match dtype :
983+ case torch .int8 :
984+ quantized_relu = (
985+ torch .ops .cadence .quantized_relu_asym8s_asym8s .per_tensor
986+ )
987+ case torch .uint8 :
988+ quantized_relu = (
989+ torch .ops .cadence .quantized_relu_asym8u_asym8u .per_tensor
990+ )
991+ case _:
992+ quantized_relu = torch .ops .cadence .quantized_relu_per_tensor
993+
994+ output = quantized_relu (
995+ X ,
996+ X_zero_point ,
997+ out_zero_point ,
998+ out_multiplier ,
999+ out_shift ,
1000+ )
1001+ else :
1002+ output = torch .ops .cadence .quantized_relu (
1003+ X , X_zero_point , out_zero_point , out_multiplier , out_shift
1004+ )
9541005
9551006 # Verify output properties
9561007 self .assertEqual (output .dtype , dtype , f"Output dtype should be { dtype } " )
0 commit comments