Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
b91bf0b
[ExecuTorch][Llama] Split custom sdpa op and kv cache
kimishpatel Dec 20, 2024
b981f06
Changes to split kv cache and sdpa
kimishpatel Dec 20, 2024
5eb4c6f
Update on "Changes to split kv cache and sdpa"
kimishpatel Dec 20, 2024
e105c4c
Update base for Update on " [ExecuTorch][BE] Split kv cache and SDPA …
kimishpatel Dec 21, 2024
275144b
Update on " [ExecuTorch][BE] Split kv cache and SDPA for better code …
kimishpatel Dec 21, 2024
703f76e
Update base for Update on " [ExecuTorch][BE] Split kv cache and SDPA …
kimishpatel Jan 7, 2025
d20dd95
Update on " [ExecuTorch][BE] Split kv cache and SDPA for better code …
kimishpatel Jan 7, 2025
df1383b
Update base for Update on " [ExecuTorch][BE] Split kv cache and SDPA …
kimishpatel Jan 13, 2025
3a6b545
Update on " [ExecuTorch][BE] Split kv cache and SDPA for better code …
kimishpatel Jan 13, 2025
148354d
Update base for Update on " [ExecuTorch][BE] Split kv cache and SDPA …
kimishpatel Jan 13, 2025
3468f0c
Update on " [ExecuTorch][BE] Split kv cache and SDPA for better code …
kimishpatel Jan 13, 2025
f87afa4
Update base for Update on " [ExecuTorch][BE] Split kv cache and SDPA …
kimishpatel Jan 14, 2025
84ef14b
Update on " [ExecuTorch][BE] Split kv cache and SDPA for better code …
kimishpatel Jan 14, 2025
f6a87ee
Update base for Update on " [ExecuTorch][BE] Split kv cache and SDPA …
kimishpatel Jan 15, 2025
6e8cff5
Update on " [ExecuTorch][BE] Split kv cache and SDPA for better code …
kimishpatel Jan 15, 2025
13c7da9
Update base for Update on " [ExecuTorch][BE] Split kv cache and SDPA …
kimishpatel Jan 15, 2025
ed78ae3
Update on " [ExecuTorch][BE] Split kv cache and SDPA for better code …
kimishpatel Jan 15, 2025
ee290d0
Update base for Update on " [ExecuTorch][BE] Split kv cache and SDPA …
kimishpatel Jan 16, 2025
305350d
Update on " [ExecuTorch][BE] Split kv cache and SDPA for better code …
kimishpatel Jan 16, 2025
5eb26e8
Update base for Update on " [ExecuTorch][BE] Split kv cache and SDPA …
kimishpatel Jan 16, 2025
049b31b
Update on " [ExecuTorch][BE] Split kv cache and SDPA for better code …
kimishpatel Jan 16, 2025
75044ad
Update base for Update on " [ExecuTorch][BE] Split kv cache and SDPA …
kimishpatel Jan 16, 2025
c13944d
Update on " [ExecuTorch][BE] Split kv cache and SDPA for better code …
kimishpatel Jan 16, 2025
14e7cdc
Update base for Update on " [ExecuTorch][BE] Split kv cache and SDPA …
kimishpatel Jan 16, 2025
7d1c3f4
Update on " [ExecuTorch][BE] Split kv cache and SDPA for better code …
kimishpatel Jan 16, 2025
a466446
Update base for Update on " [ExecuTorch][BE] Split kv cache and SDPA …
kimishpatel Jan 16, 2025
6f79856
Update on " [ExecuTorch][BE] Split kv cache and SDPA for better code …
kimishpatel Jan 16, 2025
0937059
Update base for Update on " [ExecuTorch][BE] Split kv cache and SDPA …
kimishpatel Jan 21, 2025
64ab3f5
Update on " [ExecuTorch][BE] Split kv cache and SDPA for better code …
kimishpatel Jan 21, 2025
3c2d80c
Update base for Update on " [ExecuTorch][BE] Split kv cache and SDPA …
kimishpatel Jan 21, 2025
5d61242
Update on " [ExecuTorch][BE] Split kv cache and SDPA for better code …
kimishpatel Jan 21, 2025
3b84b51
Update base for Update on " [ExecuTorch][BE] Split kv cache and SDPA …
kimishpatel Jan 21, 2025
a10e400
Update on " [ExecuTorch][BE] Split kv cache and SDPA for better code …
kimishpatel Jan 21, 2025
1f8c183
Update base for Update on " [ExecuTorch][BE] Split kv cache and SDPA …
kimishpatel Jan 22, 2025
7b41972
Update on " [ExecuTorch][BE] Split kv cache and SDPA for better code …
kimishpatel Jan 22, 2025
fb74693
Update base for Update on " [ExecuTorch][BE] Split kv cache and SDPA …
kimishpatel Jan 22, 2025
c0bf723
Update on " [ExecuTorch][BE] Split kv cache and SDPA for better code …
kimishpatel Jan 22, 2025
928d08a
Update base for Update on " [ExecuTorch][BE] Split kv cache and SDPA …
kimishpatel Jan 22, 2025
b6a4eb5
Update on " [ExecuTorch][BE] Split kv cache and SDPA for better code …
kimishpatel Jan 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .ci/scripts/test_llama.sh
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ fi

