diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index e4be6a09641..0220baa593f 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -6,7 +6,6 @@ # pyre-strict -import logging from math import prod from typing import Callable, Optional, Tuple @@ -49,36 +48,16 @@ def _validate_ref_impl_exists() -> None: "cadence::roi_align_box_processor", } - # All of these should either - # 1. be removed - # 2. have a reference implementation added to ref_implementations.py - _WARN_ONLY = { - "cadence::quantized_softmax.per_tensor", - "cadence::quantized_softmax", - } - ref_impls = get_registered_ref_implementations() - warn_impls = [] error_impls = [] for op_name in _REGISTERED_META_KERNELS: # Strip the namespace prefix if present (e.g., "cadence::" -> "") op_name_clean = op_name.split("::")[-1] if "::" in op_name else op_name if op_name_clean not in ref_impls: - if op_name in _WARN_ONLY: - warn_impls.append(op_name) - elif op_name not in _SKIP_OPS: + if op_name not in _SKIP_OPS: error_impls.append(op_name) - if warn_impls: - warn_msg = ( - f"The following {len(warn_impls)} meta kernel registrations are missing reference implementations:\n" - + "\n".join(f" - {op}" for op in warn_impls) - + "\n\nPlease add reference implementations in ref_implementations.py using " - + "@impl_tracked(m, '')." - ) - logging.warning(warn_msg) - if error_impls: error_msg = ( f"The following {len(error_impls)} meta kernel registrations are missing reference implementations:\n" diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index 3e08cdc358c..5a8cba0361d 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -2054,3 +2054,95 @@ def softmax_f32_f32( assert input_tensor.dtype == torch.float32, "input_tensor must be float32" assert not half_to_float, "half_to_float is not supported" return torch.nn.functional.softmax(input_tensor, dim=dim, dtype=torch.float32) + + +def quantized_softmax_per_tensor_common( + input_tensor: torch.Tensor, + mask: torch.Tensor | None, + dim: int, + in_scale: float, + in_zero_point: int, + out_scale: float, + out_zero_point: int, +) -> torch.Tensor: + """ + Quantized softmax operation. + + Args: + - input_tensor (Tensor): The quantized input tensor + - mask (Tensor): Mask tensor + - dim (int): The dimension along which softmax is computed + - in_scale (float): The scale of the input quantization + - in_zero_point (int): The zero point of the input quantization + - out_scale (float): The scale of the output quantization + - out_zero_point (int): The zero point of the output quantization + """ + # TODO: T228751479 - Add support for mask parameter in softmax + assert mask is None + supported_dtypes = [torch.int8, torch.uint8, torch.int16] + if input_tensor.dtype not in supported_dtypes: + raise ValueError( + f"Input dtype must be one of {supported_dtypes}. Got {input_tensor.dtype}" + ) + + float_input_tensor = dequantize_per_tensor( + input_tensor, + in_scale, + in_zero_point, + torch.iinfo(input_tensor.dtype).min, + torch.iinfo(input_tensor.dtype).max, + input_tensor.dtype, + ) + + softmax_output = torch.nn.functional.softmax(float_input_tensor, dim=dim) + + return quantize_per_tensor( + softmax_output, + out_scale, + out_zero_point, + torch.iinfo(input_tensor.dtype).min, + torch.iinfo(input_tensor.dtype).max, + input_tensor.dtype, + ) + + +@impl_tracked(m, "quantized_softmax.per_tensor") +def quantized_softmax_per_tensor( + input_tensor: torch.Tensor, + mask: torch.Tensor | None, + dim: int, + in_scale: float, + in_zero_point: int, + out_scale: float, + out_zero_point: int, +) -> torch.Tensor: + return quantized_softmax_per_tensor_common( + input_tensor, + mask, + dim, + in_scale, + in_zero_point, + out_scale, + out_zero_point, + ) + + +@impl_tracked(m, "quantized_softmax") +def quantized_softmax( + input_tensor: torch.Tensor, + mask: torch.Tensor | None, + dim: int, + in_scale: torch.Tensor, + in_zero_point: torch.Tensor, + out_scale: float, + out_zero_point: int, +) -> torch.Tensor: + return quantized_softmax_per_tensor_common( + input_tensor, + mask, + dim, + float(in_scale.item()), + int(in_zero_point.item()), + out_scale, + out_zero_point, + ) diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index c38668b76c6..5629ed518e5 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -3079,3 +3079,135 @@ def test_quantized_w8a32_gru_invalid_hidden_dim(self) -> None: self.assertIn( "Hidden dimension must be a multiple of 4", str(context.exception) ) + + @expand( + [ + ( + "basic_int8_dim_1", + torch.tensor([[10, 20, 30]], dtype=torch.int8), + None, + 1, + 0.1, + 0, + 0.004, + 0, + torch.int8, + torch.tensor([[23, 61, 127]], dtype=torch.int8), + ), + ( + "uint8_with_zero_points", + torch.tensor([[128, 130, 132]], dtype=torch.uint8), + None, + 1, + 0.1, + 128, + 0.004, + 128, + torch.uint8, + torch.tensor([[195, 210, 228]], dtype=torch.uint8), + ), + ( + "basic_int16", + torch.tensor([[100, 200, 300]], dtype=torch.int16), + None, + 1, + 0.01, + 0, + 0.004, + 0, + torch.int16, + torch.tensor([[23, 61, 166]], dtype=torch.int16), + ), + ( + "multi_row_int8", + torch.tensor([[10, 20, 30], [5, 10, 15]], dtype=torch.int8), + None, + 1, + 0.1, + 0, + 0.004, + 0, + torch.int8, + torch.tensor([[23, 61, 127], [47, 77, 127]], dtype=torch.int8), + ), + ( + "softmax_dim_0", + torch.tensor([[10, 20], [30, 40]], dtype=torch.int8), + None, + 0, + 0.1, + 0, + 0.004, + 0, + torch.int8, + torch.tensor([[30, 30], [127, 127]], dtype=torch.int8), + ), + ] + ) + def test_quantized_softmax_per_tensor( + self, + name: str, + input_tensor: torch.Tensor, + mask: torch.Tensor | None, + dim: int, + in_scale: float, + in_zero_point: int, + out_scale: float, + out_zero_point: int, + dtype: torch.dtype, + expected_output: torch.Tensor, + ) -> None: + output = torch.ops.cadence.quantized_softmax.per_tensor( + input_tensor, + mask, + dim, + in_scale, + in_zero_point, + out_scale, + out_zero_point, + ) + + # Verify output properties + self.assertEqual( + output.dtype, dtype, f"Output dtype should be {dtype} in {name}" + ) + self.assertEqual( + output.shape, + input_tensor.shape, + f"Output shape should match input shape in {name}", + ) + + # Verify output matches expected values (allowing for small quantization errors) + # For softmax, we expect outputs to be in [0, 1] range when dequantized + self.assertTrue( + torch.allclose( + output.to(torch.float32), + expected_output.to(torch.float32), + rtol=0.05, + atol=5.0, + ), + f"Output values don't match expected in {name}. Got {output}, expected {expected_output}", + ) + + def test_quantized_softmax(self) -> None: + # Test quantized_softmax (default variant with tensor scale/zero_point) + input_tensor = torch.tensor([[10, 20, 30]], dtype=torch.int8) + in_scale = torch.tensor([0.1]) + in_zero_point = torch.tensor([0]) + output = torch.ops.cadence.quantized_softmax( + input_tensor, + None, # mask + 1, # dim + in_scale, + in_zero_point, + 0.004, # out_scale + 0, # out_zero_point + ) + + # Verify output properties + self.assertEqual(output.dtype, torch.int8, "Output dtype should be int8") + self.assertEqual( + output.shape, + input_tensor.shape, + "Output shape should match input shape", + )