Skip to content
14 changes: 14 additions & 0 deletions examples/models/llama/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ runtime.python_library(
"source_transformation/sdpa.py",
"source_transformation/spin_quant.py",
"source_transformation/vulkan_rope.py",
"source_transformation/attention_sink.py",
],
_is_external_target = True,
base_module = "executorch.examples.models.llama",
Expand Down Expand Up @@ -213,3 +214,16 @@ runtime.python_test(
"//executorch/examples/models/llama:llama_transformer",
],
)

runtime.python_test(
name = "attention_sink_test",
srcs = [
"source_transformation/test_attention_sink.py",
],
supports_static_listing = False,
deps = [
"fbsource//third-party/pypi/parameterized:parameterized",
"//caffe2:torch",
":export_library",
],
)
41 changes: 41 additions & 0 deletions examples/models/llama/rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,22 @@ def apply_rotary_emb(
return xq_out.type_as(xq), xk_out.type_as(xk)


def apply_rotary_emb_to_k(
xk: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor
) -> torch.Tensor:
xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)

freqs_cos = reshape_for_broadcast(freqs_cos, xk_r)
freqs_sin = reshape_for_broadcast(freqs_sin, xk_r)

xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos

xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3)

return xk_out.type_as(xk)


class RotaryEmbedding(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -160,3 +176,28 @@ def hf_apply_rotary_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed


def hf_apply_rotary_emb_to_k(k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the key tensors.

Args:
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of k. Similarly, if k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`torch.Tensor` the key tensor rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
k_embed = (k * cos) + (rotate_half(k) * sin)
return k_embed
62 changes: 62 additions & 0 deletions examples/models/llama/source_transformation/attention_sink.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# Components for supporting Attention Sink. See
# https://arxiv.org/abs/2309.17453 for more details about Attention Sink.

import torch

from executorch.examples.models.llama.llama_transformer import ModelArgs, Rope
from executorch.examples.models.llama.rope import (
apply_rotary_emb_to_k,
hf_apply_rotary_emb_to_k,
)


class RopeWithAttentionSink(Rope):
"""
Rope that helps adjust position encoding when tokens are shifted in KVCache.
For AttentionSink, when tokens are shifted in KVCache, we need to use positions
in KVCache instead of positions in the actual text.
"""

def __init__(self, params: ModelArgs):
super().__init__(params)
if self.params.use_hf_rope:
self.apply_rotary_emb_to_k = hf_apply_rotary_emb_to_k
else:
self.apply_rotary_emb_to_k = apply_rotary_emb_to_k

def rerotate_k(
self,
k: torch.Tensor,
original_position: int,
new_position: int,
):
"""
Rerotate k from original_position to new_position. This is done by rerotating
k with (new_position * theta - original_position * theta) with the following matrix:
(cos(delta), -sin(delta)
sin(delta), cos(delta))
where delta = new_position * theta - original_position * theta

The shape of k is (batch_size, seq_len, n_local_heads, head_dim)

Based on https://github.com/huggingface/transformers/blame/main/src/transformers/cache_utils.py#L961
"""
seq_len = k.shape[1]
original_freqs_cos = self.freqs_cos.narrow(0, original_position, seq_len)
original_freqs_sin = self.freqs_sin.narrow(0, original_position, seq_len)
new_freqs_cos = self.freqs_cos.narrow(0, new_position, seq_len)
new_freqs_sin = self.freqs_sin.narrow(0, new_position, seq_len)
rerotation_cos = (
new_freqs_cos * original_freqs_cos + new_freqs_sin * original_freqs_sin
)
rerotation_sin = (
new_freqs_sin * original_freqs_cos - new_freqs_cos * original_freqs_sin
)

return self.apply_rotary_emb_to_k(k, rerotation_cos, rerotation_sin)
73 changes: 73 additions & 0 deletions examples/models/llama/source_transformation/test_attention_sink.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from executorch.examples.models.llama.llama_transformer import ModelArgs

from executorch.examples.models.llama.source_transformation.attention_sink import (
RopeWithAttentionSink,
)
from parameterized import parameterized


class RopeWithAttentionSinkTest(unittest.TestCase):

def setUp(self):
torch.manual_seed(42)
self.params = ModelArgs(use_kv_cache=True, enable_dynamic_shape=True)
self.rope_with_attention_sink = RopeWithAttentionSink(params=self.params)

@parameterized.expand(
[
[128, 127], # Rotate left
[128, 128], # No rotation
[128, 129], # Rotate right
]
)
def test_rotate(self, original_position, new_position):
seq_len = 32

q = torch.rand(
1, seq_len, self.params.n_heads, self.params.head_dim, dtype=torch.float32
)
k = torch.rand(
1,
seq_len,
self.params.n_heads,
self.params.head_dim,
dtype=torch.float32,
)
freqs_cos, freqs_sin = self.rope_with_attention_sink.get_freqs(
input_pos=torch.tensor([original_position], dtype=torch.int32),
seq_len=seq_len,
)
_, pre_rotated_k = self.rope_with_attention_sink.forward(
q=q,
k=k,
freqs_cos=freqs_cos,
freqs_sin=freqs_sin,
)

rerotated_k = self.rope_with_attention_sink.rerotate_k(
k=pre_rotated_k,
original_position=original_position,
new_position=new_position,
)

freqs_cos, freqs_sin = self.rope_with_attention_sink.get_freqs(
input_pos=torch.tensor([new_position], dtype=torch.int32),
seq_len=seq_len,
)
_, expected_k = self.rope_with_attention_sink.forward(
q=q,
k=k,
freqs_cos=freqs_cos,
freqs_sin=freqs_sin,
)

torch.testing.assert_close(rerotated_k, expected_k)
Loading