From fa0260a13f912c60c3f23e0a5c111b5949735aa9 Mon Sep 17 00:00:00 2001 From: Andrew Grebenisan Date: Tue, 21 Oct 2025 09:48:38 -0700 Subject: [PATCH] Cadence ops: Support strongly typed softmax (#15201) Summary: As titled. Differential Revision: D84845481 --- backends/cadence/aot/ops_registrations.py | 12 ++++++------ backends/cadence/aot/ref_implementations.py | 11 +++++++++++ .../cadence/aot/tests/test_ref_implementations.py | 9 +++++++++ 3 files changed, 26 insertions(+), 6 deletions(-) diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index 145dc16557a..d8db866fa4e 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -53,7 +53,6 @@ def _validate_ref_impl_exists() -> None: # 1. be removed # 2. have a reference implementation added to ref_implementations.py _WARN_ONLY = { - "cadence::_softmax_f32_f32", "cadence::quantized_softmax.per_tensor", "cadence::quantized_softmax", "cadence::quantized_w8a32_gru", @@ -640,10 +639,10 @@ def register_fake( "int sampling_ratio, bool aligned) -> (Tensor out)" ) lib.define( - "_softmax_f32_f32(Tensor self, int dim, bool? half_to_float) -> (Tensor out)" + "_softmax_f32_f32(Tensor self, int dim, bool? half_to_float = None) -> (Tensor out)" ) lib.define( - "_softmax_f32_f32.out(Tensor self, int dim, bool? half_to_float, *, Tensor(a!) out) -> Tensor(a!)" + "_softmax_f32_f32.out(Tensor self, int dim, bool? half_to_float = None, *, Tensor(a!) out) -> Tensor(a!)" ) lib.define( @@ -2652,12 +2651,13 @@ def quantized_conv1d_nlc_asym8uxsym8u_asym8u_per_tensor_meta( @register_fake("cadence::_softmax_f32_f32") def softmax_f32_f32_meta( - self: torch.Tensor, + input_tensor: torch.Tensor, dim: int, - dtype: torch.dtype, half_to_float: Optional[bool] = None, ) -> torch.Tensor: - return self.new_empty(self.size(), dtype=self.dtype) + assert input_tensor.dtype == torch.float32, "input_tensor must be float32" + assert not half_to_float, "half_to_float is not supported" + return input_tensor.new_empty(input_tensor.size(), dtype=torch.float32) @register_fake("cadence::quantized_softmax") diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index e284a7e639b..6e0c116ad45 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -1979,3 +1979,14 @@ def linalg_svd( assert compute_uv U, S, Vh = torch.linalg.svd(A, full_matrices=full_matrices, driver=driver) return U.contiguous(), S.contiguous(), Vh.contiguous() + + +@impl_tracked(m, "_softmax_f32_f32") +def softmax_f32_f32( + input_tensor: torch.Tensor, + dim: int, + half_to_float: bool | None = None, +) -> torch.Tensor: + assert input_tensor.dtype == torch.float32, "input_tensor must be float32" + assert not half_to_float, "half_to_float is not supported" + return torch.nn.functional.softmax(input_tensor, dim=dim, dtype=torch.float32) diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index 02485b0ae09..e9ba52c58b9 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -2885,3 +2885,12 @@ def test_quantized_layer_norm(self) -> None: output_scale, output_zero_point, ) + + def test_softmax_f32_f32(self) -> None: + # Just a wrapper around torch.nn.functional.softmax, so just ensure that it runs + input_tensor = torch.tensor( + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32 + ) + output = torch.ops.cadence._softmax_f32_f32(input_tensor, dim=1) + self.assertEqual(output.dtype, torch.float32) + self.assertEqual(output.shape, input_tensor.shape)