Skip to content

Commit 6bc73e6

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
All variants of quantized relu (pytorch#14080)
Summary: Create a generic quantized relu and decorators for all custom quantized relu ops. Differential Revision: D81948125
1 parent 2f588c5 commit 6bc73e6

File tree

2 files changed

+183
-60
lines changed

2 files changed

+183
-60
lines changed

backends/cadence/aot/ref_implementations.py

Lines changed: 78 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -735,13 +735,12 @@ def quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor() -> torch.Tens
735735
def quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
736736

737737

738-
@impl(m, "quantized_relu")
739-
def quantized_relu(
738+
def quantized_relu_common(
740739
X: torch.Tensor,
741-
X_zero_point: torch.Tensor,
740+
X_zero_point: torch.Tensor | int,
742741
out_zero_point: int,
743-
out_multiplier: torch.Tensor,
744-
out_shift: torch.Tensor,
742+
out_multiplier: int,
743+
out_shift: int,
745744
) -> torch.Tensor:
746745
"""
747746
Quantized ReLU operation followed by requantization.
@@ -757,7 +756,7 @@ def quantized_relu(
757756
if X.dtype not in supported_dtypes:
758757
raise ValueError(f"X dtype must be one of {supported_dtypes}. Got {X.dtype}")
759758

760-
out_scale = -out_multiplier * (1 / (1 << 31)) * (2 ** out_shift[0])
759+
out_scale = -out_multiplier * (1 / (1 << 31)) * (2**out_shift)
761760
dequantized_X = torch.where(X > X_zero_point, X - X_zero_point, torch.zeros_like(X))
762761
return quantize_per_tensor(
763762
dequantized_X,
@@ -769,6 +768,79 @@ def quantized_relu(
769768
)
770769

771770

771+
def quantized_relu_variant(
772+
per_tensor: bool,
773+
dtype: torch.dtype | None = None,
774+
) -> Callable[[Callable[..., torch.Tensor]], Callable[..., torch.Tensor]]:
775+
"""Create a quantized relu variant with type checking."""
776+
777+
def decorator(_: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
778+
def variant(
779+
X: torch.Tensor,
780+
X_zero_point: torch.Tensor | int,
781+
out_zero_point: int,
782+
out_multiplier: torch.Tensor | int,
783+
out_shift: torch.Tensor | int,
784+
) -> torch.Tensor:
785+
if per_tensor:
786+
if dtype and X.dtype != dtype:
787+
raise ValueError(f"X dtype must be {dtype}. Got {X.dtype}")
788+
789+
assert isinstance(out_shift, int)
790+
assert isinstance(out_multiplier, int)
791+
_out_shift = out_shift
792+
_out_multiplier = out_multiplier
793+
else:
794+
assert isinstance(out_multiplier, torch.Tensor)
795+
if out_multiplier.numel() > 1:
796+
raise ValueError("Only scalar out_multiplier is supported")
797+
798+
assert isinstance(out_shift, torch.Tensor)
799+
if out_shift.numel() > 1:
800+
raise ValueError("Only scalar out_shift is supported")
801+
802+
assert isinstance(X_zero_point, torch.Tensor)
803+
if X_zero_point.shape != X.shape:
804+
raise ValueError(
805+
f"X_zero_point shape must be {X.shape}. Got {X_zero_point.shape}"
806+
)
807+
808+
_out_multiplier = int(out_multiplier.item())
809+
_out_shift = int(out_shift.item())
810+
811+
return quantized_relu_common(
812+
X,
813+
X_zero_point,
814+
out_zero_point,
815+
_out_multiplier,
816+
_out_shift,
817+
)
818+
819+
return variant
820+
821+
return decorator
822+
823+
824+
@impl(m, "quantized_relu")
825+
@quantized_relu_variant(False)
826+
def quantized_relu() -> torch.Tensor: ...
827+
828+
829+
@impl(m, "quantized_relu.per_tensor")
830+
@quantized_relu_variant(True)
831+
def quantized_relu_per_tensor() -> torch.Tensor: ...
832+
833+
834+
@impl(m, "quantized_relu_asym8s_asym8s.per_tensor")
835+
@quantized_relu_variant(True, torch.int8)
836+
def quantized_relu_asym8s_asym8s_per_tensor() -> torch.Tensor: ...
837+
838+
839+
@impl(m, "quantized_relu_asym8u_asym8u.per_tensor")
840+
@quantized_relu_variant(True, torch.uint8)
841+
def quantized_relu_asym8u_asym8u_per_tensor() -> torch.Tensor: ...
842+
843+
772844
@impl(m, "requantize")
773845
def requantize(
774846
input: torch.Tensor,

backends/cadence/aot/tests/test_ref_implementations.py

Lines changed: 105 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -877,73 +877,124 @@ def test_quantized_conv_per_tensor(
877877
@expand(
878878
[
879879
# Test case 1: Basic int8 case with negative scale
880-
(
881-
"basic_int8",
882-
torch.tensor([-1, 0, 1, 3], dtype=torch.int8), # input
883-
torch.tensor([0], dtype=torch.int8), # X_zero_point (scalar broadcast)
884-
0, # out_zero_point
885-
torch.tensor([1073741824]), # out_multiplier (0.5 * 2^31)
886-
torch.tensor([0]), # out_shift
887-
torch.int8, # dtype
888-
torch.tensor(
889-
[0, 0, 0, -2], dtype=torch.int8
890-
), # expected: relu(-1,0,1,3) = (0,0,1,3) * (-0.5) + 0 = (0,0,-0.5,-1.5) -> (0,0,0,-2)
891-
),
880+
*[
881+
(
882+
"basic_int8",
883+
torch.tensor([-1, 0, 1, 3], dtype=dtype), # input
884+
0, # X_zero_point (scalar broadcast)
885+
0, # out_zero_point
886+
1073741824, # out_multiplier (0.5 * 2^31)
887+
0, # out_shift
888+
dtype, # dtype
889+
torch.tensor(
890+
[0, 0, 0, -2], dtype=dtype
891+
), # expected: relu(-1,0,1,3) = (0,0,1,3) * (-0.5) + 0 = (0,0,-0.5,-1.5) -> (0,0,0,-2)
892+
)
893+
for dtype in [torch.int8]
894+
],
892895
# Test case 2: uint8 with non-zero zero point
893-
(
894-
"uint8_with_zp",
895-
torch.tensor([126, 128, 130, 132], dtype=torch.uint8), # input
896-
torch.tensor([128], dtype=torch.uint8), # X_zero_point
897-
64, # out_zero_point
898-
torch.tensor([536870912]), # out_multiplier (0.25 * 2^31)
899-
torch.tensor([0]), # out_shift
900-
torch.uint8, # dtype
901-
torch.tensor(
902-
[64, 64, 64, 63], dtype=torch.uint8
903-
), # expected: relu(-2,0,2,4) = (0,0,2,4) * (-0.25) + 64 = (64,64,63.5,63) -> (64,64,64,63)
904-
),
896+
*[
897+
(
898+
"uint8_with_zp",
899+
torch.tensor([126, 128, 130, 132], dtype=dtype), # input
900+
128, # X_zero_point
901+
64, # out_zero_point
902+
536870912, # out_multiplier (0.25 * 2^31)
903+
0, # out_shift
904+
dtype, # dtype
905+
torch.tensor(
906+
[64, 64, 64, 63], dtype=dtype
907+
), # expected: relu(-2,0,2,4) = (0,0,2,4) * (-0.25) + 64 = (64,64,63.5,63) -> (64,64,64,63)
908+
)
909+
for dtype in [torch.uint8]
910+
],
905911
# Test case 3: All negative values (should all become zero after ReLU)
906-
(
907-
"all_negative_int8",
908-
torch.tensor([-5, -3, -1], dtype=torch.int8), # input
909-
torch.tensor([0], dtype=torch.int8), # X_zero_point
910-
10, # out_zero_point
911-
torch.tensor([1073741824]), # out_multiplier (0.5 * 2^31)
912-
torch.tensor([0]), # out_shift
913-
torch.int8, # dtype
914-
torch.tensor(
915-
[10, 10, 10], dtype=torch.int8
916-
), # expected: relu(-5,-3,-1) = (0,0,0) * (-0.5) + 10 = (10,10,10)
917-
),
912+
*[
913+
(
914+
"all_negative_int8",
915+
torch.tensor([-5, -3, -1], dtype=dtype), # input
916+
0, # X_zero_point
917+
10, # out_zero_point
918+
1073741824, # out_multiplier (0.5 * 2^31)
919+
0, # out_shift
920+
dtype, # dtype
921+
torch.tensor(
922+
[10, 10, 10], dtype=dtype
923+
), # expected: relu(-5,-3,-1) = (0,0,0) * (-0.5) + 10 = (10,10,10)
924+
)
925+
for dtype in [torch.int8]
926+
],
918927
# Test case 4: All positive values with shift (scale becomes -0.25)
919-
(
920-
"positive_with_shift",
921-
torch.tensor([2, 4, 6, 8], dtype=torch.int8), # input
922-
torch.tensor([1], dtype=torch.int8), # X_zero_point
923-
5, # out_zero_point
924-
torch.tensor([1073741824]), # out_multiplier (0.5 * 2^31)
925-
torch.tensor([1]), # out_shift (multiply by 2^1 = 2)
926-
torch.int8, # dtype
927-
torch.tensor(
928-
[4, 2, 0, -2], dtype=torch.int8
929-
), # expected: relu(1,3,5,7) = (1,3,5,7) * (-1.0) + 5 = (4,2,0,-2)
930-
),
928+
*[
929+
(
930+
"positive_with_shift",
931+
torch.tensor([2, 4, 6, 8], dtype=dtype), # input
932+
1, # X_zero_point
933+
5, # out_zero_point
934+
1073741824, # out_multiplier (0.5 * 2^31)
935+
1, # out_shift (multiply by 2^1 = 2)
936+
dtype, # dtype
937+
torch.tensor(
938+
[4, 2, 0, -2], dtype=dtype
939+
), # expected: relu(1,3,5,7) = (1,3,5,7) * (-1.0) + 5 = (4,2,0,-2)
940+
)
941+
for dtype in [torch.int8, torch.uint8]
942+
],
943+
# Test case 4: Non-per-tensor
944+
*[
945+
(
946+
"non_per_tensor",
947+
torch.tensor([-1, -2, -3, 1, 2, 3], dtype=dtype), # input
948+
torch.tensor([0, 0, 0, 1, 1, 1]), # X_zero_point
949+
5, # out_zero_point
950+
torch.tensor([1073741824]), # out_multiplier (0.5 * 2^31)
951+
torch.tensor([1]), # out_shift (multiply by 2^1 = 2)
952+
dtype, # dtype
953+
torch.tensor([5, 5, 5, 5, 4, 3], dtype=dtype),
954+
)
955+
for dtype in [torch.int8]
956+
],
931957
]
932958
)
933959
def test_quantized_relu(
934960
self,
935961
name: str,
936962
X: torch.Tensor,
937-
X_zero_point: torch.Tensor,
963+
X_zero_point: torch.Tensor | int,
938964
out_zero_point: int,
939-
out_multiplier: torch.Tensor,
940-
out_shift: torch.Tensor,
965+
out_multiplier: torch.Tensor | int,
966+
out_shift: torch.Tensor | int,
941967
dtype: torch.dtype,
942968
expected_output: torch.Tensor,
943969
) -> None:
944-
output = torch.ops.cadence.quantized_relu(
945-
X, X_zero_point, out_zero_point, out_multiplier, out_shift
946-
)
970+
971+
if isinstance(X_zero_point, int):
972+
assert isinstance(out_multiplier, int)
973+
assert isinstance(out_shift, int)
974+
975+
match dtype:
976+
case torch.int8:
977+
quantized_relu = (
978+
torch.ops.cadence.quantized_relu_asym8s_asym8s.per_tensor
979+
)
980+
case torch.uint8:
981+
quantized_relu = (
982+
torch.ops.cadence.quantized_relu_asym8u_asym8u.per_tensor
983+
)
984+
case _:
985+
quantized_relu = torch.ops.cadence.quantized_relu_per_tensor
986+
987+
output = quantized_relu(
988+
X,
989+
X_zero_point,
990+
out_zero_point,
991+
out_multiplier,
992+
out_shift,
993+
)
994+
else:
995+
output = torch.ops.cadence.quantized_relu(
996+
X, X_zero_point, out_zero_point, out_multiplier, out_shift
997+
)
947998

948999
# Verify output properties
9491000
self.assertEqual(output.dtype, dtype, f"Output dtype should be {dtype}")

0 commit comments

Comments
 (0)