Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions backends/cadence/aot/ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1099,3 +1099,51 @@ def where_Scalar(
raise ValueError("condition must be a bool tensor")

return torch.where(condition, if_true, if_false)


@impl(m, "rope")
def rope(
input_tensor: torch.Tensor,
sin_tensor: torch.Tensor,
cos_tensor: torch.Tensor,
pos: torch.Tensor | None,
) -> torch.Tensor:
original_shape = input_tensor.shape

if len(original_shape) not in [4, 5]:
raise ValueError(
f"Input tensor must be 4D or 5D. Got {len(original_shape)}D tensor"
)
if original_shape[0] != 1:
raise ValueError("Input tensor must have batch size 1")
if len(original_shape) == 5:
input_tensor = input_tensor.view(
input_tensor.shape[0], input_tensor.shape[1], input_tensor.shape[2], -1
)

_, s, h, hd = input_tensor.shape

if hd % 2:
raise ValueError("Hidden dimension must be divisible by 2")

if sin_tensor.shape != (s, hd // 2) or cos_tensor.shape != (s, hd // 2):
raise ValueError(
f"sin_tensor and cos_tensor must have shape {s, hd // 2}. Got {sin_tensor.shape} and {cos_tensor.shape}"
)

if pos is not None:
if pos.shape != (input_tensor.shape[1],):
raise ValueError(
f"pos must have shape {input_tensor.shape[1]}. Got {pos.shape}"
)
sin_tensor = sin_tensor[pos]
cos_tensor = cos_tensor[pos]

sin_tensor = sin_tensor.unsqueeze(1)
cos_tensor = cos_tensor.unsqueeze(1)

x0, x1 = input_tensor[..., ::2], input_tensor[..., 1::2]
rotated = torch.cat(
[x0 * cos_tensor - x1 * sin_tensor, x0 * sin_tensor + x1 * cos_tensor], dim=-1
)
return rotated.view(original_shape)
100 changes: 100 additions & 0 deletions backends/cadence/aot/tests/test_ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1156,3 +1156,103 @@ def test_where_Scalar(self) -> None:
torch.ops.cadence.where_Scalar(input_tensor, 1.0, 0.0)

self.assertIn("condition must be a bool tensor", str(context.exception))

@expand(
[
(
"h1xhd4",
torch.tensor([[[[1.0, 2.0, 3.0, 4.0]]]], dtype=torch.float32),
torch.tensor([[0.0, 0.0]], dtype=torch.float32),
torch.tensor([[1.0, 1.0]], dtype=torch.float32),
torch.tensor([[[[1.0, 3.0, 2.0, 4.0]]]], dtype=torch.float32),
),
(
"h2xhd4",
torch.tensor(
[[[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]]]],
dtype=torch.float32,
),
torch.tensor([[0.0, 1.0]], dtype=torch.float32),
torch.tensor([[1.0, 0.0]], dtype=torch.float32),
torch.tensor(
[[[[1.0, -4.0, 2.0, 3.0], [5, -8.0, 6.0, 7.0]]]],
dtype=torch.float32,
),
),
(
"s2xh2xhd4",
torch.tensor(
[
[
[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]],
[[9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]],
]
],
dtype=torch.float32,
),
torch.tensor([[0.0, 1.0], [0.0, 1.0]], dtype=torch.float32),
torch.tensor([[1.0, 0.0], [1.0, 0.0]], dtype=torch.float32),
torch.tensor(
[
[
[[1.0, -4.0, 2.0, 3.0], [5.0, -8.0, 6.0, 7.0]],
[[9.0, -12.0, 10.0, 11.0], [13.0, -16.0, 14.0, 15.0]],
]
],
dtype=torch.float32,
),
),
(
"pos_not_none",
torch.tensor(
[
[
[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]],
[[9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]],
]
],
dtype=torch.float32,
),
torch.tensor([[1.0, 0.0], [0.0, 1.0]], dtype=torch.float32),
torch.tensor([[0.0, 1.0], [1.0, 0.0]], dtype=torch.float32),
torch.tensor(
[
[
[[1.0, -4.0, 2.0, 3.0], [5.0, -8.0, 6.0, 7.0]],
[[-10.0, 11.0, 9.0, 12.0], [-14.0, 15.0, 13.0, 16.0]],
]
],
dtype=torch.float32,
),
torch.tensor([1, 0]),
),
]
)
def test_rope(
self,
name: str,
input_tensor: torch.Tensor,
sin_tensor: torch.Tensor,
cos_tensor: torch.Tensor,
expected_output: torch.Tensor,
pos: torch.Tensor | None = None,
) -> None:
output = torch.ops.cadence.rope(input_tensor, sin_tensor, cos_tensor, pos)

# Verify output properties
self.assertEqual(
output.dtype,
input_tensor.dtype,
f"Output dtype should match input dtype in {name}",
)
self.assertEqual(
output.shape,
input_tensor.shape,
f"Output shape should match input shape in {name}",
)

# Verify output matches expected values
self.assertTrue(
torch.allclose(output, expected_output, rtol=1e-4, atol=1e-4),
f"Output values don't match expected in {name}. Got {output}, expected {expected_output}",
)
Loading