|
7 | 7 | # pyre-strict |
8 | 8 |
|
9 | 9 | from math import prod |
| 10 | +import numpy as np |
10 | 11 | from typing import Optional, Tuple |
11 | 12 |
|
12 | 13 | import torch |
|
167 | 168 | "where_Scalar.out(Tensor condition, float self, float other, *, Tensor(a!) out) -> Tensor(a!)" |
168 | 169 | ) |
169 | 170 |
|
| 171 | +lib.define("rope(Tensor input, Tensor sin_tensor, Tensor cos_tensor, Tensor? pos) -> (Tensor out)") |
| 172 | +lib.define( |
| 173 | + "rope.out(Tensor input, Tensor sin_tensor, Tensor cos_tensor, Tensor? pos, *, Tensor(a!) out) -> Tensor(a!)" |
| 174 | +) |
| 175 | + |
170 | 176 | # ------------------------------------ # |
171 | 177 | # Migrated from custom_ops.yaml # |
172 | 178 | # ------------------------------------ # |
@@ -954,3 +960,22 @@ def where_Scalar_meta( |
954 | 960 | other: float, |
955 | 961 | ) -> torch.Tensor: |
956 | 962 | return condition.new_empty(condition.size(), dtype=torch.float32) |
| 963 | + |
| 964 | +@register_fake("cadence::rope") |
| 965 | +def rope_meta( |
| 966 | + input: torch.Tensor, |
| 967 | + sin_tensor: torch.Tensor, |
| 968 | + cos_tensor: torch.Tensor, |
| 969 | + pos: Optional[torch.Tensor], |
| 970 | +) -> torch.Tensor: |
| 971 | + input_shape = list(input.shape) |
| 972 | + assert len(input_shape) in (4, 5) and input_shape[0] == 1, f"input shape {input_shape} must be (1, seq, h, hd) or (1, seq, h, hd / 2, 2)" |
| 973 | + seq = input_shape[1] |
| 974 | + h = input_shape[2] |
| 975 | + hd = np.prod(input_shape) / (seq * h) |
| 976 | + sin_shape = list(sin_tensor.shape) |
| 977 | + cos_shape = list(cos_tensor.shape) |
| 978 | + assert sin_shape == cos_shape, f"{sin_shape=} must be same as {cos_shape}" |
| 979 | + assert len(sin_shape) == 2 and sin_shape[-1] == hd//2, f"{sin_shape=} must be [seq, hd/2]" |
| 980 | + assert pos is None or len(pos.shape) == 1 and pos.shape[0] == seq, f"{pos.shape} must be [{seq}]" |
| 981 | + return input.new_empty(input.shape, dtype=input.dtype) |
0 commit comments