Skip to content
2 changes: 2 additions & 0 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
get_quant_weight_transform,
)
from .source_transformation.quantized_kv_cache import (
replace_kv_cache_with_custom_kv_cache,
replace_kv_cache_with_quantized_kv_cache,
)
from .source_transformation.rms_norm import replace_rms_norm_with_native_rms_norm
Expand Down Expand Up @@ -1052,6 +1053,7 @@ def _get_source_transforms( # noqa
transforms.append(materialze_broadcast_of_rope_freq_cis)

if args.use_sdpa_with_kv_cache:
transforms.append(replace_kv_cache_with_custom_kv_cache)
transforms.append(replace_sdpa_with_custom_op)

if args.quantize_kv_cache:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import logging
from enum import Enum
from typing import Tuple

import torch
import torch.nn as nn
Expand Down Expand Up @@ -44,7 +45,6 @@ def __init__(
QuantizedCacheType.AffineSymmetric,
QuantizedCacheType.AffineAsymmetric,
):

raise ValueError(
f"Only affine symmetric and asymmetric cache types are supported: got {cache_type}"
)
Expand Down Expand Up @@ -81,10 +81,11 @@ def __init__(
)

def _quantize(self, value):
scales, zero_points = (
torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(
value, self.quantized_cache_dtype
)
(
scales,
zero_points,
) = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(
value, self.quantized_cache_dtype
)
quantized_value = torch.ops.quantized_decomposed.quantize_per_token(
value,
Expand Down Expand Up @@ -262,3 +263,71 @@ def replace_kv_cache_with_quantized_kv_cache(module):
else:
replace_kv_cache_with_quantized_kv_cache(child)
return module


class CustomKVCache(nn.Module):
def __init__(
self,
max_batch_size: int,
max_seq_length: int,
n_heads: int,
head_dim: int,
dtype=torch.float32,
):
super().__init__()
self.max_seq_length = max_seq_length
cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim)

self.max_batch_size = max_batch_size
self.n_heads = n_heads
self.head_dim = head_dim
self.register_buffer(
"k_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu")
)
self.register_buffer(
"v_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu")
)

def update(
self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# input_pos: [S], k_val: [B, S, H, D]
start_pos = input_pos[0].item()
_ = torch.ops.llama.update_cache(k_val, self.k_cache, start_pos)
_ = torch.ops.llama.update_cache(v_val, self.v_cache, start_pos)
return self.k_cache, self.v_cache


def replace_kv_cache_with_custom_kv_cache(module):
r"""
Replace KVCache with CustomKVCache. This modifies the model in place.
At the moment custom kv cache only supports cache with shape
[B, S, H, D] as opposed to [B, H, S, D]
This is because the custom op treats second dim as sequence dim.
Future work: support [B, H, S, D]
"""
logging.warning(
"Replacing KVCache with CustomKVCache. This modifies the model in place."
)
for name, child in module.named_children():
if isinstance(child, KVCache):
cache_shape = child.k_cache.shape
cache_dtype = child.k_cache.dtype
assert (
child.is_transposed is False
), "CustomKVCache does not support transposed cache"
max_batch_size, max_seq_length, n_heads, head_dim = cache_shape
setattr(
module,
name,
CustomKVCache(
max_batch_size,
max_seq_length,
n_heads,
head_dim,
dtype=cache_dtype,
),
)
else:
replace_kv_cache_with_custom_kv_cache(child)
return module
39 changes: 10 additions & 29 deletions examples/models/llama/source_transformation/sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,33 +56,16 @@ def forward(

k_cache = self.kv_cache.k_cache
v_cache = self.kv_cache.v_cache
if hasattr(self.kv_cache, "quantized_cache_dtype"):
# updated quantize cache, scale and zero points
# returns dequantized kv cache
# Not most optimal. Optimizations to follow next
k_cache, v_cache = self.kv_cache.update(input_pos, k, v)
output = torch.ops.llama.custom_sdpa(
q,
k_cache,
v_cache,
input_pos[0].item(),
None, # Attention mask
0, # dropout probability. Ignored by the code
True, # is_causal
)
else:
output = torch.ops.llama.sdpa_with_kv_cache(
q,
k,
v,
k_cache,
v_cache,
input_pos[0].item(),
seqlen,
None, # Attention mask
0, # dropout probability. Ignored by the code
True, # is_causal
)
k_cache, v_cache = self.kv_cache.update(input_pos, k, v)
output = torch.ops.llama.custom_sdpa(
q,
k_cache,
v_cache,
input_pos[0].item(),
None, # Attention mask
0, # dropout probability. Ignored by the code
True, # is_causal
)
return output.view(bsz, seqlen, self.dim).to(dtype=input_dtype)


Expand All @@ -106,7 +89,6 @@ def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module:


class SDPASimple(torch.nn.Module):

def __init__(
self,
kv_cache: KVCache,
Expand Down Expand Up @@ -166,7 +148,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:


class SDPAFlex(torch.nn.Module):

def __init__(
self,
kv_cache: KVCache,
Expand Down
Loading