if [[ "${MODE}" =~ .*quantize_kv.* ]]; then
QUANTIZE_KV_CACHE=ON
# quantize_kv cache transform uses custom kv cache update op
CUSTOM=ON
else
QUANTIZE_KV_CACHE=OFF
fi
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/quantizer/custom_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def get_custom_quant_ios_dtype(
"""
This function is specific for llama inputs and outputs
"""
if node.op == "placeholder" and "attention_sdpa_kv_cache_past_" in node.name:
if node.op == "placeholder" and "attention_kv_cache_past_" in node.name:
return kv_dtype

# Tag index put node before copy node, because copy is a skipped node in qnn
Expand Down
2 changes: 2 additions & 0 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,8 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
# export_to_edge
builder_exported = _prepare_for_llama_export(args).export()

builder_exported.run_canonical_optimizations()

if args.export_only:
exit()

Expand Down
46 changes: 15 additions & 31 deletions examples/models/llama/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,22 +232,16 @@ def __init__(
max_seq_length: int,
n_heads: int,
head_dim: int,
transpose_cache: bool,
enable_dynamic_shape: bool,
dtype=torch.float32,
):
super().__init__()
self.max_seq_length = max_seq_length
self.is_transposed = transpose_cache
if transpose_cache:
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
else:
cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim)
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)

