Skip to content

Commit b3a9388

Browse files
nitish2112facebook-github-bot
authored andcommitted
: Adds DSP op for RoPE
Reviewed By: hsharma35 Differential Revision: D75605145
1 parent 851f5fc commit b3a9388

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

backends/cadence/aot/ops_registrations.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
# pyre-strict
88

99
from math import prod
10+
import numpy as np
1011
from typing import Optional, Tuple
1112

1213
import torch
@@ -167,6 +168,11 @@
167168
"where_Scalar.out(Tensor condition, float self, float other, *, Tensor(a!) out) -> Tensor(a!)"
168169
)
169170

171+
lib.define("rope(Tensor input, Tensor sin_tensor, Tensor cos_tensor, Tensor? pos) -> (Tensor out)")
172+
lib.define(
173+
"rope.out(Tensor input, Tensor sin_tensor, Tensor cos_tensor, Tensor? pos, *, Tensor(a!) out) -> Tensor(a!)"
174+
)
175+
170176
# ------------------------------------ #
171177
# Migrated from custom_ops.yaml #
172178
# ------------------------------------ #
@@ -954,3 +960,22 @@ def where_Scalar_meta(
954960
other: float,
955961
) -> torch.Tensor:
956962
return condition.new_empty(condition.size(), dtype=torch.float32)
963+
964+
@register_fake("cadence::rope")
965+
def rope_meta(
966+
input: torch.Tensor,
967+
sin_tensor: torch.Tensor,
968+
cos_tensor: torch.Tensor,
969+
pos: Optional[torch.Tensor],
970+
) -> torch.Tensor:
971+
input_shape = list(input.shape)
972+
assert len(input_shape) in (4, 5) and input_shape[0] == 1, f"input shape {input_shape} must be (1, seq, h, hd) or (1, seq, h, hd / 2, 2)"
973+
seq = input_shape[1]
974+
h = input_shape[2]
975+
hd = np.prod(input_shape) / (seq * h)
976+
sin_shape = list(sin_tensor.shape)
977+
cos_shape = list(cos_tensor.shape)
978+
assert sin_shape == cos_shape, f"{sin_shape=} must be same as {cos_shape}"
979+
assert len(sin_shape) == 2 and sin_shape[-1] == hd//2, f"{sin_shape=} must be [seq, hd/2]"
980+
assert pos is None or len(pos.shape) == 1 and pos.shape[0] == seq, f"{pos.shape} must be [{seq}]"
981+
return input.new_empty(input.shape, dtype=input.dtype)

0 commit comments

Comments
 (0)