Skip to content

Commit 09408e3

Browse files
authored
Added support for quantized_softmax in ref_implementations
Differential Revision: D85188129 Pull Request resolved: #15426
1 parent 48c4e45 commit 09408e3

File tree

3 files changed

+225
-22
lines changed

3 files changed

+225
-22
lines changed

backends/cadence/aot/ops_registrations.py

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
# pyre-strict
88

9-
import logging
109
from math import prod
1110
from typing import Callable, Optional, Tuple
1211

@@ -49,36 +48,16 @@ def _validate_ref_impl_exists() -> None:
4948
"cadence::roi_align_box_processor",
5049
}
5150

52-
# All of these should either
53-
# 1. be removed
54-
# 2. have a reference implementation added to ref_implementations.py
55-
_WARN_ONLY = {
56-
"cadence::quantized_softmax.per_tensor",
57-
"cadence::quantized_softmax",
58-
}
59-
6051
ref_impls = get_registered_ref_implementations()
61-
warn_impls = []
6252
error_impls = []
6353
for op_name in _REGISTERED_META_KERNELS:
6454
# Strip the namespace prefix if present (e.g., "cadence::" -> "")
6555
op_name_clean = op_name.split("::")[-1] if "::" in op_name else op_name
6656

6757
if op_name_clean not in ref_impls:
68-
if op_name in _WARN_ONLY:
69-
warn_impls.append(op_name)
70-
elif op_name not in _SKIP_OPS:
58+
if op_name not in _SKIP_OPS:
7159
error_impls.append(op_name)
7260

