Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
b91bf0b
[ExecuTorch][Llama] Split custom sdpa op and kv cache
kimishpatel Dec 20, 2024
b981f06
Changes to split kv cache and sdpa
kimishpatel Dec 20, 2024
5eb4c6f
Update on "Changes to split kv cache and sdpa"
kimishpatel Dec 20, 2024
e105c4c
Update base for Update on " [ExecuTorch][BE] Split kv cache and SDPA …
kimishpatel Dec 21, 2024
275144b
Update on " [ExecuTorch][BE] Split kv cache and SDPA for better code …
kimishpatel Dec 21, 2024
703f76e
Update base for Update on " [ExecuTorch][BE] Split kv cache and SDPA …
kimishpatel Jan 7, 2025
d20dd95
Update on " [ExecuTorch][BE] Split kv cache and SDPA for better code …
kimishpatel Jan 7, 2025
df1383b
Update base for Update on " [ExecuTorch][BE] Split kv cache and SDPA …
kimishpatel Jan 13, 2025
3a6b545
Update on " [ExecuTorch][BE] Split kv cache and SDPA for better code …
kimishpatel Jan 13, 2025
148354d
Update base for Update on " [ExecuTorch][BE] Split kv cache and SDPA …
kimishpatel Jan 13, 2025
3468f0c
Update on " [ExecuTorch][BE] Split kv cache and SDPA for better code …
kimishpatel Jan 13, 2025
f87afa4
Update base for Update on " [ExecuTorch][BE] Split kv cache and SDPA …
kimishpatel Jan 14, 2025
84ef14b
Update on " [ExecuTorch][BE] Split kv cache and SDPA for better code …
kimishpatel Jan 14, 2025
f6a87ee
Update base for Update on " [ExecuTorch][BE] Split kv cache and SDPA …
kimishpatel Jan 15, 2025
6e8cff5
Update on " [ExecuTorch][BE] Split kv cache and SDPA for better code …
kimishpatel Jan 15, 2025
13c7da9
Update base for Update on " [ExecuTorch][BE] Split kv cache and SDPA …
kimishpatel Jan 15, 2025
ed78ae3
Update on " [ExecuTorch][BE] Split kv cache and SDPA for better code …
kimishpatel Jan 15, 2025
ee290d0
Update base for Update on " [ExecuTorch][BE] Split kv cache and SDPA …
kimishpatel Jan 16, 2025
305350d
Update on " [ExecuTorch][BE] Split kv cache and SDPA for better code …
kimishpatel Jan 16, 2025
5eb26e8
Update base for Update on " [ExecuTorch][BE] Split kv cache and SDPA …
kimishpatel Jan 16, 2025
049b31b
Update on " [ExecuTorch][BE] Split kv cache and SDPA for better code …
kimishpatel Jan 16, 2025
75044ad
Update base for Update on " [ExecuTorch][BE] Split kv cache and SDPA …
kimishpatel Jan 16, 2025
c13944d
Update on " [ExecuTorch][BE] Split kv cache and SDPA for better code …
kimishpatel Jan 16, 2025
14e7cdc
Update base for Update on " [ExecuTorch][BE] Split kv cache and SDPA …
kimishpatel Jan 16, 2025
7d1c3f4
Update on " [ExecuTorch][BE] Split kv cache and SDPA for better code …
kimishpatel Jan 16, 2025
a466446
Update base for Update on " [ExecuTorch][BE] Split kv cache and SDPA …
kimishpatel Jan 16, 2025
6f79856
Update on " [ExecuTorch][BE] Split kv cache and SDPA for better code …
kimishpatel Jan 16, 2025
0937059
Update base for Update on " [ExecuTorch][BE] Split kv cache and SDPA …
kimishpatel Jan 21, 2025
64ab3f5
Update on " [ExecuTorch][BE] Split kv cache and SDPA for better code …
kimishpatel Jan 21, 2025
3c2d80c
Update base for Update on " [ExecuTorch][BE] Split kv cache and SDPA …
kimishpatel Jan 21, 2025
5d61242
Update on " [ExecuTorch][BE] Split kv cache and SDPA for better code …
kimishpatel Jan 21, 2025
3b84b51
Update base for Update on " [ExecuTorch][BE] Split kv cache and SDPA …
kimishpatel Jan 21, 2025
a10e400
Update on " [ExecuTorch][BE] Split kv cache and SDPA for better code …
kimishpatel Jan 21, 2025
1f8c183
Update base for Update on " [ExecuTorch][BE] Split kv cache and SDPA …
kimishpatel Jan 22, 2025
7b41972
Update on " [ExecuTorch][BE] Split kv cache and SDPA for better code …
kimishpatel Jan 22, 2025
fb74693
Update base for Update on " [ExecuTorch][BE] Split kv cache and SDPA …
kimishpatel Jan 22, 2025
c0bf723
Update on " [ExecuTorch][BE] Split kv cache and SDPA for better code …
kimishpatel Jan 22, 2025
928d08a
Update base for Update on " [ExecuTorch][BE] Split kv cache and SDPA …
kimishpatel Jan 22, 2025
b6a4eb5
Update on " [ExecuTorch][BE] Split kv cache and SDPA for better code …
kimishpatel Jan 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
get_quant_weight_transform,
)
from .source_transformation.quantized_kv_cache import (
replace_kv_cache_with_custom_kv_cache,
replace_kv_cache_with_quantized_kv_cache,
)
from .source_transformation.rms_norm import replace_rms_norm_with_native_rms_norm
Expand Down Expand Up @@ -663,6 +664,8 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
# export_to_edge
builder_exported = _prepare_for_llama_export(args).export()

