Skip to content

Commit 68b2d3c

Browse files
authored
Rope custom op
Differential Revision: D82702247 Pull Request resolved: pytorch#14399
1 parent 181ed4d commit 68b2d3c

File tree

2 files changed

+148
-0
lines changed

2 files changed

+148
-0
lines changed

backends/cadence/aot/ref_implementations.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1099,3 +1099,51 @@ def where_Scalar(
10991099
raise ValueError("condition must be a bool tensor")
11001100

11011101
return torch.where(condition, if_true, if_false)
1102+
1103+
1104+
@impl(m, "rope")
1105+
def rope(
1106+
input_tensor: torch.Tensor,
1107+
sin_tensor: torch.Tensor,
1108+
cos_tensor: torch.Tensor,
1109+
pos: torch.Tensor | None,
1110+
) -> torch.Tensor:
1111+
original_shape = input_tensor.shape
1112+
1113+
if len(original_shape) not in [4, 5]:
1114+
raise ValueError(
1115+
f"Input tensor must be 4D or 5D. Got {len(original_shape)}D tensor"
1116+
)
1117+
if original_shape[0] != 1:
1118+
raise ValueError("Input tensor must have batch size 1")
1119+
if len(original_shape) == 5:
1120+
input_tensor = input_tensor.view(
1121+
input_tensor.shape[0], input_tensor.shape[1], input_tensor.shape[2], -1
1122+
)
1123+
1124+
_, s, h, hd = input_tensor.shape
1125+
1126+
if hd % 2:
1127+
raise ValueError("Hidden dimension must be divisible by 2")
1128+
1129+
if sin_tensor.shape != (s, hd // 2) or cos_tensor.shape != (s, hd // 2):
1130+
raise ValueError(
1131+
f"sin_tensor and cos_tensor must have shape {s, hd // 2}. Got {sin_tensor.shape} and {cos_tensor.shape}"
1132+
)
1133+
1134+
if pos is not None:
1135+
if pos.shape != (input_tensor.shape[1],):
1136+
raise ValueError(
1137+
f"pos must have shape {input_tensor.shape[1]}. Got {pos.shape}"
1138+
)
1139+
sin_tensor = sin_tensor[pos]
1140+
cos_tensor = cos_tensor[pos]
1141+
1142+
sin_tensor = sin_tensor.unsqueeze(1)
1143+
cos_tensor = cos_tensor.unsqueeze(1)
1144+
1145+
x0, x1 = input_tensor[..., ::2], input_tensor[..., 1::2]
1146+
rotated = torch.cat(
1147+
[x0 * cos_tensor - x1 * sin_tensor, x0 * sin_tensor + x1 * cos_tensor], dim=-1
1148+
)
1149+
return rotated.view(original_shape)

backends/cadence/aot/tests/test_ref_implementations.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1156,3 +1156,103 @@ def test_where_Scalar(self) -> None:
11561156
torch.ops.cadence.where_Scalar(input_tensor, 1.0, 0.0)
11571157

11581158
self.assertIn("condition must be a bool tensor", str(context.exception))
1159+
1160+
@expand(
1161+
[
1162+
(
1163+
"h1xhd4",
1164+
torch.tensor([[[[1.0, 2.0, 3.0, 4.0]]]], dtype=torch.float32),
1165+
torch.tensor([[0.0, 0.0]], dtype=torch.float32),
1166+
torch.tensor([[1.0, 1.0]], dtype=torch.float32),
1167+
torch.tensor([[[[1.0, 3.0, 2.0, 4.0]]]], dtype=torch.float32),
1168+
),
1169+
(
1170+
"h2xhd4",
1171+
torch.tensor(
1172+
[[[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]]]],
1173+
dtype=torch.float32,
1174+
),
1175+
torch.tensor([[0.0, 1.0]], dtype=torch.float32),
1176+
torch.tensor([[1.0, 0.0]], dtype=torch.float32),
1177+
torch.tensor(
1178+
[[[[1.0, -4.0, 2.0, 3.0], [5, -8.0, 6.0, 7.0]]]],
1179+
dtype=torch.float32,
1180+
),
1181+
),
1182+
(
1183+
"s2xh2xhd4",
1184+
torch.tensor(
1185+
[
1186+
[
1187+
[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]],
1188+
[[9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]],
1189+
]
1190+
],
1191+
dtype=torch.float32,
1192+
),
1193+
torch.tensor([[0.0, 1.0], [0.0, 1.0]], dtype=torch.float32),
1194+
torch.tensor([[1.0, 0.0], [1.0, 0.0]], dtype=torch.float32),
1195+
torch.tensor(
1196+
[
1197+
[
1198+
[[1.0, -4.0, 2.0, 3.0], [5.0, -8.0, 6.0, 7.0]],
1199+
[[9.0, -12.0, 10.0, 11.0], [13.0, -16.0, 14.0, 15.0]],
1200+
]
1201+
],
1202+
dtype=torch.float32,
1203+
),
1204+
),
1205+
(
1206+
"pos_not_none",
1207+
torch.tensor(
1208+
[
1209+
[
1210+
[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]],
1211+
[[9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]],
1212+
]
1213+
],
1214+
dtype=torch.float32,
1215+
),
1216+
torch.tensor([[1.0, 0.0], [0.0, 1.0]], dtype=torch.float32),
1217+
torch.tensor([[0.0, 1.0], [1.0, 0.0]], dtype=torch.float32),
1218+
torch.tensor(
1219+
[
1220+
[
1221+
[[1.0, -4.0, 2.0, 3.0], [5.0, -8.0, 6.0, 7.0]],
1222+
[[-10.0, 11.0, 9.0, 12.0], [-14.0, 15.0, 13.0, 16.0]],
1223+
]
1224+
],
1225+
dtype=torch.float32,
1226+
),
1227+
torch.tensor([1, 0]),
1228+
),
1229+
]
1230+
)
1231+
def test_rope(
1232+
self,
1233+
name: str,
1234+
input_tensor: torch.Tensor,
1235+
sin_tensor: torch.Tensor,
1236+
cos_tensor: torch.Tensor,
1237+
expected_output: torch.Tensor,
1238+
pos: torch.Tensor | None = None,
1239+
) -> None:
1240+
output = torch.ops.cadence.rope(input_tensor, sin_tensor, cos_tensor, pos)
1241+
1242+
# Verify output properties
1243+
self.assertEqual(
1244+
output.dtype,
1245+
input_tensor.dtype,
1246+
f"Output dtype should match input dtype in {name}",
1247+
)
1248+
self.assertEqual(
1249+
output.shape,
1250+
input_tensor.shape,
1251+
f"Output shape should match input shape in {name}",
1252+
)
1253+
1254+
# Verify output matches expected values
1255+
self.assertTrue(
1256+
torch.allclose(output, expected_output, rtol=1e-4, atol=1e-4),
1257+
f"Output values don't match expected in {name}. Got {output}, expected {expected_output}",
1258+
)

0 commit comments

Comments
 (0)