From a5599b3bc418c011f5a523e9c591716602c3f5e5 Mon Sep 17 00:00:00 2001 From: Andrew Grebenisan Date: Mon, 29 Sep 2025 20:14:02 -0700 Subject: [PATCH] Rope custom op (#14399) Summary: Continued support of cadence custom ops Reviewed By: ethansfng, hsharma35 Differential Revision: D82702247 --- backends/cadence/aot/ref_implementations.py | 48 +++++++++ .../aot/tests/test_ref_implementations.py | 100 ++++++++++++++++++ 2 files changed, 148 insertions(+) diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index 781f04ae1da..52776c55c54 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -1099,3 +1099,51 @@ def where_Scalar( raise ValueError("condition must be a bool tensor") return torch.where(condition, if_true, if_false) + + +@impl(m, "rope") +def rope( + input_tensor: torch.Tensor, + sin_tensor: torch.Tensor, + cos_tensor: torch.Tensor, + pos: torch.Tensor | None, +) -> torch.Tensor: + original_shape = input_tensor.shape + + if len(original_shape) not in [4, 5]: + raise ValueError( + f"Input tensor must be 4D or 5D. Got {len(original_shape)}D tensor" + ) + if original_shape[0] != 1: + raise ValueError("Input tensor must have batch size 1") + if len(original_shape) == 5: + input_tensor = input_tensor.view( + input_tensor.shape[0], input_tensor.shape[1], input_tensor.shape[2], -1 + ) + + _, s, h, hd = input_tensor.shape + + if hd % 2: + raise ValueError("Hidden dimension must be divisible by 2") + + if sin_tensor.shape != (s, hd // 2) or cos_tensor.shape != (s, hd // 2): + raise ValueError( + f"sin_tensor and cos_tensor must have shape {s, hd // 2}. Got {sin_tensor.shape} and {cos_tensor.shape}" + ) + + if pos is not None: + if pos.shape != (input_tensor.shape[1],): + raise ValueError( + f"pos must have shape {input_tensor.shape[1]}. Got {pos.shape}" + ) + sin_tensor = sin_tensor[pos] + cos_tensor = cos_tensor[pos] + + sin_tensor = sin_tensor.unsqueeze(1) + cos_tensor = cos_tensor.unsqueeze(1) + + x0, x1 = input_tensor[..., ::2], input_tensor[..., 1::2] + rotated = torch.cat( + [x0 * cos_tensor - x1 * sin_tensor, x0 * sin_tensor + x1 * cos_tensor], dim=-1 + ) + return rotated.view(original_shape) diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index 26281b70216..2858f9781e5 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -1156,3 +1156,103 @@ def test_where_Scalar(self) -> None: torch.ops.cadence.where_Scalar(input_tensor, 1.0, 0.0) self.assertIn("condition must be a bool tensor", str(context.exception)) + + @expand( + [ + ( + "h1xhd4", + torch.tensor([[[[1.0, 2.0, 3.0, 4.0]]]], dtype=torch.float32), + torch.tensor([[0.0, 0.0]], dtype=torch.float32), + torch.tensor([[1.0, 1.0]], dtype=torch.float32), + torch.tensor([[[[1.0, 3.0, 2.0, 4.0]]]], dtype=torch.float32), + ), + ( + "h2xhd4", + torch.tensor( + [[[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]]]], + dtype=torch.float32, + ), + torch.tensor([[0.0, 1.0]], dtype=torch.float32), + torch.tensor([[1.0, 0.0]], dtype=torch.float32), + torch.tensor( + [[[[1.0, -4.0, 2.0, 3.0], [5, -8.0, 6.0, 7.0]]]], + dtype=torch.float32, + ), + ), + ( + "s2xh2xhd4", + torch.tensor( + [ + [ + [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]], + [[9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]], + ] + ], + dtype=torch.float32, + ), + torch.tensor([[0.0, 1.0], [0.0, 1.0]], dtype=torch.float32), + torch.tensor([[1.0, 0.0], [1.0, 0.0]], dtype=torch.float32), + torch.tensor( + [ + [ + [[1.0, -4.0, 2.0, 3.0], [5.0, -8.0, 6.0, 7.0]], + [[9.0, -12.0, 10.0, 11.0], [13.0, -16.0, 14.0, 15.0]], + ] + ], + dtype=torch.float32, + ), + ), + ( + "pos_not_none", + torch.tensor( + [ + [ + [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]], + [[9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]], + ] + ], + dtype=torch.float32, + ), + torch.tensor([[1.0, 0.0], [0.0, 1.0]], dtype=torch.float32), + torch.tensor([[0.0, 1.0], [1.0, 0.0]], dtype=torch.float32), + torch.tensor( + [ + [ + [[1.0, -4.0, 2.0, 3.0], [5.0, -8.0, 6.0, 7.0]], + [[-10.0, 11.0, 9.0, 12.0], [-14.0, 15.0, 13.0, 16.0]], + ] + ], + dtype=torch.float32, + ), + torch.tensor([1, 0]), + ), + ] + ) + def test_rope( + self, + name: str, + input_tensor: torch.Tensor, + sin_tensor: torch.Tensor, + cos_tensor: torch.Tensor, + expected_output: torch.Tensor, + pos: torch.Tensor | None = None, + ) -> None: + output = torch.ops.cadence.rope(input_tensor, sin_tensor, cos_tensor, pos) + + # Verify output properties + self.assertEqual( + output.dtype, + input_tensor.dtype, + f"Output dtype should match input dtype in {name}", + ) + self.assertEqual( + output.shape, + input_tensor.shape, + f"Output shape should match input shape in {name}", + ) + + # Verify output matches expected values + self.assertTrue( + torch.allclose(output, expected_output, rtol=1e-4, atol=1e-4), + f"Output values don't match expected in {name}. Got {output}, expected {expected_output}", + )