@@ -884,73 +884,124 @@ def test_quantized_conv_per_tensor(
884
884
@expand (
885
885
[
886
886
# Test case 1: Basic int8 case with negative scale
887
- (
888
- "basic_int8" ,
889
- torch .tensor ([- 1 , 0 , 1 , 3 ], dtype = torch .int8 ), # input
890
- torch .tensor ([0 ], dtype = torch .int8 ), # X_zero_point (scalar broadcast)
891
- 0 , # out_zero_point
892
- torch .tensor ([1073741824 ]), # out_multiplier (0.5 * 2^31)
893
- torch .tensor ([0 ]), # out_shift
894
- torch .int8 , # dtype
895
- torch .tensor (
896
- [0 , 0 , 0 , - 2 ], dtype = torch .int8
897
- ), # expected: relu(-1,0,1,3) = (0,0,1,3) * (-0.5) + 0 = (0,0,-0.5,-1.5) -> (0,0,0,-2)
898
- ),
887
+ * [
888
+ (
889
+ "basic_int8" ,
890
+ torch .tensor ([- 1 , 0 , 1 , 3 ], dtype = dtype ), # input
891
+ 0 , # X_zero_point (scalar broadcast)
892
+ 0 , # out_zero_point
893
+ 1073741824 , # out_multiplier (0.5 * 2^31)
894
+ 0 , # out_shift
895
+ dtype , # dtype
896
+ torch .tensor (
897
+ [0 , 0 , 0 , - 2 ], dtype = dtype
898
+ ), # expected: relu(-1,0,1,3) = (0,0,1,3) * (-0.5) + 0 = (0,0,-0.5,-1.5) -> (0,0,0,-2)
899
+ )
900
+ for dtype in [torch .int8 ]
901
+ ],
899
902
# Test case 2: uint8 with non-zero zero point
900
- (
901
- "uint8_with_zp" ,
902
- torch .tensor ([126 , 128 , 130 , 132 ], dtype = torch .uint8 ), # input
903
- torch .tensor ([128 ], dtype = torch .uint8 ), # X_zero_point
904
- 64 , # out_zero_point
905
- torch .tensor ([536870912 ]), # out_multiplier (0.25 * 2^31)
906
- torch .tensor ([0 ]), # out_shift
907
- torch .uint8 , # dtype
908
- torch .tensor (
909
- [64 , 64 , 64 , 63 ], dtype = torch .uint8
910
- ), # expected: relu(-2,0,2,4) = (0,0,2,4) * (-0.25) + 64 = (64,64,63.5,63) -> (64,64,64,63)
911
- ),
903
+ * [
904
+ (
905
+ "uint8_with_zp" ,
906
+ torch .tensor ([126 , 128 , 130 , 132 ], dtype = dtype ), # input
907
+ 128 , # X_zero_point
908
+ 64 , # out_zero_point
909
+ 536870912 , # out_multiplier (0.25 * 2^31)
910
+ 0 , # out_shift
911
+ dtype , # dtype
912
+ torch .tensor (
913
+ [64 , 64 , 64 , 63 ], dtype = dtype
914
+ ), # expected: relu(-2,0,2,4) = (0,0,2,4) * (-0.25) + 64 = (64,64,63.5,63) -> (64,64,64,63)
915
+ )
916
+ for dtype in [torch .uint8 ]
917
+ ],
912
918
# Test case 3: All negative values (should all become zero after ReLU)
913
- (
914
- "all_negative_int8" ,
915
- torch .tensor ([- 5 , - 3 , - 1 ], dtype = torch .int8 ), # input
916
- torch .tensor ([0 ], dtype = torch .int8 ), # X_zero_point
917
- 10 , # out_zero_point
918
- torch .tensor ([1073741824 ]), # out_multiplier (0.5 * 2^31)
919
- torch .tensor ([0 ]), # out_shift
920
- torch .int8 , # dtype
921
- torch .tensor (
922
- [10 , 10 , 10 ], dtype = torch .int8
923
- ), # expected: relu(-5,-3,-1) = (0,0,0) * (-0.5) + 10 = (10,10,10)
924
- ),
919
+ * [
920
+ (
921
+ "all_negative_int8" ,
922
+ torch .tensor ([- 5 , - 3 , - 1 ], dtype = dtype ), # input
923
+ 0 , # X_zero_point
924
+ 10 , # out_zero_point
925
+ 1073741824 , # out_multiplier (0.5 * 2^31)
926
+ 0 , # out_shift
927
+ dtype , # dtype
928
+ torch .tensor (
929
+ [10 , 10 , 10 ], dtype = dtype
930
+ ), # expected: relu(-5,-3,-1) = (0,0,0) * (-0.5) + 10 = (10,10,10)
931
+ )
932
+ for dtype in [torch .int8 ]
933
+ ],
925
934
# Test case 4: All positive values with shift (scale becomes -0.25)
926
- (
927
- "positive_with_shift" ,
928
- torch .tensor ([2 , 4 , 6 , 8 ], dtype = torch .int8 ), # input
929
- torch .tensor ([1 ], dtype = torch .int8 ), # X_zero_point
930
- 5 , # out_zero_point
931
- torch .tensor ([1073741824 ]), # out_multiplier (0.5 * 2^31)
932
- torch .tensor ([1 ]), # out_shift (multiply by 2^1 = 2)
933
- torch .int8 , # dtype
934
- torch .tensor (
935
- [4 , 2 , 0 , - 2 ], dtype = torch .int8
936
- ), # expected: relu(1,3,5,7) = (1,3,5,7) * (-1.0) + 5 = (4,2,0,-2)
937
- ),
935
+ * [
936
+ (
937
+ "positive_with_shift" ,
938
+ torch .tensor ([2 , 4 , 6 , 8 ], dtype = dtype ), # input
939
+ 1 , # X_zero_point
940
+ 5 , # out_zero_point
941
+ 1073741824 , # out_multiplier (0.5 * 2^31)
942
+ 1 , # out_shift (multiply by 2^1 = 2)
943
+ dtype , # dtype
944
+ torch .tensor (
945
+ [4 , 2 , 0 , - 2 ], dtype = dtype
946
+ ), # expected: relu(1,3,5,7) = (1,3,5,7) * (-1.0) + 5 = (4,2,0,-2)
947
+ )
948
+ for dtype in [torch .int8 , torch .uint8 ]
949
+ ],
950
+ # Test case 4: Non-per-tensor
951
+ * [
952
+ (
953
+ "non_per_tensor" ,
954
+ torch .tensor ([- 1 , - 2 , - 3 , 1 , 2 , 3 ], dtype = dtype ), # input
955
+ torch .tensor ([0 , 0 , 0 , 1 , 1 , 1 ]), # X_zero_point
956
+ 5 , # out_zero_point
957
+ torch .tensor ([1073741824 ]), # out_multiplier (0.5 * 2^31)
958
+ torch .tensor ([1 ]), # out_shift (multiply by 2^1 = 2)
959
+ dtype , # dtype
960
+ torch .tensor ([5 , 5 , 5 , 5 , 4 , 3 ], dtype = dtype ),
961
+ )
962
+ for dtype in [torch .int8 ]
963
+ ],
938
964
]
939
965
)
940
966
def test_quantized_relu (
941
967
self ,
942
968
name : str ,
943
969
X : torch .Tensor ,
944
- X_zero_point : torch .Tensor ,
970
+ X_zero_point : torch .Tensor | int ,
945
971
out_zero_point : int ,
946
- out_multiplier : torch .Tensor ,
947
- out_shift : torch .Tensor ,
972
+ out_multiplier : torch .Tensor | int ,
973
+ out_shift : torch .Tensor | int ,
948
974
dtype : torch .dtype ,
949
975
expected_output : torch .Tensor ,
950
976
) -> None :
951
- output = torch .ops .cadence .quantized_relu (
952
- X , X_zero_point , out_zero_point , out_multiplier , out_shift
953
- )
977
+
978
+ if isinstance (X_zero_point , int ):
979
+ assert isinstance (out_multiplier , int )
980
+ assert isinstance (out_shift , int )
981
+
982
+ match dtype :
983
+ case torch .int8 :
984
+ quantized_relu = (
985
+ torch .ops .cadence .quantized_relu_asym8s_asym8s .per_tensor
986
+ )
987
+ case torch .uint8 :
988
+ quantized_relu = (
989
+ torch .ops .cadence .quantized_relu_asym8u_asym8u .per_tensor
990
+ )
991
+ case _:
992
+ quantized_relu = torch .ops .cadence .quantized_relu_per_tensor
993
+
994
+ output = quantized_relu (
995
+ X ,
996
+ X_zero_point ,
997
+ out_zero_point ,
998
+ out_multiplier ,
999
+ out_shift ,
1000
+ )
1001
+ else :
1002
+ output = torch .ops .cadence .quantized_relu (
1003
+ X , X_zero_point , out_zero_point , out_multiplier , out_shift
1004
+ )
954
1005
955
1006
# Verify output properties
956
1007
self .assertEqual (output .dtype , dtype , f"Output dtype should be { dtype } " )
0 commit comments