73-
if warn_impls:
74-
warn_msg = (
75-
f"The following {len(warn_impls)} meta kernel registrations are missing reference implementations:\n"
76-
+ "\n".join(f" - {op}" for op in warn_impls)
77-
+ "\n\nPlease add reference implementations in ref_implementations.py using "
78-
+ "@impl_tracked(m, '<op_name>')."
79-
)
80-
logging.warning(warn_msg)
81-
8261
if error_impls:
8362
error_msg = (
8463
f"The following {len(error_impls)} meta kernel registrations are missing reference implementations:\n"

backends/cadence/aot/ref_implementations.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2054,3 +2054,95 @@ def softmax_f32_f32(
20542054
assert input_tensor.dtype == torch.float32, "input_tensor must be float32"
20552055
assert not half_to_float, "half_to_float is not supported"
20562056
return torch.nn.functional.softmax(input_tensor, dim=dim, dtype=torch.float32)
2057+
2058+
2059+
def quantized_softmax_per_tensor_common(
2060+
input_tensor: torch.Tensor,
2061+
mask: torch.Tensor | None,
2062+
dim: int,
2063+
in_scale: float,
2064+
in_zero_point: int,
2065+
out_scale: float,
2066+
out_zero_point: int,
2067+
) -> torch.Tensor:
2068+
"""
2069+
Quantized softmax operation.
2070+
2071+
Args:
2072+
- input_tensor (Tensor): The quantized input tensor
2073+
- mask (Tensor): Mask tensor
2074+
- dim (int): The dimension along which softmax is computed
2075+
- in_scale (float): The scale of the input quantization
2076+
- in_zero_point (int): The zero point of the input quantization
2077+
- out_scale (float): The scale of the output quantization
2078+
- out_zero_point (int): The zero point of the output quantization
2079+
"""
2080+
# TODO: T228751479 - Add support for mask parameter in softmax
2081+
assert mask is None
2082+
supported_dtypes = [torch.int8, torch.uint8, torch.int16]
2083+
if input_tensor.dtype not in supported_dtypes:
2084+
raise ValueError(
2085+
f"Input dtype must be one of {supported_dtypes}. Got {input_tensor.dtype}"
2086+
)
2087+
2088+
float_input_tensor = dequantize_per_tensor(
2089+
input_tensor,
2090+
in_scale,
2091+
in_zero_point,
2092+
torch.iinfo(input_tensor.dtype).min,
2093+
torch.iinfo(input_tensor.dtype).max,
2094+
input_tensor.dtype,
2095+
)
2096+
2097+
softmax_output = torch.nn.functional.softmax(float_input_tensor, dim=dim)
2098+
2099+
return quantize_per_tensor(
2100+
softmax_output,
2101+
out_scale,
2102+
out_zero_point,
2103+
torch.iinfo(input_tensor.dtype).min,
2104+
torch.iinfo(input_tensor.dtype).max,
2105+
input_tensor.dtype,
2106+
)
2107+
2108+
2109+
@impl_tracked(m, "quantized_softmax.per_tensor")
2110+
def quantized_softmax_per_tensor(
2111+
input_tensor: torch.Tensor,
2112+
mask: torch.Tensor | None,
2113+
dim: int,
2114+
in_scale: float,
2115+
in_zero_point: int,
2116+
out_scale: float,
2117+
out_zero_point: int,
2118+
) -> torch.Tensor:
2119+
return quantized_softmax_per_tensor_common(
2120+
input_tensor,
2121+
mask,
2122+
dim,
2123+
in_scale,
2124+
in_zero_point,
2125+
out_scale,
2126+
out_zero_point,
2127+
)
2128+
2129+
2130+
@impl_tracked(m, "quantized_softmax")
2131+
def quantized_softmax(
2132+
input_tensor: torch.Tensor,
2133+
mask: torch.Tensor | None,
2134+
dim: int,
2135+
in_scale: torch.Tensor,
2136+
in_zero_point: torch.Tensor,
2137+
out_scale: float,
2138+
out_zero_point: int,
2139+
) -> torch.Tensor:
2140+
return quantized_softmax_per_tensor_common(
2141+
input_tensor,
2142+
mask,
2143+
dim,
2144+
float(in_scale.item()),
2145+
int(in_zero_point.item()),
2146+
out_scale,
2147+
out_zero_point,
2148+
)

backends/cadence/aot/tests/test_ref_implementations.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3079,3 +3079,135 @@ def test_quantized_w8a32_gru_invalid_hidden_dim(self) -> None:
30793079
self.assertIn(
30803080
"Hidden dimension must be a multiple of 4", str(context.exception)
30813081
)
3082+
3083+
@expand(
3084+
[
3085+
(
3086+
"basic_int8_dim_1",
3087+
torch.tensor([[10, 20, 30]], dtype=torch.int8),
3088+
None,
3089+
1,
3090+
0.1,
3091+
0,
3092+
0.004,
3093+
0,
3094+
torch.int8,
3095+
torch.tensor([[23, 61, 127]], dtype=torch.int8),
3096+
),
3097+
(
3098+
"uint8_with_zero_points",
3099+
torch.tensor([[128, 130, 132]], dtype=torch.uint8),
3100+
None,
3101+
1,
3102+
0.1,
3103+
128,
3104+
0.004,
3105+
128,
3106+
torch.uint8,
3107+
torch.tensor([[195, 210, 228]], dtype=torch.uint8),
3108+
),
3109+
(
3110+
"basic_int16",
3111+
torch.tensor([[100, 200, 300]], dtype=torch.int16),
3112+
None,
3113+
1,
3114+
0.01,
3115+
0,
3116+
0.004,
3117+
0,
3118+
torch.int16,
3119+
torch.tensor([[23, 61, 166]], dtype=torch.int16),
3120+
),
3121+
(
3122+
"multi_row_int8",
3123+
torch.tensor([[10, 20, 30], [5, 10, 15]], dtype=torch.int8),
3124+
None,
3125+
1,
3126+
0.1,
3127+
0,
3128+
0.004,
3129+
0,
3130+
torch.int8,
3131+
torch.tensor([[23, 61, 127], [47, 77, 127]], dtype=torch.int8),
3132+
),
3133+
(
3134+
"softmax_dim_0",
3135+
torch.tensor([[10, 20], [30, 40]], dtype=torch.int8),
3136+
None,
3137+
0,
3138+
0.1,
3139+
0,
3140+
0.004,
3141+
0,
3142+
torch.int8,
3143+
torch.tensor([[30, 30], [127, 127]], dtype=torch.int8),
3144+
),
3145+
]
3146+
)
3147+
def test_quantized_softmax_per_tensor(
3148+
self,
3149+
name: str,
3150+
input_tensor: torch.Tensor,
3151+
mask: torch.Tensor | None,
3152+
dim: int,
3153+
in_scale: float,
3154+
in_zero_point: int,
3155+
out_scale: float,
3156+
out_zero_point: int,
3157+
dtype: torch.dtype,
3158+
expected_output: torch.Tensor,
3159+
) -> None:
3160+
output = torch.ops.cadence.quantized_softmax.per_tensor(
3161+
input_tensor,
3162+
mask,
3163+
dim,
3164+
in_scale,
3165+
in_zero_point,
3166+
out_scale,
3167+
out_zero_point,
3168+
)
3169+
3170+
# Verify output properties
3171+
self.assertEqual(
3172+
output.dtype, dtype, f"Output dtype should be {dtype} in {name}"
3173+
)
3174+
self.assertEqual(
3175+
output.shape,
3176+
input_tensor.shape,
3177+
f"Output shape should match input shape in {name}",
3178+
)
3179+
3180+
# Verify output matches expected values (allowing for small quantization errors)
3181+
# For softmax, we expect outputs to be in [0, 1] range when dequantized
3182+
self.assertTrue(
3183+
torch.allclose(
3184+
output.to(torch.float32),
3185+
expected_output.to(torch.float32),
3186+
rtol=0.05,
3187+
atol=5.0,
3188+
),
3189+
f"Output values don't match expected in {name}. Got {output}, expected {expected_output}",
3190+
)
3191+
3192+
def test_quantized_softmax(self) -> None:
3193+
# Test quantized_softmax (default variant with tensor scale/zero_point)
3194+
input_tensor = torch.tensor([[10, 20, 30]], dtype=torch.int8)
3195+
in_scale = torch.tensor([0.1])
3196+
in_zero_point = torch.tensor([0])
3197+
output = torch.ops.cadence.quantized_softmax(
3198+
input_tensor,
3199+
None, # mask
3200+
1, # dim
3201+
in_scale,
3202+
in_zero_point,
3203+
0.004, # out_scale
3204+
0, # out_zero_point
3205+
)
3206+
3207+
# Verify output properties
3208+
self.assertEqual(output.dtype, torch.int8, "Output dtype should be int8")
3209+
self.assertEqual(
3210+
output.shape,
3211+
input_tensor.shape,
3212+
"Output shape should match input shape",
3213+
)

0 commit comments

Comments
 (0)