self.max_batch_size = max_batch_size
self.n_heads = n_heads
self.head_dim = head_dim
self.transpose_cache = transpose_cache
self.enable_dynamic_shape = enable_dynamic_shape
self.register_buffer(
"k_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu")
Expand All @@ -259,12 +253,12 @@ def __init__(
def update(
self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# input_pos: [S], k_val: [B, H, S, D] or [B, S, H, D] depending on transpose_cache
# input_pos: [S], k_val: [B, H, S, D]
if self.enable_dynamic_shape:
start_pos = input_pos[0].item()
torch._check_is_size(start_pos)
torch._check(start_pos < self.max_seq_length)
dim_to_slice = 2 if self.transpose_cache else 1
dim_to_slice = 2
seq_length = k_val.size(dim_to_slice)
# Replace the entry in the cache for this token
# The following lines are equivalent to:
Expand All @@ -283,28 +277,22 @@ def update(
else:
k_out = self.k_cache
v_out = self.v_cache
if self.transpose_cache:
k_out[:, :, input_pos] = k_val
v_out[:, :, input_pos] = v_val
else:
k_out[:, input_pos] = k_val
v_out[:, input_pos] = v_val
k_out[:, :, input_pos] = k_val
v_out[:, :, input_pos] = v_val

return k_out, v_out


class SDPA(nn.Module):
def __init__(
self,
kv_cache: KVCache,
dim: int,
head_dim: int,
n_rep: int,
max_seq_len: int,
enable_dynamic_shape: bool,
):
super().__init__()
self.kv_cache = kv_cache
self.dim = dim
self.head_dim = head_dim
self.n_rep = n_rep
Expand All @@ -314,18 +302,13 @@ def __init__(
def forward(
self,
input_pos: torch.Tensor,
q: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_heads, head_dim)
k: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_kv_heads, head_dim)
v: torch.Tensor, # (bs, seqlen, n_local_kv_heads, head_dim)
q: torch.Tensor, # Already have rotary embeddings. (bs, n_local_heads, seqlen, head_dim)
k: torch.Tensor, # Already have rotary embeddings. (bs, n_local_kv_heads, seqlen, head_dim)
v: torch.Tensor, # (bs, n_local_kv_heads, seqlen, head_dim)
bsz,
seqlen,
mask: torch.Tensor,
) -> torch.Tensor:
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
k = k.transpose(1, 2)
v = v.transpose(1, 2)

k, v = self.kv_cache.update(input_pos, k, v)
if self.enable_dynamic_shape:
start_pos = input_pos[-1].item()
torch._check_is_size(start_pos)
Expand All @@ -336,6 +319,8 @@ def forward(
else:
attn_mask = mask[None, None, input_pos]

# TODO(kimishpatel): This should not be necessary because scaled_dot_product_attention
# can natively support GQA now. But needs enable_gqa=True
k = k.repeat_interleave(self.n_rep, dim=1)
v = v.repeat_interleave(self.n_rep, dim=1)
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0)
Expand Down Expand Up @@ -383,11 +368,9 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
args.max_seq_len,
self.n_kv_heads,
self.head_dim,
not args.use_sdpa_with_kv_cache_op, # if we are using the custom op don't transpose the cache. Expect untransposed q k v
args.enable_dynamic_shape,
)
self.SDPA = SDPA(
kv_cache=self.kv_cache,
dim=self.n_local_heads * self.head_dim,
head_dim=self.head_dim,
n_rep=self.n_rep,
Expand All @@ -414,15 +397,16 @@ def forward(
# RoPE relative positional embeddings
q, k = self.rope.forward(q, k, freqs_cos, freqs_sin)

q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
k = k.transpose(1, 2)
v = v.transpose(1, 2)

if self.use_kv_cache:
assert input_pos is not None
k, v = self.kv_cache.update(input_pos, k, v)
output = self.SDPA(input_pos, q, k, v, bsz, seqlen, self.mask)
return self.wo(output)

q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
k = k.transpose(1, 2)
v = v.transpose(1, 2)

# grouped multiquery attention: expand out keys and values
k = k.repeat_interleave(self.n_rep, dim=1)
v = v.repeat_interleave(self.n_rep, dim=1)
Expand Down
27 changes: 6 additions & 21 deletions examples/models/llama/source_transformation/attention_sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ def __init__(
self,
n_heads: int,
head_dim: int,
transpose_cache: bool,
enable_dynamic_shape: bool,
rope: RopeWithAttentionSink,
window_size: int,
Expand All @@ -125,7 +124,6 @@ def __init__(
max_seq_length=window_size + sink_size,
n_heads=n_heads,
head_dim=head_dim,
transpose_cache=transpose_cache,
enable_dynamic_shape=enable_dynamic_shape,
dtype=dtype,
)
Expand Down Expand Up @@ -161,28 +159,17 @@ def evict_tokens(self, input_pos: torch.Tensor, seq_len: int) -> int:
input_pos_item + self.position_shift - self.sink_size - num_to_evict
)
num_empty_space = self.window_size - num_to_keep
dim_to_slice = 2 if self.transpose_cache else 1
dim_to_slice = 2
k_to_keep = self.k_cache.narrow(
dim_to_slice,
self.sink_size + num_to_evict, # pyre-ignore [6]
num_to_keep, # pyre-ignore [6]
)
if self.transpose_cache:
k_to_keep = self.rope.rerotate_k(
k=k_to_keep.transpose(1, 2),
original_position=( # pyre-ignore [6]
self.sink_size + num_to_evict
),
new_position=self.sink_size,
).transpose(1, 2)
else:
k_to_keep = self.rope.rerotate_k(
k=k_to_keep,
original_position=( # pyre-ignore [6]
self.sink_size + num_to_evict
),
new_position=self.sink_size,
)
k_to_keep = self.rope.rerotate_k(
k=k_to_keep.transpose(1, 2),
original_position=(self.sink_size + num_to_evict), # pyre-ignore [6]
new_position=self.sink_size,
).transpose(1, 2)
self.k_cache = torch.cat(
[
self.k_cache.narrow(dim_to_slice, 0, self.sink_size),
Expand Down Expand Up @@ -278,7 +265,6 @@ def _replace_attention(
kv_cache_with_attention_sink = KVCacheWithAttentionSink(
n_heads=kv_cache.n_heads,
head_dim=kv_cache.head_dim,
transpose_cache=kv_cache.transpose_cache,
enable_dynamic_shape=kv_cache.enable_dynamic_shape,
rope=rope_with_attention_sink,
max_batch_size=kv_cache.max_batch_size,
Expand All @@ -288,7 +274,6 @@ def _replace_attention(
dtype=kv_cache.k_cache.dtype,
)
child_module.kv_cache = kv_cache_with_attention_sink
child_module.SDPA.kv_cache = kv_cache_with_attention_sink
child_module.forward = types.MethodType( # pyre-ignore
attention_sink_forward, child_module
)
Expand Down
Loading
Loading