Skip to content
Merged
18 changes: 13 additions & 5 deletions examples/models/llama/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,16 @@ def forward(
return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)


def _create_causal_mask_for_ring_buffer(
cache_positions, window_size, start_pos, seq_len
):
pos_q = start_pos + torch.arange(seq_len, dtype=torch.long).view(-1, 1)
delta = pos_q - cache_positions
attn_mask = (cache_positions >= 0) & (delta >= 0) & (delta < window_size)
attn_mask = torch.where(attn_mask == True, 0, float("-inf")) # noqa E712
return attn_mask


class CacheUpdateStrategy(Enum):
RING_BUFFER = "RingBuffer"
INVALID = "Invalid"
Expand Down Expand Up @@ -283,12 +293,10 @@ def __init__(
self.is_ring_buffer = True

def create_causal_mask_for_ring_buffer(self, start_pos, seq_len):
pos_q = start_pos + torch.arange(seq_len, dtype=torch.long).view(-1, 1)
cache_positions = self.cache_positions_manager.cache_positions
delta = pos_q - cache_positions
attn_mask = (cache_positions >= 0) & (delta >= 0) & (delta < self.window_size)
attn_mask = torch.where(attn_mask == True, 0, float("-inf")) # noqa E712
return attn_mask
return _create_causal_mask_for_ring_buffer(
cache_positions, self.window_size, start_pos, seq_len
)

def update(
self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor
Expand Down
191 changes: 190 additions & 1 deletion examples/models/llama/source_transformation/custom_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@

import torch
import torch.nn as nn
from executorch.examples.models.llama.attention import KVCache
from executorch.examples.models.llama.attention import (
_create_causal_mask_for_ring_buffer,
CachePositionsManager,
KVCache,
RingKVCache,
)

from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401

Expand Down Expand Up @@ -75,6 +80,7 @@ def __init__(
self.register_buffer(
"v_cache_zero_points", torch.ones(scale_shape, dtype=torch.int8)
)
self.cache_type = cache_type

def _quantize(self, value):
(
Expand Down Expand Up @@ -181,6 +187,7 @@ def update(self, input_pos, k_val, v_val, indices=None):
However the storage is [B, S, H, D] so we incur transpose in, transpose out
This shall be removed by subsequent post-export graph pass
"""

k_val = k_val.transpose(1, 2)
v_val = v_val.transpose(1, 2)

Expand Down Expand Up @@ -346,3 +353,185 @@ def _replace_kv_cache_with_custom_kv_cache(module):
else:
_replace_kv_cache_with_custom_kv_cache(child)
return module


class QuantizedRingKVCache(QuantizedKVCache):
def __init__(
self,
max_batch_size,
max_context_length,
n_heads,
head_dim,
cache_type: QuantizedCacheType = QuantizedCacheType.AffineSymmetric,
use_custom_update_cache_op: bool = False,
):
# Look at attention.py for explanation on why max_context_length * 2
super().__init__(
max_batch_size,
max_context_length * 2,
n_heads,
head_dim,
cache_type,
use_custom_update_cache_op,
)
self.cache_positions_manager = CachePositionsManager(self.max_context_length)
self.is_ring_buffer = True
self.window_size = max_context_length

def create_causal_mask_for_ring_buffer(self, start_pos, seq_len):
cache_positions = self.cache_positions_manager.cache_positions
return _create_causal_mask_for_ring_buffer(
cache_positions, self.window_size, start_pos, seq_len
)

def update(self, input_pos, k_val, v_val):
"""
k_val, v_val: [B, H, S, D]
return: [B, H, S, D]
However the storage is [B, S, H, D] so we incur transpose in, transpose out
This shall be removed by subsequent post-export graph pass
"""
# Need to transpose for two reasons
# 1. kv cache is stored as [B, S, H, D]
# 2. If seq_len = k_val.size(2), we wont be able be able to optimize
# away transpose at the output of k, v projection
seq_len = k_val.transpose(1, 2).size(1)
assert seq_len <= self.k_cache.size(
1
), f"Update sequence length({seq_len}) for kv cache must be smaller than the cache size({self.k_cache.size(2)})"
indices = self.cache_positions_manager.calculate_positions_and_update_indices(
input_pos, seq_len
)
indices = indices.unsqueeze(0)

return super().update(input_pos, k_val, v_val, indices)

@classmethod
def from_quantized_kv_cache(
cls,
kv_cache,
sliding_window_size,
):
assert isinstance(
kv_cache, QuantizedKVCache
), "For QuantizedRingKVCache expect QuantizedKVCache as input kv_cache"
max_batch_size, _, n_heads, head_dim = kv_cache.k_cache.shape
return cls(
max_batch_size,
sliding_window_size,
n_heads,
head_dim,
kv_cache.cache_type,
kv_cache.use_custom_update_cache_op,
)


class CustomRingKVCache(CustomKVCache):
def __init__(
self,
max_batch_size,
max_context_length,
n_heads,
head_dim,
dtype=torch.float32,
):
# Look at attention.py for explanation on why max_context_length * 2
super().__init__(
max_batch_size, max_context_length * 2, n_heads, head_dim, dtype
)
self.cache_positions_manager = CachePositionsManager(self.max_context_length)
self.is_ring_buffer = True
self.window_size = max_context_length

def create_causal_mask_for_ring_buffer(self, start_pos, seq_len):
cache_positions = self.cache_positions_manager.cache_positions
return _create_causal_mask_for_ring_buffer(
cache_positions, self.window_size, start_pos, seq_len
)

def update(self, input_pos, k_val, v_val):
"""
k_val, v_val: [B, H, S, D]
return: [B, H, S, D]
However the storage is [B, S, H, D] so we incur transpose in, transpose out
This shall be removed by subsequent post-export graph pass
"""
# Need to transpose for two reasons
# 1. kv cache is stored as [B, S, H, D]
# 2. If seq_len = k_val.size(2), we wont be able be able to optimize
# away transpose at the output of k, v projection
seq_len = k_val.transpose(1, 2).size(1)
assert seq_len <= self.k_cache.size(
1
), f"Update sequence length({seq_len}) for kv cache must be smaller than the cache size({self.k_cache.size(2)})"
indices = self.cache_positions_manager.calculate_positions_and_update_indices(
input_pos, seq_len
)
indices = indices.unsqueeze(0)

return super().update(input_pos, k_val, v_val, indices)

@classmethod
def from_custom_kv_cache(
cls,
kv_cache,
sliding_window_size,
):
max_batch_size, n_heads, _, head_dim = kv_cache.k_cache.shape
if isinstance(kv_cache, CustomKVCache):
# If replacing custom kv cache, then the shape is [B, S, H, D]
max_batch_size, _, n_heads, head_dim = kv_cache.k_cache.shape
return cls(
max_batch_size,
sliding_window_size,
n_heads,
head_dim,
dtype=kv_cache.k_cache.dtype,
)


def _replace_kv_cache_with_ring_kv_cache(attention, layer_size):
sliding_window_size = layer_size
assert (
getattr(attention, "kv_cache", None) is not None
), "Attention module must have kv_cache module"
kv_cache = attention.kv_cache
if isinstance(kv_cache, KVCache):
attention.kv_cache = RingKVCache(
kv_cache.max_batch_size,
sliding_window_size,
kv_cache.n_heads,
kv_cache.head_dim,
kv_cache.enable_dynamic_shape,
kv_cache.k_cache.dtype,
)
elif isinstance(kv_cache, CustomKVCache):
attention.kv_cache = CustomRingKVCache.from_custom_kv_cache(
kv_cache, layer_size
)
elif isinstance(kv_cache, QuantizedKVCache):
attention.kv_cache = QuantizedRingKVCache.from_quantized_kv_cache(
kv_cache, layer_size
)


def replace_kv_cache_with_ring_kv_cache(module, layer_sizes):
# This is needed to ensure that custom ops are registered
from executorch.extension.llm.custom_ops import custom_ops # noqa: F401

logging.info(
"Replacing kv cache with ring kv cache. This modifies the model in place."
)
assert len(layer_sizes) == len(
module.layers
), f"Length of layer sizes {len(layer_sizes)} must match the number of layers in the module {len(module.layers)}."
for i, transformer_block in enumerate(module.layers):
sliding_window_size = layer_sizes[i]
if sliding_window_size == 0:
continue
assert (
getattr(transformer_block, "attention", None) is not None
), f"Transfomer block must have attention module. Transformer block {transformer_block}"
attention = transformer_block.attention
_replace_kv_cache_with_ring_kv_cache(attention, sliding_window_size)
return module
25 changes: 25 additions & 0 deletions examples/models/llama/tests/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,33 @@ python_unittest(
srcs = [
"test_ring_attention.py",
],
preload_deps = [
"//executorch/extension/llm/custom_ops:custom_ops_aot_lib",
"//executorch/kernels/quantized:aot_lib",
],
deps = [
"//caffe2:torch",
"//executorch/examples/models/llama:export_library",
"//executorch/examples/models/llama:llama_transformer",
"//executorch/examples/models/llama:custom_kv_cache",
"//executorch/examples/models/llama:sdpa",
],
)

python_unittest(
name = "test_replace_kv_cache",
srcs = [
"test_replace_kv_cache.py",
],
preload_deps = [
"//executorch/extension/llm/custom_ops:custom_ops_aot_lib",
"//executorch/kernels/quantized:aot_lib",
],
deps = [
"//caffe2:torch",
"//executorch/examples/models/llama:export_library",
"//executorch/examples/models/llama:llama_transformer",
"//executorch/examples/models/llama:custom_kv_cache",
"//executorch/examples/models/llama:sdpa",
],
)
Loading
Loading