diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index cdaca41569f..4a6edf03c0e 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -167,6 +167,13 @@ "where_Scalar.out(Tensor condition, float self, float other, *, Tensor(a!) out) -> Tensor(a!)" ) +lib.define( + "rope(Tensor input, Tensor sin_tensor, Tensor cos_tensor, Tensor? pos) -> (Tensor out)" +) +lib.define( + "rope.out(Tensor input, Tensor sin_tensor, Tensor cos_tensor, Tensor? pos, *, Tensor(a!) out) -> Tensor(a!)" +) + # ------------------------------------ # # Migrated from custom_ops.yaml # # ------------------------------------ # @@ -954,3 +961,29 @@ def where_Scalar_meta( other: float, ) -> torch.Tensor: return condition.new_empty(condition.size(), dtype=torch.float32) + + +@register_fake("cadence::rope") +def rope_meta( + input: torch.Tensor, + sin_tensor: torch.Tensor, + cos_tensor: torch.Tensor, + pos: Optional[torch.Tensor], +) -> torch.Tensor: + input_shape = list(input.shape) + assert ( + len(input_shape) in (4, 5) and input_shape[0] == 1 + ), f"input shape {input_shape} must be (1, seq, h, hd) or (1, seq, h, hd / 2, 2)" + seq = input_shape[1] + h = input_shape[2] + hd = prod(input_shape) / (seq * h) + sin_shape = list(sin_tensor.shape) + cos_shape = list(cos_tensor.shape) + assert sin_shape == cos_shape, f"{sin_shape=} must be same as {cos_shape}" + assert ( + len(sin_shape) == 2 and sin_shape[-1] == hd // 2 + ), f"{sin_shape=} must be [seq, hd/2]" + assert ( + pos is None or len(pos.shape) == 1 and pos.shape[0] == seq + ), f"{pos.shape} must be [{seq}]" + return input.new_empty(input.shape, dtype=input.dtype)