Skip to content

Commit 7aaede0

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Rope custom op (#14399)
Summary: Continued support of cadence custom ops Reviewed By: hsharma35 Differential Revision: D82702247
1 parent 181ed4d commit 7aaede0

File tree

2 files changed

+136
-0
lines changed

2 files changed

+136
-0
lines changed

backends/cadence/aot/ref_implementations.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1099,3 +1099,43 @@ def where_Scalar(
10991099
raise ValueError("condition must be a bool tensor")
11001100

11011101
return torch.where(condition, if_true, if_false)
1102+
1103+
1104+
@impl(m, "rope")
1105+
def rope(
1106+
input_tensor: torch.Tensor,
1107+
sin_tensor: torch.Tensor,
1108+
cos_tensor: torch.Tensor,
1109+
pos: torch.Tensor | None,
1110+
) -> torch.Tensor:
1111+
original_shape = input_tensor.shape
1112+
1113+
if len(original_shape) not in [4, 5]:
1114+
raise ValueError(f"Input tensor must be 4D or 5D. Got {len(original_shape)}D tensor")
1115+
if original_shape[0] != 1:
1116+
raise ValueError("Input tensor must have batch size 1")
1117+
if len(original_shape) == 5:
1118+
input_tensor = input_tensor.view(input_tensor.shape[0], input_tensor.shape[1], input_tensor.shape[2], -1)
1119+
1120+
_, s, h, hd = input_tensor.shape
1121+
1122+
if hd % 2:
1123+
raise ValueError("Hidden dimension must be divisible by 2")
1124+
1125+
if sin_tensor.shape != (s, hd // 2) or cos_tensor.shape != (s, hd // 2):
1126+
raise ValueError(f"sin_tensor and cos_tensor must have shape {s, hd // 2}. Got {sin_tensor.shape} and {cos_tensor.shape}")
1127+
1128+
if pos is not None:
1129+
if pos.shape != (input_tensor.shape[1],):
1130+
raise ValueError(f"pos must have shape {input_tensor.shape[1]}. Got {pos.shape}")
1131+
sin_tensor = sin_tensor[pos]
1132+
cos_tensor = cos_tensor[pos]
1133+
1134+
sin_tensor = sin_tensor.unsqueeze(1)
1135+
cos_tensor = cos_tensor.unsqueeze(1)
1136+
1137+
x0, x1 = input_tensor[..., ::2], input_tensor[..., 1::2]
1138+
rotated = torch.cat(
1139+
[x0 * cos_tensor - x1 * sin_tensor, x0 * sin_tensor + x1 * cos_tensor], dim=-1
1140+
)
1141+
return rotated.view(original_shape)

backends/cadence/aot/tests/test_ref_implementations.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1156,3 +1156,99 @@ def test_where_Scalar(self) -> None:
11561156
torch.ops.cadence.where_Scalar(input_tensor, 1.0, 0.0)
11571157

11581158
self.assertIn("condition must be a bool tensor", str(context.exception))
1159+
1160+
@expand(
1161+
[
1162+
(
1163+
"h1xhd4",
1164+
torch.tensor(
1165+
[[[[1.0, 2.0, 3.0, 4.0]]]], dtype=torch.float32
1166+
),
1167+
torch.tensor(
1168+
[[0.0, 0.0]], dtype=torch.float32
1169+
),
1170+
torch.tensor(
1171+
[[1.0, 1.0]], dtype=torch.float32
1172+
),
1173+
torch.tensor(
1174+
[[[[1.0, 3.0, 2.0, 4.0]]]], dtype=torch.float32
1175+
),
1176+
),
1177+
(
1178+
"h2xhd4",
1179+
torch.tensor(
1180+
[[[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]]]], dtype=torch.float32
1181+
),
1182+
torch.tensor(
1183+
[[0.0, 1.0]], dtype=torch.float32
1184+
),
1185+
torch.tensor(
1186+
[[1.0, 0.0]], dtype=torch.float32
1187+
),
1188+
torch.tensor(
1189+
[[[[1.0, -4.0, 2.0, 3.0], [5, -8.0, 6.0, 7.0]]]], dtype=torch.float32
1190+
),
1191+
),
1192+
(
1193+
"s2xh2xhd4",
1194+
torch.tensor(
1195+
[[[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]], [[9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]]]], dtype=torch.float32
1196+
),
1197+
torch.tensor(
1198+
[[0.0, 1.0], [0.0, 1.0]], dtype=torch.float32
1199+
),
1200+
torch.tensor(
1201+
[[1.0, 0.0], [1.0, 0.0]], dtype=torch.float32
1202+
),
1203+
torch.tensor([[[[ 1., -4., 2., 3. ],
1204+
[ 5., -8., 6., 7. ]],
1205+
[[ 9., -12., 10., 11. ],
1206+
[ 13., -16., 14., 15. ]]]], dtype=torch.float32),
1207+
),
1208+
(
1209+
"pos_not_none",
1210+
torch.tensor(
1211+
[[[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]], [[9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]]]], dtype=torch.float32
1212+
),
1213+
torch.tensor(
1214+
[[1.0, 0.0], [0.0, 1.0]], dtype=torch.float32
1215+
),
1216+
torch.tensor(
1217+
[[0.0, 1.0], [1.0, 0.0]], dtype=torch.float32
1218+
),
1219+
torch.tensor([[[[ 1., -4., 2., 3. ],
1220+
[ 5., -8., 6., 7. ]],
1221+
[[ -10., 11., 9., 12. ],
1222+
[ -14., 15., 13., 16. ]]]], dtype=torch.float32),
1223+
torch.tensor([1, 0])
1224+
),
1225+
]
1226+
)
1227+
def test_rope(
1228+
self,
1229+
name: str,
1230+
input_tensor: torch.Tensor,
1231+
sin_tensor: torch.Tensor,
1232+
cos_tensor: torch.Tensor,
1233+
expected_output: torch.Tensor,
1234+
pos: torch.Tensor | None = None,
1235+
) -> None:
1236+
output = torch.ops.cadence.rope(input_tensor, sin_tensor, cos_tensor, pos)
1237+
1238+
# Verify output properties
1239+
self.assertEqual(
1240+
output.dtype,
1241+
input_tensor.dtype,
1242+
f"Output dtype should match input dtype in {name}",
1243+
)
1244+
self.assertEqual(
1245+
output.shape,
1246+
input_tensor.shape,
1247+
f"Output shape should match input shape in {name}",
1248+
)
1249+
1250+
# Verify output matches expected values
1251+
self.assertTrue(
1252+
torch.allclose(output, expected_output, rtol=1e-4, atol=1e-4),
1253+
f"Output values don't match expected in {name}. Got {output}, expected {expected_output}",
1254+
)

0 commit comments

Comments
 (0)