Skip to content

Commit 269e6b5

Browse files
Eashan Gargfacebook-github-bot
authored andcommitted
Added support for quantized_softmax in ref_implementations
Summary: Add support for quantized_softmax ref implementation Reviewed By: DrJessop Differential Revision: D85188129
1 parent 94def70 commit 269e6b5

File tree

3 files changed

+223
-2
lines changed

3 files changed

+223
-2
lines changed

backends/cadence/aot/ops_registrations.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,6 @@ def _validate_ref_impl_exists() -> None:
5353
# 1. be removed
5454
# 2. have a reference implementation added to ref_implementations.py
5555
_WARN_ONLY = {
56-
"cadence::quantized_softmax.per_tensor",
57-
"cadence::quantized_softmax",
5856
}
5957

6058
ref_impls = get_registered_ref_implementations()

backends/cadence/aot/ref_implementations.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2054,3 +2054,92 @@ def softmax_f32_f32(
20542054
assert input_tensor.dtype == torch.float32, "input_tensor must be float32"
20552055
assert not half_to_float, "half_to_float is not supported"
20562056
return torch.nn.functional.softmax(input_tensor, dim=dim, dtype=torch.float32)
2057+
2058+
2059+
def quantized_softmax_per_tensor_common(
2060+
input_tensor: torch.Tensor,
2061+
mask: torch.Tensor | None,
2062+
dim: int,
2063+
in_scale: float,
2064+
in_zero_point: int,
2065+
out_scale: float,
2066+
out_zero_point: int,
2067+
) -> torch.Tensor:
2068+
"""
2069+
Quantized softmax operation.
2070+
2071+
Args:
2072+
- input_tensor (Tensor): The quantized input tensor
2073+
- mask (Tensor): Mask tensor
2074+
- dim (int): The dimension along which softmax is computed
2075+
- in_scale (float): The scale of the input quantization
2076+
- in_zero_point (int): The zero point of the input quantization
2077+
- out_scale (float): The scale of the output quantization
2078+
- out_zero_point (int): The zero point of the output quantization
2079+
"""
2080+
supported_dtypes = [torch.int8, torch.uint8, torch.int16]
2081+
if input_tensor.dtype not in supported_dtypes:
2082+
raise ValueError(
2083+
f"Input dtype must be one of {supported_dtypes}. Got {input_tensor.dtype}"
2084+
)
2085+
2086+
float_input_tensor = dequantize_per_tensor(
2087+
input_tensor,
2088+
in_scale,
2089+
in_zero_point,
2090+
torch.iinfo(input_tensor.dtype).min,
2091+
torch.iinfo(input_tensor.dtype).max,
2092+
input_tensor.dtype,
2093+
)
2094+
2095+
softmax_output = torch.nn.functional.softmax(float_input_tensor, dim=dim)
2096+
2097+
return quantize_per_tensor(
2098+
softmax_output,
2099+
out_scale,
2100+
out_zero_point,
2101+
torch.iinfo(input_tensor.dtype).min,
2102+
torch.iinfo(input_tensor.dtype).max,
2103+
input_tensor.dtype,
2104+
)
2105+
2106+
@impl_tracked(m, "quantized_softmax.per_tensor")
2107+
def quantized_softmax_per_tensor(
2108+
input_tensor: torch.Tensor,
2109+
mask: torch.Tensor | None,
2110+
dim: int,
2111+
in_scale: float,
2112+
in_zero_point: int,
2113+
out_scale: float,
2114+
out_zero_point: int,
2115+
) -> torch.Tensor:
2116+
return quantized_softmax_per_tensor_common(
2117+
input_tensor,
2118+
mask,
2119+
dim,
2120+
in_scale,
2121+
in_zero_point,
2122+
out_scale,
2123+
out_zero_point,
2124+
)
2125+
2126+
2127+
@impl_tracked(m, "quantized_softmax")
2128+
def quantized_softmax(
2129+
input_tensor: torch.Tensor,
2130+
mask: torch.Tensor | None,
2131+
dim: int,
2132+
in_scale: torch.Tensor,
2133+
in_zero_point: torch.Tensor,
2134+
out_scale: float,
2135+
out_zero_point: int,
2136+
) -> torch.Tensor:
2137+
return quantized_softmax_per_tensor_common(
2138+
input_tensor,
2139+
mask,
2140+
dim,
2141+
float(in_scale.item()),
2142+
int(in_zero_point.item()),
2143+
out_scale,
2144+
out_zero_point,
2145+
)

backends/cadence/aot/tests/test_ref_implementations.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3079,3 +3079,137 @@ def test_quantized_w8a32_gru_invalid_hidden_dim(self) -> None:
30793079
self.assertIn(
30803080
"Hidden dimension must be a multiple of 4", str(context.exception)
30813081
)
3082+
3083+
@expand(
3084+
[
3085+
(
3086+
"basic_int8_dim_1",
3087+
torch.tensor([[10, 20, 30]], dtype=torch.int8),
3088+
None,
3089+
1,
3090+
0.1,
3091+
0,
3092+
0.004,
3093+
0,
3094+
torch.int8,
3095+
torch.tensor([[23, 61, 127]], dtype=torch.int8),
3096+
),
3097+
(
3098+
"uint8_with_zero_points",
3099+
torch.tensor([[128, 130, 132]], dtype=torch.uint8),
3100+
None,
3101+
1,
3102+
0.1,
3103+
128,
3104+
0.004,
3105+
128,
3106+
torch.uint8,
3107+
torch.tensor([[195, 210, 228]], dtype=torch.uint8),
3108+
),
3109+
(
3110+
"basic_int16",
3111+
torch.tensor([[100, 200, 300]], dtype=torch.int16),
3112+
None,
3113+
1,
3114+
0.01,
3115+
0,
3116+
0.004,
3117+
0,
3118+
torch.int16,
3119+
torch.tensor([[23, 61, 166]], dtype=torch.int16),
3120+
),
3121+
(
3122+
"multi_row_int8",
3123+
torch.tensor(
3124+
[[10, 20, 30], [5, 10, 15]], dtype=torch.int8
3125+
),
3126+
None,
3127+
1,
3128+
0.1,
3129+
0,
3130+
0.004,
3131+
0,
3132+
torch.int8,
3133+
torch.tensor(
3134+
[[23, 61, 127], [47, 77, 127]], dtype=torch.int8
3135+
),
3136+
),
3137+
(
3138+
"softmax_dim_0",
3139+
torch.tensor([[10, 20], [30, 40]], dtype=torch.int8),
3140+
None,
3141+
0,
3142+
0.1,
3143+
0,
3144+
0.004,
3145+
0,
3146+
torch.int8,
3147+
torch.tensor([[30, 30], [127, 127]], dtype=torch.int8),
3148+
),
3149+
]
3150+
)
3151+
def test_quantized_softmax_per_tensor(
3152+
self,
3153+
name: str,
3154+
input_tensor: torch.Tensor,
3155+
mask: torch.Tensor | None,
3156+
dim: int,
3157+
in_scale: float,
3158+
in_zero_point: int,
3159+
out_scale: float,
3160+
out_zero_point: int,
3161+
dtype: torch.dtype,
3162+
expected_output: torch.Tensor,
3163+
) -> None:
3164+
output = torch.ops.cadence.quantized_softmax.per_tensor(
3165+
input_tensor,
3166+
mask,
3167+
dim,
3168+
in_scale,
3169+
in_zero_point,
3170+
out_scale,
3171+
out_zero_point,
3172+
)
3173+
3174+
# Verify output properties
3175+
self.assertEqual(output.dtype, dtype, f"Output dtype should be {dtype} in {name}")
3176+
self.assertEqual(
3177+
output.shape,
3178+
input_tensor.shape,
3179+
f"Output shape should match input shape in {name}",
3180+
)
3181+
3182+
# Verify output matches expected values (allowing for small quantization errors)
3183+
# For softmax, we expect outputs to be in [0, 1] range when dequantized
3184+
self.assertTrue(
3185+
torch.allclose(
3186+
output.to(torch.float32),
3187+
expected_output.to(torch.float32),
3188+
rtol=0.05,
3189+
atol=5.0,
3190+
),
3191+
f"Output values don't match expected in {name}. Got {output}, expected {expected_output}",
3192+
)
3193+
3194+
def test_quantized_softmax(self) -> None:
3195+
# Test quantized_softmax (default variant with tensor scale/zero_point)
3196+
input_tensor = torch.tensor([[10, 20, 30]], dtype=torch.int8)
3197+
in_scale = torch.tensor([0.1])
3198+
in_zero_point = torch.tensor([0])
3199+
output = torch.ops.cadence.quantized_softmax(
3200+
input_tensor,
3201+
None, # mask
3202+
1, # dim
3203+
in_scale,
3204+
in_zero_point,
3205+
0.004, # out_scale
3206+
0, # out_zero_point
3207+
)
3208+
3209+
# Verify output properties
3210+
self.assertEqual(output.dtype, torch.int8, "Output dtype should be int8")
3211+
self.assertEqual(
3212+
output.shape,
3213+
input_tensor.shape,
3214+
"Output shape should match input shape",
3215+
)

0 commit comments

Comments
 (0)