Skip to content

Commit 4d37069

Browse files
nitish2112facebook-github-bot
authored andcommitted
: Adds DSP op for RoPE (#11264)
Summary: Pull Request resolved: #11264 Reviewed By: hsharma35 Differential Revision: D75605145
1 parent 851f5fc commit 4d37069

File tree

1 file changed

+32
-0
lines changed

1 file changed

+32
-0
lines changed

backends/cadence/aot/ops_registrations.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,13 @@
167167
"where_Scalar.out(Tensor condition, float self, float other, *, Tensor(a!) out) -> Tensor(a!)"
168168
)
169169

170+
lib.define(
171+
"rope(Tensor input, Tensor sin_tensor, Tensor cos_tensor, Tensor? pos) -> (Tensor out)"
172+
)
173+
lib.define(
174+
"rope.out(Tensor input, Tensor sin_tensor, Tensor cos_tensor, Tensor? pos, *, Tensor(a!) out) -> Tensor(a!)"
175+
)
176+
170177
# ------------------------------------ #
171178
# Migrated from custom_ops.yaml #
172179
# ------------------------------------ #
@@ -954,3 +961,28 @@ def where_Scalar_meta(
954961
other: float,
955962
) -> torch.Tensor:
956963
return condition.new_empty(condition.size(), dtype=torch.float32)
964+
965+
@register_fake("cadence::rope")
966+
def rope_meta(
967+
input: torch.Tensor,
968+
sin_tensor: torch.Tensor,
969+
cos_tensor: torch.Tensor,
970+
pos: Optional[torch.Tensor],
971+
) -> torch.Tensor:
972+
input_shape = list(input.shape)
973+
assert (
974+
len(input_shape) in (4, 5) and input_shape[0] == 1
975+
), f"input shape {input_shape} must be (1, seq, h, hd) or (1, seq, h, hd / 2, 2)"
976+
seq = input_shape[1]
977+
h = input_shape[2]
978+
hd = prod(input_shape) / (seq * h)
979+
sin_shape = list(sin_tensor.shape)
980+
cos_shape = list(cos_tensor.shape)
981+
assert sin_shape == cos_shape, f"{sin_shape=} must be same as {cos_shape}"
982+
assert (
983+
len(sin_shape) == 2 and sin_shape[-1] == hd//2
984+
), f"{sin_shape=} must be [seq, hd/2]"
985+
assert (
986+
pos is None or len(pos.shape) == 1 and pos.shape[0] == seq
987+
), f"{pos.shape} must be [{seq}]"
988+
return input.new_empty(input.shape, dtype=input.dtype)

0 commit comments

Comments
 (0)