From c95f97e9bbf5ff3871b5b7ca6bf65e1d97ed60b7 Mon Sep 17 00:00:00 2001 From: Lunwen He Date: Tue, 26 Nov 2024 17:13:08 -0800 Subject: [PATCH 1/2] move rope related logic together Pull Request resolved: https://github.com/pytorch/executorch/pull/6560 Right now, rope related code scatters around a few different places in `llama_transformer`. It makes it hard to make changes to rope related things. This PR moves all rope related logic into its own module. ghstack-source-id: 255543205 Differential Revision: [D65173598](https://our.internmc.facebook.com/intern/diff/D65173598/) --- examples/models/llama/llama_transformer.py | 139 +++++++++++------- .../llama/source_transformation/rope.py | 28 ++-- 2 files changed, 101 insertions(+), 66 deletions(-) diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index 3f8b8dd6547..10d660d37a6 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -147,6 +147,81 @@ def __post_init__(self): self.head_dim = self.dim // self.n_heads +class Rope(torch.nn.Module): + def __init__(self, params: ModelArgs): + super().__init__() + self.params = params + if self.params.use_hf_rope: + self.precompute_freqs_cis = hf_precompute_freqs_cis + else: + self.precompute_freqs_cis = partial( + precompute_freqs_cis, use_scaled=self.params.use_scaled_rope + ) + freqs_cos, freqs_sin = self.precompute_freqs_cis( + self.params.head_dim, + ( + self.params.max_seq_len # Normal llama2. + if self.params.ffn_dim_multiplier is None + else self.params.max_seq_len * 2 # Sharded checkpoint. + ), + self.params.rope_freq_base, + ) + self.register_buffer("freqs_cos", freqs_cos, persistent=False) + self.register_buffer("freqs_sin", freqs_sin, persistent=False) + if self.params.use_hf_rope: + self.apply_rotary_emb = hf_apply_rotary_emb + else: + self.apply_rotary_emb = RotaryEmbedding() + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + ): + return self.apply_rotary_emb(q, k, freqs_cos, freqs_sin) + + def get_freqs(self, input_pos: Optional[torch.Tensor], seq_len: int): + """ + Get the precomputed frequencies for the given input position and sequence length. + + Args: + input_pos (torch.Tensor): The input position tensor. + seq_len (int): The sequence length. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The precomputed frequencies for the given input position and sequence length. + """ + if self.params.use_kv_cache: + assert ( + input_pos is not None + ), "input_pos must be provided when use_kv_cache is True" + + if self.params.enable_dynamic_shape: + # when KV cache is used, seqlen is most likely 1. We want to slice from the start_pos. + input_pos_item = input_pos[-1].item() + torch._check_is_size(input_pos_item) + torch._check(input_pos_item < self.params.max_seq_len) + # pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor + freqs_cos = self.freqs_cos.narrow(0, input_pos_item, seq_len) + # pyre-ignore: Incompatible parameter type [6] + freqs_sin = self.freqs_sin.narrow(0, input_pos_item, seq_len) + else: + # When not using dynamic shape, use of the .item results in + # symints, due to querying the data from tensor. + # this path avoids that for mps backend, although probably mps backend + # can support dynamic shape? + freqs_cos = self.freqs_cos[input_pos] + freqs_sin = self.freqs_sin[input_pos] + + else: + assert input_pos is None, "input_pos is unused when use_kv_cache is False" + freqs_cos = self.freqs_cos[:seq_len] + freqs_sin = self.freqs_sin[:seq_len] + return freqs_cos, freqs_sin + + class KVCache(nn.Module): def __init__( self, @@ -266,7 +341,7 @@ def forward( class Attention(nn.Module): - def __init__(self, args: ModelArgs, layer_id: int): + def __init__(self, args: ModelArgs, layer_id: int, rope: Rope): super().__init__() self.use_kv_cache = args.use_kv_cache self.n_heads = args.n_heads @@ -287,6 +362,8 @@ def __init__(self, args: ModelArgs, layer_id: int): self.layer_id = layer_id + self.rope = rope + causal_mask = torch.tril( torch.ones( self.max_seq_len, @@ -303,7 +380,7 @@ def __init__(self, args: ModelArgs, layer_id: int): args.max_seq_len, self.n_kv_heads, self.head_dim, - not args.use_sdpa_with_kv_cache_op, # if we are using the custom op dont transpose the cache. Expect untransposed q k v + not args.use_sdpa_with_kv_cache_op, # if we are using the custom op don't transpose the cache. Expect untransposed q k v args.enable_dynamic_shape, ) self.SDPA = SDPA( @@ -314,10 +391,6 @@ def __init__(self, args: ModelArgs, layer_id: int): max_seq_len=self.max_seq_len, enable_dynamic_shape=args.enable_dynamic_shape, ) - if args.use_hf_rope: - self.apply_rotary_emb = hf_apply_rotary_emb - else: - self.apply_rotary_emb = RotaryEmbedding() def forward( self, @@ -336,7 +409,7 @@ def forward( v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) # RoPE relative positional embeddings - q, k = self.apply_rotary_emb(q, k, freqs_cos, freqs_sin) + q, k = self.rope.forward(q, k, freqs_cos, freqs_sin) if self.use_kv_cache: assert input_pos is not None @@ -424,13 +497,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class TransformerBlock(nn.Module): - def __init__(self, layer_id: int, args: ModelArgs): + def __init__(self, layer_id: int, args: ModelArgs, rope: Rope): super().__init__() self.use_kv_cache = args.use_kv_cache self.n_heads = args.n_heads self.dim = args.dim self.head_dim = args.head_dim - self.attention = Attention(args, layer_id) + self.attention = Attention(args, layer_id, rope) if args.moe: self.block_sparse_moe = MOEFeedForward(args) else: @@ -459,9 +532,10 @@ def __init__(self, params: ModelArgs): self.n_layers = params.n_layers self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim) + self.rope = Rope(params) self.layers = torch.nn.ModuleList() for layer_id in range(params.n_layers): - self.layers.append(TransformerBlock(layer_id, params)) + self.layers.append(TransformerBlock(layer_id, params, self.rope)) self.norm = RMSNorm(params.dim, eps=params.norm_eps) self.output = nn.Linear(params.dim, params.vocab_size, bias=False) self.use_kv_cache = params.use_kv_cache @@ -469,23 +543,6 @@ def __init__(self, params: ModelArgs): self.max_seq_len = params.max_seq_len self.input_prune_map = params.input_prune_map self.output_prune_map = params.output_prune_map - if params.use_hf_rope: - self.precompute_freqs_cis = hf_precompute_freqs_cis - else: - self.precompute_freqs_cis = partial( - precompute_freqs_cis, use_scaled=params.use_scaled_rope - ) - freqs_cos, freqs_sin = self.precompute_freqs_cis( - params.head_dim, - ( - params.max_seq_len # Normal llama2. - if params.ffn_dim_multiplier is None - else params.max_seq_len * 2 # Sharded checkpoint. - ), - params.rope_freq_base, - ) - self.register_buffer("freqs_cos", freqs_cos, persistent=False) - self.register_buffer("freqs_sin", freqs_sin, persistent=False) def forward( self, @@ -502,33 +559,7 @@ def forward( if tokens is not None and h is None: h = self.tok_embeddings(tokens) seqlen = h.shape[1] - - if self.use_kv_cache: - assert ( - input_pos is not None - ), "input_pos must be provided when use_kv_cache is True" - - if self.params.enable_dynamic_shape: - # when KV cache is used, seqlen is most likely 1. We want to slice from the start_pos. - input_pos_item = input_pos[-1].item() - torch._check_is_size(input_pos_item) - torch._check(input_pos_item < self.params.max_seq_len) - # pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor - freqs_cos = self.freqs_cos.narrow(0, input_pos_item, seqlen) - # pyre-ignore: Incompatible parameter type [6] - freqs_sin = self.freqs_sin.narrow(0, input_pos_item, seqlen) - else: - # When not using dynamic shape, use of the .item results in - # symints, due to querying the data from tensor. - # this path avoids that for mps backend, although probably mps backend - # can support dynamic shape? - freqs_cos = self.freqs_cos[input_pos] - freqs_sin = self.freqs_sin[input_pos] - - else: - assert input_pos is None, "input_pos is unused when use_kv_cache is False" - freqs_cos = self.freqs_cos[:seqlen] - freqs_sin = self.freqs_sin[:seqlen] + freqs_cos, freqs_sin = self.rope.get_freqs(input_pos, seqlen) for layer in self.layers: h = layer( diff --git a/examples/models/llama/source_transformation/rope.py b/examples/models/llama/source_transformation/rope.py index a2a2264b247..79fb2399669 100644 --- a/examples/models/llama/source_transformation/rope.py +++ b/examples/models/llama/source_transformation/rope.py @@ -13,23 +13,27 @@ def materialze_broadcast_of_rope_freq_cis( module: torch.nn.Module, ): assert isinstance(module, Transformer) - assert module.freqs_cos.dim() == 2 - dim0 = module.freqs_cos.size(0) - dim1 = module.freqs_cos.size(1) + assert module.rope.freqs_cos.dim() == 2 + dim0 = module.rope.freqs_cos.size(0) + dim1 = module.rope.freqs_cos.size(1) module_attention = module.layers[0].attention assert ( module_attention.n_local_kv_heads == module_attention.n_local_heads ), f"For rope freqs to be materialized for broadcast, q, k, v num heads must match. For q got {module_attention.n_kv_heads} for k got {module_attention.n_local_heads} and v got {module_attention.n_local_kv_heads}" num_heads = module_attention.n_local_heads - module.freqs_cos = module.freqs_cos.view(dim0, 1, dim1) - module.freqs_cos = module.freqs_cos.expand(dim0, num_heads, dim1).contiguous() - assert module.freqs_sin.dim() == 2 - assert dim0 == module.freqs_sin.size( + module.rope.freqs_cos = module.rope.freqs_cos.view(dim0, 1, dim1) + module.rope.freqs_cos = module.rope.freqs_cos.expand( + dim0, num_heads, dim1 + ).contiguous() + assert module.rope.freqs_sin.dim() == 2 + assert dim0 == module.rope.freqs_sin.size( 0 - ), f"sin and cos freq table sizes must match. Mismatch found at dim 0: {dim0} vs {module.freqs_sin.size(0)}" - assert dim1 == module.freqs_sin.size( + ), f"sin and cos freq table sizes must match. Mismatch found at dim 0: {dim0} vs {module.rope.freqs_sin.size(0)}" + assert dim1 == module.rope.freqs_sin.size( 1 - ), f"sin and cos freq table sizes must match. Mismatch found at dim 1: {dim1} vs {module.freqs_sin.size(1)}" - module.freqs_sin = module.freqs_sin.view(dim0, 1, dim1) - module.freqs_sin = module.freqs_sin.expand(dim0, num_heads, dim1).contiguous() + ), f"sin and cos freq table sizes must match. Mismatch found at dim 1: {dim1} vs {module.rope.freqs_sin.size(1)}" + module.rope.freqs_sin = module.rope.freqs_sin.view(dim0, 1, dim1) + module.rope.freqs_sin = module.rope.freqs_sin.expand( + dim0, num_heads, dim1 + ).contiguous() return module From 923e31e8a5f75b693bd2d8732070a6fe2cad7392 Mon Sep 17 00:00:00 2001 From: Lunwen He Date: Tue, 26 Nov 2024 17:13:09 -0800 Subject: [PATCH 2/2] implement position encoding for shifted tokens Pull Request resolved: https://github.com/pytorch/executorch/pull/6646 In AttentionSink, it uses tokens' positions in the KVCache instead of the actual text. When tokens get shifted in KVCache, it needs to update q and k's position embedding. In the original [implementation](https://github.com/mit-han-lab/streaming-llm) of AttentionSink with Rope, it caches the original q and k in KVCache and apply position embedding during inference. This PR adds `RopeWithAttentionSink`. It assumes that q and k are already encoded with their original position. When we shift tokens, we reapply the position delta. This has two benefits: - minimize our code since our existing `llama_transformer` applies rope embedding before doing KVCache update - avoid performance regression when tokens are not shifted because we don't need to reapply position encoding in KVCache for them ghstack-source-id: 255579838 Differential Revision: [D65366440](https://our.internmc.facebook.com/intern/diff/D65366440/) --- examples/models/llama/TARGETS | 14 ++++ examples/models/llama/rope.py | 41 +++++++++++ .../source_transformation/attention_sink.py | 62 ++++++++++++++++ .../test_attention_sink.py | 73 +++++++++++++++++++ 4 files changed, 190 insertions(+) create mode 100644 examples/models/llama/source_transformation/attention_sink.py create mode 100644 examples/models/llama/source_transformation/test_attention_sink.py diff --git a/examples/models/llama/TARGETS b/examples/models/llama/TARGETS index cf387bfab24..284520d4d5e 100644 --- a/examples/models/llama/TARGETS +++ b/examples/models/llama/TARGETS @@ -93,6 +93,7 @@ runtime.python_library( "source_transformation/sdpa.py", "source_transformation/spin_quant.py", "source_transformation/vulkan_rope.py", + "source_transformation/attention_sink.py", ], _is_external_target = True, base_module = "executorch.examples.models.llama", @@ -213,3 +214,16 @@ runtime.python_test( "//executorch/examples/models/llama:llama_transformer", ], ) + +runtime.python_test( + name = "attention_sink_test", + srcs = [ + "source_transformation/test_attention_sink.py", + ], + supports_static_listing = False, + deps = [ + "fbsource//third-party/pypi/parameterized:parameterized", + "//caffe2:torch", + ":export_library", + ], +) diff --git a/examples/models/llama/rope.py b/examples/models/llama/rope.py index 0383c798988..1445787f5eb 100644 --- a/examples/models/llama/rope.py +++ b/examples/models/llama/rope.py @@ -92,6 +92,22 @@ def apply_rotary_emb( return xq_out.type_as(xq), xk_out.type_as(xk) +def apply_rotary_emb_to_k( + xk: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor +) -> torch.Tensor: + xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1) + + freqs_cos = reshape_for_broadcast(freqs_cos, xk_r) + freqs_sin = reshape_for_broadcast(freqs_sin, xk_r) + + xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin + xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos + + xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3) + + return xk_out.type_as(xk) + + class RotaryEmbedding(torch.nn.Module): def __init__(self): super().__init__() @@ -160,3 +176,28 @@ def hf_apply_rotary_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed + + +def hf_apply_rotary_emb_to_k(k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the key tensors. + + Args: + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of k. Similarly, if k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `torch.Tensor` the key tensor rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + k_embed = (k * cos) + (rotate_half(k) * sin) + return k_embed diff --git a/examples/models/llama/source_transformation/attention_sink.py b/examples/models/llama/source_transformation/attention_sink.py new file mode 100644 index 00000000000..94f5b47871f --- /dev/null +++ b/examples/models/llama/source_transformation/attention_sink.py @@ -0,0 +1,62 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Components for supporting Attention Sink. See +# https://arxiv.org/abs/2309.17453 for more details about Attention Sink. + +import torch + +from executorch.examples.models.llama.llama_transformer import ModelArgs, Rope +from executorch.examples.models.llama.rope import ( + apply_rotary_emb_to_k, + hf_apply_rotary_emb_to_k, +) + + +class RopeWithAttentionSink(Rope): + """ + Rope that helps adjust position encoding when tokens are shifted in KVCache. + For AttentionSink, when tokens are shifted in KVCache, we need to use positions + in KVCache instead of positions in the actual text. + """ + + def __init__(self, params: ModelArgs): + super().__init__(params) + if self.params.use_hf_rope: + self.apply_rotary_emb_to_k = hf_apply_rotary_emb_to_k + else: + self.apply_rotary_emb_to_k = apply_rotary_emb_to_k + + def rerotate_k( + self, + k: torch.Tensor, + original_position: int, + new_position: int, + ): + """ + Rerotate k from original_position to new_position. This is done by rerotating + k with (new_position * theta - original_position * theta) with the following matrix: + (cos(delta), -sin(delta) + sin(delta), cos(delta)) + where delta = new_position * theta - original_position * theta + + The shape of k is (batch_size, seq_len, n_local_heads, head_dim) + + Based on https://github.com/huggingface/transformers/blame/main/src/transformers/cache_utils.py#L961 + """ + seq_len = k.shape[1] + original_freqs_cos = self.freqs_cos.narrow(0, original_position, seq_len) + original_freqs_sin = self.freqs_sin.narrow(0, original_position, seq_len) + new_freqs_cos = self.freqs_cos.narrow(0, new_position, seq_len) + new_freqs_sin = self.freqs_sin.narrow(0, new_position, seq_len) + rerotation_cos = ( + new_freqs_cos * original_freqs_cos + new_freqs_sin * original_freqs_sin + ) + rerotation_sin = ( + new_freqs_sin * original_freqs_cos - new_freqs_cos * original_freqs_sin + ) + + return self.apply_rotary_emb_to_k(k, rerotation_cos, rerotation_sin) diff --git a/examples/models/llama/source_transformation/test_attention_sink.py b/examples/models/llama/source_transformation/test_attention_sink.py new file mode 100644 index 00000000000..adb3bff3a58 --- /dev/null +++ b/examples/models/llama/source_transformation/test_attention_sink.py @@ -0,0 +1,73 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from executorch.examples.models.llama.llama_transformer import ModelArgs + +from executorch.examples.models.llama.source_transformation.attention_sink import ( + RopeWithAttentionSink, +) +from parameterized import parameterized + + +class RopeWithAttentionSinkTest(unittest.TestCase): + + def setUp(self): + torch.manual_seed(42) + self.params = ModelArgs(use_kv_cache=True, enable_dynamic_shape=True) + self.rope_with_attention_sink = RopeWithAttentionSink(params=self.params) + + @parameterized.expand( + [ + [128, 127], # Rotate left + [128, 128], # No rotation + [128, 129], # Rotate right + ] + ) + def test_rotate(self, original_position, new_position): + seq_len = 32 + + q = torch.rand( + 1, seq_len, self.params.n_heads, self.params.head_dim, dtype=torch.float32 + ) + k = torch.rand( + 1, + seq_len, + self.params.n_heads, + self.params.head_dim, + dtype=torch.float32, + ) + freqs_cos, freqs_sin = self.rope_with_attention_sink.get_freqs( + input_pos=torch.tensor([original_position], dtype=torch.int32), + seq_len=seq_len, + ) + _, pre_rotated_k = self.rope_with_attention_sink.forward( + q=q, + k=k, + freqs_cos=freqs_cos, + freqs_sin=freqs_sin, + ) + + rerotated_k = self.rope_with_attention_sink.rerotate_k( + k=pre_rotated_k, + original_position=original_position, + new_position=new_position, + ) + + freqs_cos, freqs_sin = self.rope_with_attention_sink.get_freqs( + input_pos=torch.tensor([new_position], dtype=torch.int32), + seq_len=seq_len, + ) + _, expected_k = self.rope_with_attention_sink.forward( + q=q, + k=k, + freqs_cos=freqs_cos, + freqs_sin=freqs_sin, + ) + + torch.testing.assert_close(rerotated_k, expected_k)