Skip to content

Commit c97c2c1

Browse files
authored
All variants of quantized relu
Differential Revision: D81948125 Pull Request resolved: #14080
1 parent 9af908d commit c97c2c1

File tree

2 files changed

+183
-60
lines changed

2 files changed

+183
-60
lines changed

backends/cadence/aot/ref_implementations.py

Lines changed: 78 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -748,13 +748,12 @@ def quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor() -> torch.Tens
748748
def quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
749749

750750

751-
@impl(m, "quantized_relu")
752-
def quantized_relu(
751+
def quantized_relu_common(
753752
X: torch.Tensor,
754-
X_zero_point: torch.Tensor,
753+
X_zero_point: torch.Tensor | int,
755754
out_zero_point: int,
756-
out_multiplier: torch.Tensor,
757-
out_shift: torch.Tensor,
755+
out_multiplier: int,
756+
out_shift: int,
758757
) -> torch.Tensor:
759758
"""
760759
Quantized ReLU operation followed by requantization.
@@ -770,7 +769,7 @@ def quantized_relu(
770769
if X.dtype not in supported_dtypes:
771770
raise ValueError(f"X dtype must be one of {supported_dtypes}. Got {X.dtype}")
772771

773-
out_scale = -out_multiplier * (1 / (1 << 31)) * (2 ** out_shift[0])
772+
out_scale = -out_multiplier * (1 / (1 << 31)) * (2**out_shift)
774773
dequantized_X = torch.where(X > X_zero_point, X - X_zero_point, torch.zeros_like(X))
775774
return quantize_per_tensor(
776775
dequantized_X,
@@ -782,6 +781,79 @@ def quantized_relu(
782781
)
783782

784783

784+
def quantized_relu_variant(
785+
per_tensor: bool,
786+
dtype: torch.dtype | None = None,
787+
) -> Callable[[Callable[..., torch.Tensor]], Callable[..., torch.Tensor]]:
788+
"""Create a quantized relu variant with type checking."""
789+
790+
def decorator(_: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
791+
def variant(
792+
X: torch.Tensor,
793+
X_zero_point: torch.Tensor | int,
794+
out_zero_point: int,
795+
out_multiplier: torch.Tensor | int,
796+
out_shift: torch.Tensor | int,
797+
) -> torch.Tensor:
798+
if per_tensor:
799+
if dtype and X.dtype != dtype:
800+
raise ValueError(f"X dtype must be {dtype}. Got {X.dtype}")
801+
802+
assert isinstance(out_shift, int)
803+
assert isinstance(out_multiplier, int)
804+
_out_shift = out_shift
805+
_out_multiplier = out_multiplier
806+
else:
807+
assert isinstance(out_multiplier, torch.Tensor)
808+
if out_multiplier.numel() > 1:
809+
raise ValueError("Only scalar out_multiplier is supported")
810+
811+
assert isinstance(out_shift, torch.Tensor)
812+
if out_shift.numel() > 1:
813+
raise ValueError("Only scalar out_shift is supported")
814+
815+
assert isinstance(X_zero_point, torch.Tensor)
816+
if X_zero_point.shape != X.shape:
817+
raise ValueError(
818+
f"X_zero_point shape must be {X.shape}. Got {X_zero_point.shape}"
819+
)
820+
821+
_out_multiplier = int(out_multiplier.item())
822+
_out_shift = int(out_shift.item())
823+
824+
return quantized_relu_common(
825+
X,
826+
X_zero_point,
827+
out_zero_point,
828+
_out_multiplier,
829+
_out_shift,
830+
)
831+
832+
return variant
833+
834+
return decorator
835+
836+
837+
@impl(m, "quantized_relu")
838+
@quantized_relu_variant(False)
839+
def quantized_relu() -> torch.Tensor: ...
840+
841+
842+
@impl(m, "quantized_relu.per_tensor")
843+
@quantized_relu_variant(True)
844+
def quantized_relu_per_tensor() -> torch.Tensor: ...
845+
846+
847+
@impl(m, "quantized_relu_asym8s_asym8s.per_tensor")
848+
@quantized_relu_variant(True, torch.int8)
849+
def quantized_relu_asym8s_asym8s_per_tensor() -> torch.Tensor: ...
850+
851+
852+
@impl(m, "quantized_relu_asym8u_asym8u.per_tensor")
853+
@quantized_relu_variant(True, torch.uint8)
854+
def quantized_relu_asym8u_asym8u_per_tensor() -> torch.Tensor: ...
855+
856+
785857
@impl(m, "requantize")
786858
def requantize(
787859
input: torch.Tensor,

backends/cadence/aot/tests/test_ref_implementations.py

Lines changed: 105 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -884,73 +884,124 @@ def test_quantized_conv_per_tensor(
884884
@expand(
885885
[
886886
# Test case 1: Basic int8 case with negative scale
887-
(
888-
"basic_int8",
889-
torch.tensor([-1, 0, 1, 3], dtype=torch.int8), # input
890-
torch.tensor([0], dtype=torch.int8), # X_zero_point (scalar broadcast)
891-
0, # out_zero_point
892-
torch.tensor([1073741824]), # out_multiplier (0.5 * 2^31)
893-
torch.tensor([0]), # out_shift
894-
torch.int8, # dtype
895-
torch.tensor(
896-
[0, 0, 0, -2], dtype=torch.int8
897-
), # expected: relu(-1,0,1,3) = (0,0,1,3) * (-0.5) + 0 = (0,0,-0.5,-1.5) -> (0,0,0,-2)
898-
),
887+
*[
888+
(
889+
"basic_int8",
890+
torch.tensor([-1, 0, 1, 3], dtype=dtype), # input
891+
0, # X_zero_point (scalar broadcast)
892+
0, # out_zero_point
893+
1073741824, # out_multiplier (0.5 * 2^31)
894+
0, # out_shift
895+
dtype, # dtype
896+
torch.tensor(
897+
[0, 0, 0, -2], dtype=dtype
898+
), # expected: relu(-1,0,1,3) = (0,0,1,3) * (-0.5) + 0 = (0,0,-0.5,-1.5) -> (0,0,0,-2)
899+
)
900+
for dtype in [torch.int8]
901+
],
899902
# Test case 2: uint8 with non-zero zero point
900-
(
901-
"uint8_with_zp",
902-
torch.tensor([126, 128, 130, 132], dtype=torch.uint8), # input
903-
torch.tensor([128], dtype=torch.uint8), # X_zero_point
904-
64, # out_zero_point
905-
torch.tensor([536870912]), # out_multiplier (0.25 * 2^31)
906-
torch.tensor([0]), # out_shift
907-
torch.uint8, # dtype
908-
torch.tensor(
909-
[64, 64, 64, 63], dtype=torch.uint8
910-
), # expected: relu(-2,0,2,4) = (0,0,2,4) * (-0.25) + 64 = (64,64,63.5,63) -> (64,64,64,63)
911-
),
903+
*[
904+
(
905+
"uint8_with_zp",
906+
torch.tensor([126, 128, 130, 132], dtype=dtype), # input
907+
128, # X_zero_point
908+
64, # out_zero_point
909+
536870912, # out_multiplier (0.25 * 2^31)
910+
0, # out_shift
911+
dtype, # dtype
912+
torch.tensor(
913+
[64, 64, 64, 63], dtype=dtype
914+
), # expected: relu(-2,0,2,4) = (0,0,2,4) * (-0.25) + 64 = (64,64,63.5,63) -> (64,64,64,63)
915+
)
916+
for dtype in [torch.uint8]
917+
],
912918
# Test case 3: All negative values (should all become zero after ReLU)
913-
(
914-
"all_negative_int8",
915-
torch.tensor([-5, -3, -1], dtype=torch.int8), # input
916-
torch.tensor([0], dtype=torch.int8), # X_zero_point
917-
10, # out_zero_point
918-
torch.tensor([1073741824]), # out_multiplier (0.5 * 2^31)
919-
torch.tensor([0]), # out_shift
920-
torch.int8, # dtype
921-
torch.tensor(
922-
[10, 10, 10], dtype=torch.int8
923-
), # expected: relu(-5,-3,-1) = (0,0,0) * (-0.5) + 10 = (10,10,10)
924-
),
919+
*[
920+
(
921+
"all_negative_int8",
922+
torch.tensor([-5, -3, -1], dtype=dtype), # input
923+
0, # X_zero_point
924+
10, # out_zero_point
925+
1073741824, # out_multiplier (0.5 * 2^31)
926+
0, # out_shift
927+
dtype, # dtype
928+
torch.tensor(
929+
[10, 10, 10], dtype=dtype
930+
), # expected: relu(-5,-3,-1) = (0,0,0) * (-0.5) + 10 = (10,10,10)
931+
)
932+
for dtype in [torch.int8]
933+
],
925934
# Test case 4: All positive values with shift (scale becomes -0.25)
926-
(
927-
"positive_with_shift",
928-
torch.tensor([2, 4, 6, 8], dtype=torch.int8), # input
929-
torch.tensor([1], dtype=torch.int8), # X_zero_point
930-
5, # out_zero_point
931-
torch.tensor([1073741824]), # out_multiplier (0.5 * 2^31)
932-
torch.tensor([1]), # out_shift (multiply by 2^1 = 2)
933-
torch.int8, # dtype
934-
torch.tensor(
935-
[4, 2, 0, -2], dtype=torch.int8
936-
), # expected: relu(1,3,5,7) = (1,3,5,7) * (-1.0) + 5 = (4,2,0,-2)
937-
),
935+
*[
936+
(
937+
"positive_with_shift",
938+
torch.tensor([2, 4, 6, 8], dtype=dtype), # input
939+
1, # X_zero_point
940+
5, # out_zero_point
941+
1073741824, # out_multiplier (0.5 * 2^31)
942+
1, # out_shift (multiply by 2^1 = 2)
943+
dtype, # dtype
944+
torch.tensor(
945+
[4, 2, 0, -2], dtype=dtype
946+
), # expected: relu(1,3,5,7) = (1,3,5,7) * (-1.0) + 5 = (4,2,0,-2)
947+
)
948+
for dtype in [torch.int8, torch.uint8]
949+
],
950+
# Test case 4: Non-per-tensor
951+
*[
952+
(
953+
"non_per_tensor",
954+
torch.tensor([-1, -2, -3, 1, 2, 3], dtype=dtype), # input
955+
torch.tensor([0, 0, 0, 1, 1, 1]), # X_zero_point
956+
5, # out_zero_point
957+
torch.tensor([1073741824]), # out_multiplier (0.5 * 2^31)
958+
torch.tensor([1]), # out_shift (multiply by 2^1 = 2)
959+
dtype, # dtype
960+
torch.tensor([5, 5, 5, 5, 4, 3], dtype=dtype),
961+
)
962+
for dtype in [torch.int8]
963+
],
938964
]
939965
)
940966
def test_quantized_relu(
941967
self,
942968
name: str,
943969
X: torch.Tensor,
944-
X_zero_point: torch.Tensor,
970+
X_zero_point: torch.Tensor | int,
945971
out_zero_point: int,
946-
out_multiplier: torch.Tensor,
947-
out_shift: torch.Tensor,
972+
out_multiplier: torch.Tensor | int,
973+
out_shift: torch.Tensor | int,
948974
dtype: torch.dtype,
949975
expected_output: torch.Tensor,
950976
) -> None:
951-
output = torch.ops.cadence.quantized_relu(
952-
X, X_zero_point, out_zero_point, out_multiplier, out_shift
953-
)
977+
978+
if isinstance(X_zero_point, int):
979+
assert isinstance(out_multiplier, int)
980+
assert isinstance(out_shift, int)
981+
982+
match dtype:
983+
case torch.int8:
984+
quantized_relu = (
985+
torch.ops.cadence.quantized_relu_asym8s_asym8s.per_tensor
986+
)
987+
case torch.uint8:
988+
quantized_relu = (
989+
torch.ops.cadence.quantized_relu_asym8u_asym8u.per_tensor
990+
)
991+
case _:
992+
quantized_relu = torch.ops.cadence.quantized_relu_per_tensor
993+
994+
output = quantized_relu(
995+
X,
996+
X_zero_point,
997+
out_zero_point,
998+
out_multiplier,
999+
out_shift,
1000+
)
1001+
else:
1002+
output = torch.ops.cadence.quantized_relu(
1003+
X, X_zero_point, out_zero_point, out_multiplier, out_shift
1004+
)
9541005

9551006
# Verify output properties
9561007
self.assertEqual(output.dtype, dtype, f"Output dtype should be {dtype}")

0 commit comments

Comments
 (0)