Skip to content

Commit f783d97

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Removed support for non-per-tensor quantized relu (pytorch#14788)
Summary: Not supporting quantized relu default, so removing it from ref_implementations Differential Revision: D83874866
1 parent 705fccc commit f783d97

File tree

1 file changed

+10
-39
lines changed

1 file changed

+10
-39
lines changed

backends/cadence/aot/ref_implementations.py

Lines changed: 10 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1125,75 +1125,46 @@ def quantized_relu_common(
11251125

11261126

11271127
def quantized_relu_variant(
1128-
per_tensor: bool,
11291128
dtype: torch.dtype | None = None,
11301129
) -> Callable[[Callable[..., torch.Tensor]], Callable[..., torch.Tensor]]:
11311130
"""Create a quantized relu variant with type checking."""
11321131

11331132
def decorator(_: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
11341133
def variant(
11351134
X: torch.Tensor,
1136-
X_zero_point: torch.Tensor | int,
1135+
X_zero_point: int,
11371136
out_zero_point: int,
1138-
out_multiplier: torch.Tensor | int,
1139-
out_shift: torch.Tensor | int,
1137+
out_multiplier: int,
1138+
out_shift: int,
11401139
) -> torch.Tensor:
1141-
if per_tensor:
1142-
if dtype and X.dtype != dtype:
1143-
raise ValueError(f"X dtype must be {dtype}. Got {X.dtype}")
1144-
1145-
assert isinstance(out_shift, int)
1146-
assert isinstance(out_multiplier, int)
1147-
_out_shift = out_shift
1148-
_out_multiplier = out_multiplier
1149-
else:
1150-
assert isinstance(out_multiplier, torch.Tensor)
1151-
if out_multiplier.numel() > 1:
1152-
raise ValueError("Only scalar out_multiplier is supported")
1153-
1154-
assert isinstance(out_shift, torch.Tensor)
1155-
if out_shift.numel() > 1:
1156-
raise ValueError("Only scalar out_shift is supported")
1157-
1158-
assert isinstance(X_zero_point, torch.Tensor)
1159-
if X_zero_point.shape != X.shape:
1160-
raise ValueError(
1161-
f"X_zero_point shape must be {X.shape}. Got {X_zero_point.shape}"
1162-
)
1163-
1164-
_out_multiplier = int(out_multiplier.item())
1165-
_out_shift = int(out_shift.item())
1140+
if dtype and X.dtype != dtype:
1141+
raise ValueError(f"X dtype must be {dtype}. Got {X.dtype}")
11661142

11671143
return quantized_relu_common(
11681144
X,
11691145
X_zero_point,
11701146
out_zero_point,
1171-
_out_multiplier,
1172-
_out_shift,
1147+
out_multiplier,
1148+
out_shift,
11731149
)
11741150

11751151
return variant
11761152

11771153
return decorator
11781154

11791155

1180-
@impl(m, "quantized_relu")
1181-
@quantized_relu_variant(False)
1182-
def quantized_relu() -> torch.Tensor: ...
1183-
1184-
11851156
@impl(m, "quantized_relu.per_tensor")
1186-
@quantized_relu_variant(True)
1157+
@quantized_relu_variant()
11871158
def quantized_relu_per_tensor() -> torch.Tensor: ...
11881159

11891160

11901161
@impl(m, "quantized_relu_asym8s_asym8s.per_tensor")
1191-
@quantized_relu_variant(True, torch.int8)
1162+
@quantized_relu_variant(torch.int8)
11921163
def quantized_relu_asym8s_asym8s_per_tensor() -> torch.Tensor: ...
11931164

11941165

11951166
@impl(m, "quantized_relu_asym8u_asym8u.per_tensor")
1196-
@quantized_relu_variant(True, torch.uint8)
1167+
@quantized_relu_variant(torch.uint8)
11971168
def quantized_relu_asym8u_asym8u_per_tensor() -> torch.Tensor: ...
11981169

11991170

0 commit comments

Comments
 (0)