Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
get_qnn_quantizer,
get_vulkan_quantizer,
)
from executorch.extension.llm.modules import replace_mha_with_inference_mha
from executorch.util.activation_memory_profiler import generate_memory_trace

from ..model_factory import EagerModelFactory
Expand Down Expand Up @@ -536,7 +537,7 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
else:
dtype_override = None

return (
model_manager = (
_load_llama_model(
args.model,
checkpoint=checkpoint_path,
Expand All @@ -563,6 +564,15 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
.set_output_dir(output_dir_path)
.source_transform(_get_source_transforms(args.model, dtype_override, args))
)
if args.model in TORCHTUNE_DEFINED_MODELS:
if args.use_kv_cache:
print("Setting up the KV cache...")
model_manager.model.setup_caches(
batch_size=1,
dtype=dtype_override.to_torch_dtype(),
decoder_max_seq_len=args.max_seq_length,
)
return model_manager
Comment on lines +567 to +575
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Setup_cache moved here



def get_quantizer_and_quant_params(args):
Expand Down Expand Up @@ -974,6 +984,10 @@ def _get_source_transforms( # noqa
) -> List[Callable[[torch.nn.Module], torch.nn.Module]]:
transforms = []

is_torchtune = modelname in TORCHTUNE_DEFINED_MODELS
if is_torchtune:
transforms.append(replace_mha_with_inference_mha)

if args.use_spin_quant:
if args.use_spin_quant == "cuda":
from .source_transformation.spin_quant import (
Expand Down Expand Up @@ -1075,4 +1089,6 @@ def _get_source_transforms( # noqa
if args.vulkan:
transforms.append(replace_with_vulkan_rotary_emb)

print(f"Source transformations: {[t.__name__ for t in transforms]}")

return transforms
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ def __init__(
enable_dynamic_shape=False,
):
super().__init__()
self.max_batch_size = max_batch_size
self.max_seq_length = max_seq_length
self.n_heads = n_heads
self.head_dim = head_dim
self.cache_type = cache_type
if cache_type not in (
QuantizedCacheType.AffineSymmetric,
QuantizedCacheType.AffineAsymmetric,
Expand All @@ -65,6 +70,9 @@ def __init__(
self.register_buffer(
"v_cache", torch.zeros(cache_shape, dtype=self.quantized_cache_dtype)
)
self.register_buffer(
"cache_pos", torch.arange(0, max_seq_length), persistent=False
)
self.register_buffer(
"k_cache_scales", torch.ones(scale_shape, dtype=torch.float64)
)
Expand Down Expand Up @@ -95,7 +103,7 @@ def _quantize(self, value):
)
return quantized_value, scales, zero_points

def update(self, input_pos, k_val, v_val):
def update(self, k_val, v_val):
# quantize current k_val and store it in the cache
quantized_k_val, k_scales, k_zero_points = self._quantize(k_val)

Expand All @@ -110,7 +118,7 @@ def update(self, input_pos, k_val, v_val):
# for lowering pains of backends that work better
# with index_put op.
if self.enable_dynamic_shape:
start_pos = input_pos[0].item()
start_pos = self.cache_pos[0].item()
torch._check_is_size(start_pos)
dim_to_slice = 2 if self.is_transposed else 1
torch._check(start_pos < self.k_cache.size(dim_to_slice))
Expand All @@ -136,12 +144,12 @@ def update(self, input_pos, k_val, v_val):
narrowed_v_scales.copy_(v_scales)
narrowed_v_zp.copy_(v_zero_points)
else:
self.k_cache[:, :, input_pos] = quantized_k_val
self.k_cache_scales[:, :, input_pos] = k_scales
self.k_cache_zero_points[:, :, input_pos] = k_zero_points
self.v_cache[:, :, input_pos] = quantized_v_val
self.v_cache_scales[:, :, input_pos] = v_scales
self.v_cache_zero_points[:, :, input_pos] = v_zero_points
self.k_cache[:, :, self.cache_pos] = quantized_k_val
self.k_cache_scales[:, :, self.cache_pos] = k_scales
self.k_cache_zero_points[:, :, self.cache_pos] = k_zero_points
self.v_cache[:, :, self.cache_pos] = quantized_v_val
self.v_cache_scales[:, :, self.cache_pos] = v_scales
self.v_cache_zero_points[:, :, self.cache_pos] = v_zero_points
else:
# Right now using custom ops on this path.
# In future we can update custom op to handle transposed cache
Expand All @@ -150,7 +158,7 @@ def update(self, input_pos, k_val, v_val):
# backends such as QNN want to use quantized cache, with dynamic shape,
# instead of quantizing on their own.
# But until this opting for code simplicity
start_pos = input_pos[0].item()
start_pos = self.cache_pos[0].item()
_ = torch.ops.llama.update_quantized_cache(
quantized_k_val, self.k_cache, start_pos
)
Expand Down Expand Up @@ -207,6 +215,31 @@ def from_float(cls, kv_cache, cache_type: QuantizedCacheType):
kv_cache.enable_dynamic_shape,
)

def clone(self) -> "QuantizedKVCache":
"""Create a clone of the KVCache."""
if self.is_transposed:
num_kv_heads = self.k_cache.shape[1]
else:
num_kv_heads = self.k_cache.shape[2]
clone = QuantizedKVCache(
max_batch_size=self.max_batch_size,
max_seq_length=self.max_seq_length,
n_heads=num_kv_heads,
head_dim=self.k_cache.shape[3],
cache_type=self.cache_type,
tranposed=self.is_transposed,
enable_dynamic_shape=self.enable_dynamic_shape,
)
clone.k_cache.copy_(self.k_cache)
clone.v_cache.copy_(self.v_cache)
clone.cache_pos.copy_(self.cache_pos)
clone.k_cache_scales.copy_(self.k_cache_scales)
clone.v_cache_scales.copy_(self.v_cache_scales)
if clone.cache_type == QuantizedCacheType.AffineAsymmetric:
clone.k_cache_zero_points.copy_(self.k_cache_zero_points)
clone.v_cache_zero_points.copy_(self.v_cache_zero_points)
return clone


def replace_kv_cache_with_quantized_kv_cache(module):
logging.warning(
Expand Down
29 changes: 12 additions & 17 deletions examples/models/llama/source_transformation/sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,20 @@ class SDPACustom(torch.nn.Module):
def __init__(
self,
kv_cache: Union[KVCache, QuantizedKVCache],
dim: int,
):
super().__init__()
# Custom op only supports float32 currently. Converting to/from float32 is
# faster than not having the op.
self.kv_cache = kv_cache
if not isinstance(kv_cache, QuantizedKVCache):
self.kv_cache = kv_cache.to(torch.float)
else:
assert (
kv_cache.cache_fp_type == torch.float32
), "Only float32 is supported for custom SDPA"
self.dim = dim
# if not isinstance(kv_cache, QuantizedKVCache):
# self.kv_cache = kv_cache.to(torch.float)
# else:
# assert (
# kv_cache.cache_fp_type == torch.float32
# ), "Only float32 is supported for custom SDPA"

def forward(
self,
input_pos: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
Expand All @@ -60,12 +57,12 @@ def forward(
# updated quantize cache, scale and zero points
# returns dequantized kv cache
# Not most optimal. Optimizations to follow next
k_cache, v_cache = self.kv_cache.update(input_pos, k, v)
k_cache, v_cache = self.kv_cache.update(self.kv_cache.cache_pos, k, v)
output = torch.ops.llama.custom_sdpa(
q,
k_cache,
v_cache,
input_pos[0].item(),
self.kv_cache.cache_pos[0].item(),
None, # Attention mask
0, # dropout probability. Ignored by the code
True, # is_causal
Expand All @@ -77,13 +74,13 @@ def forward(
v,
k_cache,
v_cache,
input_pos[0].item(),
self.kv_cache.cache_pos[0].item(),
seqlen,
None, # Attention mask
0, # dropout probability. Ignored by the code
True, # is_causal
)
return output.view(bsz, seqlen, self.dim).to(dtype=input_dtype)
return output.view(bsz, seqlen, -1).to(dtype=input_dtype)


def _replace_sdpa_with_custom_op(module: torch.nn.Module):
Expand All @@ -106,7 +103,6 @@ def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module:


class SDPASimple(torch.nn.Module):

def __init__(
self,
kv_cache: KVCache,
Expand All @@ -122,7 +118,6 @@ def __init__(

def forward(
self,
input_pos: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
Expand All @@ -134,8 +129,8 @@ def forward(
k = k.transpose(1, 2)
v = v.transpose(1, 2)

k, v = self.kv_cache.update(input_pos, k, v)
attn_mask = mask[None, None, input_pos]
k, v = self.kv_cache.update(k, v)
attn_mask = mask[None, None, self.kv_cache.cache_pos]

k = k.repeat_interleave(self.n_rep, dim=1)
v = v.repeat_interleave(self.n_rep, dim=1)
Expand Down
14 changes: 7 additions & 7 deletions examples/models/llama3_2_vision/text_decoder/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,13 @@ def __init__(self, **kwargs):

self.model_ = prune_output_vocab(self.model_, output_prune_map)

if self.use_kv_cache:
print("Setting up KV cache on the model...")
self.model_.setup_caches(
batch_size=1,
dtype=self.dtype,
decoder_max_seq_len=self.max_seq_len,
)
# if self.use_kv_cache:
# print("Setting up KV cache on the model...")
# self.model_.setup_caches(
# batch_size=1,
# dtype=self.dtype,
# decoder_max_seq_len=self.max_seq_len,
# )
Comment on lines +145 to +151
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to do this because source transform happens after the model is set up, and we need to call the new swapped-in ET attention's setup_cache function. So we move the setup_cache to after the source transform


def get_eager_model(self) -> torch.nn.Module:
if self.dtype:
Expand Down
51 changes: 38 additions & 13 deletions extension/llm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@

import torch
import torchtune.modules.attention as TorchTuneAttention
from executorch.examples.models.llama.source_transformation.quantized_kv_cache import (
QuantizedKVCache,
)
from executorch.examples.models.llama.source_transformation.sdpa import (
SDPACustom,
SDPASimple,
)
from executorch.extension.llm.modules.kv_cache import KVCache as InferenceKVCache
from torch import nn
from torchtune.modules.attention_utils import _MaskType, _sdpa_or_flex_attention
Expand Down Expand Up @@ -145,16 +152,27 @@ def __init__(

# Use flex attention if supported and we are sample packing
self._attention_call = _sdpa_or_flex_attention()
self._sdpa = SDPA(
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
head_dim=self.head_dim,
attn_dropout=self.attn_dropout if self.training else 0.0,
is_causal=self.is_causal,
attention_fn=self._attention_call,
# self._sdpa = SDPA(
# num_kv_heads=self.num_kv_heads,
# num_heads=self.num_heads,
# head_dim=self.head_dim,
# attn_dropout=self.attn_dropout if self.training else 0.0,
# is_causal=self.is_causal,
# attention_fn=self._attention_call,
# kv_cache=self.kv_cache,
# )

self._sdpa = SDPACustom(
kv_cache=self.kv_cache,
)

# self._sdpa = SDPASimple(
# kv_cache=self.kv_cache,
# dim=self.embed_dim,
# head_dim=self.head_dim,
# n_rep=self.num_heads // self.num_kv_heads
# )

# this flag indicates whether to update the kv-cache during forward
# passes. when disabled, we can have the cache setup but still
# perform normal forward passes
Expand All @@ -177,13 +195,20 @@ def setup_cache(
"Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping."
)
else:
self.kv_cache = InferenceKVCache(
batch_size=batch_size,
max_seq_len=max_seq_len,
num_kv_heads=self.num_kv_heads,
# self.kv_cache = InferenceKVCache(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you try adding from executorch.extension.llm.custom_ops import * here and see if that works?

# batch_size=batch_size,
# max_seq_len=max_seq_len,
# num_kv_heads=self.num_kv_heads,
# head_dim=self.head_dim,
# dtype=dtype,
# transpose_cache=False,
# )
self.kv_cache = QuantizedKVCache(
max_batch_size=batch_size,
max_seq_length=max_seq_len,
n_heads=self.num_kv_heads,
head_dim=self.head_dim,
dtype=dtype,
transpose_cache=False,
# dtype needs to be float32 atm,
)
self._sdpa.kv_cache = self.kv_cache
self.cache_enabled = True
Expand Down
Loading