builder_exported.run_canonical_optimizations()

if args.export_only:
exit()

Expand Down Expand Up @@ -1052,6 +1055,7 @@ def _get_source_transforms( # noqa
transforms.append(materialze_broadcast_of_rope_freq_cis)

if args.use_sdpa_with_kv_cache:
transforms.append(replace_kv_cache_with_custom_kv_cache)
transforms.append(replace_sdpa_with_custom_op)

if args.quantize_kv_cache:
Expand Down
47 changes: 17 additions & 30 deletions examples/models/llama/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,22 +232,16 @@ def __init__(
max_seq_length: int,
n_heads: int,
head_dim: int,
transpose_cache: bool,
enable_dynamic_shape: bool,
dtype=torch.float32,
):
super().__init__()
self.max_seq_length = max_seq_length
self.is_transposed = transpose_cache
if transpose_cache:
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
else:
cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim)
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)

self.max_batch_size = max_batch_size
self.n_heads = n_heads
self.head_dim = head_dim
self.transpose_cache = transpose_cache
self.enable_dynamic_shape = enable_dynamic_shape
self.register_buffer(
"k_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu")
Expand All @@ -259,12 +253,12 @@ def __init__(
def update(
self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# input_pos: [S], k_val: [B, H, S, D] or [B, S, H, D] depending on transpose_cache
# input_pos: [S], k_val: [B, H, S, D]
if self.enable_dynamic_shape:
start_pos = input_pos[0].item()
torch._check_is_size(start_pos)
torch._check(start_pos < self.max_seq_length)
dim_to_slice = 2 if self.transpose_cache else 1
dim_to_slice = 2
seq_length = k_val.size(dim_to_slice)
# Replace the entry in the cache for this token
# The following lines are equivalent to:
Expand All @@ -283,28 +277,22 @@ def update(
else:
k_out = self.k_cache
v_out = self.v_cache
if self.transpose_cache:
k_out[:, :, input_pos] = k_val
v_out[:, :, input_pos] = v_val
else:
k_out[:, input_pos] = k_val
v_out[:, input_pos] = v_val
k_out[:, :, input_pos] = k_val
v_out[:, :, input_pos] = v_val

return k_out, v_out


class SDPA(nn.Module):
def __init__(
self,
kv_cache: KVCache,
dim: int,
head_dim: int,
n_rep: int,
max_seq_len: int,
enable_dynamic_shape: bool,
):
super().__init__()
self.kv_cache = kv_cache
self.dim = dim
self.head_dim = head_dim
self.n_rep = n_rep
Expand All @@ -314,18 +302,16 @@ def __init__(
def forward(
self,
input_pos: torch.Tensor,
q: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_heads, head_dim)
k: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_kv_heads, head_dim)
v: torch.Tensor, # (bs, seqlen, n_local_kv_heads, head_dim)
q: torch.Tensor, # Already have rotary embeddings. (bs, n_local_heads, seqlen, head_dim)
k: torch.Tensor, # Already have rotary embeddings. (bs, n_local_kv_heads, seqlen, head_dim)
v: torch.Tensor, # (bs, n_local_kv_heads, seqlen, head_dim)
bsz,
seqlen,
mask: torch.Tensor,
) -> torch.Tensor:
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
k = k.transpose(1, 2)
v = v.transpose(1, 2)

k, v = self.kv_cache.update(input_pos, k, v)
# TODO(kimishpatel): Move this slicing logic to Attention block so that
# SDPA does not have to take input_pos as arg
if self.enable_dynamic_shape:
start_pos = input_pos[-1].item()
torch._check_is_size(start_pos)
Expand All @@ -336,6 +322,8 @@ def forward(
else:
attn_mask = mask[None, None, input_pos]

# TODO(kimishpatel): This should not be necessary because scaled_dot_product_attention
# can natively support GQA now. But needs enable_gqa=True
k = k.repeat_interleave(self.n_rep, dim=1)
v = v.repeat_interleave(self.n_rep, dim=1)
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0)
Expand Down Expand Up @@ -383,11 +371,9 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
args.max_seq_len,
self.n_kv_heads,
self.head_dim,
not args.use_sdpa_with_kv_cache_op, # if we are using the custom op don't transpose the cache. Expect untransposed q k v
args.enable_dynamic_shape,
)
self.SDPA = SDPA(
kv_cache=self.kv_cache,
dim=self.n_local_heads * self.head_dim,
head_dim=self.head_dim,
n_rep=self.n_rep,
Expand All @@ -414,15 +400,16 @@ def forward(
# RoPE relative positional embeddings
q, k = self.rope.forward(q, k, freqs_cos, freqs_sin)

q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
k = k.transpose(1, 2)
v = v.transpose(1, 2)

if self.use_kv_cache:
assert input_pos is not None
k, v = self.kv_cache.update(input_pos, k, v)
output = self.SDPA(input_pos, q, k, v, bsz, seqlen, self.mask)
return self.wo(output)

q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
k = k.transpose(1, 2)
v = v.transpose(1, 2)

# grouped multiquery attention: expand out keys and values
k = k.repeat_interleave(self.n_rep, dim=1)
v = v.repeat_interleave(self.n_rep, dim=1)
Expand Down
Loading
Loading