Skip to content

Commit 6738350

Browse files
author
Martin Yuan
committed
Add abstract base class for attention mechanisms with unified interface
1 parent 15c772c commit 6738350

File tree

12 files changed

+416
-368
lines changed

12 files changed

+416
-368
lines changed

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@
3838
update_spill_fill_size,
3939
)
4040

41-
from executorch.examples.models.llama.llama_transformer import ModelArgs, MOEFeedForward
41+
from executorch.examples.models.llama.model_args import ModelArgs
42+
from executorch.examples.models.llama.llama_transformer import MOEFeedForward
4243

4344
from executorch.examples.qualcomm.utils import setup_common_args_and_variables
4445

examples/models/llama/attention.py

Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Optional, Tuple, Any
3+
import torch
4+
import torch.nn as nn
5+
import torch.nn.functional as F
6+
from executorch.examples.models.llama.model_args import ModelArgs
7+
from executorch.examples.models.llama.rope import Rope
8+
9+
class Attention(nn.Module, ABC):
10+
"""Abstract base class for attention mechanisms with unified interface."""
11+
@abstractmethod
12+
def forward(
13+
self,
14+
x: torch.Tensor,
15+
freqs_cos: torch.Tensor,
16+
freqs_sin: torch.Tensor,
17+
mask: Optional[torch.Tensor] = None,
18+
input_pos: Optional[torch.Tensor] = None,
19+
in_cache_state: Optional[Any] = None,
20+
out_cache_state: Optional[Any] = None,
21+
) -> Tuple[torch.Tensor, Optional[Any]]:
22+
"""Forward pass for attention mechanism.
23+
24+
Args:
25+
x: Input tensor of shape (batch_size, seq_len, dim)
26+
freqs_cos, freqs_sin: Rotary position embedding frequencies
27+
mask: Optional attention mask
28+
input_pos: Positions for KV cache updates
29+
in_cache_state/out_cache_state: Cache states
30+
31+
Returns:
32+
Tuple of (output tensor, updated cache state)
33+
"""
34+
pass
35+
36+
class KVCache(nn.Module):
37+
def __init__(
38+
self,
39+
max_batch_size: int,
40+
max_seq_length: int,
41+
n_heads: int,
42+
head_dim: int,
43+
enable_dynamic_shape: bool,
44+
dtype=torch.float32,
45+
):
46+
super().__init__()
47+
self.max_seq_length = max_seq_length
48+
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
49+
50+
self.max_batch_size = max_batch_size
51+
self.n_heads = n_heads
52+
self.head_dim = head_dim
53+
self.enable_dynamic_shape = enable_dynamic_shape
54+
self.register_buffer(
55+
"k_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu")
56+
)
57+
self.register_buffer(
58+
"v_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu")
59+
)
60+
61+
def update(
62+
self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor
63+
) -> Tuple[torch.Tensor, torch.Tensor]:
64+
# input_pos: [S], k_val: [B, H, S, D]
65+
if self.enable_dynamic_shape:
66+
start_pos = input_pos[0].item()
67+
torch._check_is_size(start_pos)
68+
torch._check(start_pos < self.max_seq_length)
69+
dim_to_slice = 2
70+
seq_length = k_val.size(dim_to_slice)
71+
# Replace the entry in the cache for this token
72+
# The following lines are equivalent to:
73+
# cache_k[:bsz, start_pos : start_pos + seqlen] = xk
74+
# cache_v[:bsz, start_pos : start_pos + seqlen] = xv
75+
# when dim_to_slice is 1
76+
# We use .narrow() here to make the compiler happy
77+
# pyre-ignore: Incompatible parameter type [6]
78+
narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length)
79+
# pyre-ignore: Incompatible parameter type [6]
80+
narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length)
81+
82+
narrowed_k.copy_(k_val)
83+
narrowed_v.copy_(v_val)
84+
return self.k_cache, self.v_cache
85+
else:
86+
k_out = self.k_cache
87+
v_out = self.v_cache
88+
k_out[:, :, input_pos] = k_val
89+
v_out[:, :, input_pos] = v_val
90+
91+
return k_out, v_out
92+
93+
class SDPA(nn.Module):
94+
def __init__(
95+
self,
96+
dim: int,
97+
head_dim: int,
98+
n_rep: int,
99+
max_seq_len: int,
100+
enable_dynamic_shape: bool,
101+
):
102+
super().__init__()
103+
self.dim = dim
104+
self.head_dim = head_dim
105+
self.n_rep = n_rep
106+
self.max_seq_len = max_seq_len
107+
self.enable_dynamic_shape = enable_dynamic_shape
108+
109+
def forward(
110+
self,
111+
input_pos: torch.Tensor,
112+
q: torch.Tensor, # Already have rotary embeddings. (bs, n_local_heads, seqlen, head_dim)
113+
k: torch.Tensor, # Already have rotary embeddings. (bs, n_local_kv_heads, seqlen, head_dim)
114+
v: torch.Tensor, # (bs, n_local_kv_heads, seqlen, head_dim)
115+
bsz,
116+
seqlen,
117+
mask: torch.Tensor,
118+
) -> torch.Tensor:
119+
if self.enable_dynamic_shape:
120+
start_pos = input_pos[-1].item()
121+
torch._check_is_size(start_pos)
122+
torch._check(start_pos < self.max_seq_len)
123+
seq_length = q.size(2)
124+
# pyre-ignore: Incompatible parameter type [6]
125+
attn_mask = mask.narrow(0, start_pos, seq_length)
126+
else:
127+
attn_mask = mask[None, None, input_pos]
128+
129+
# TODO(kimishpatel): This should not be necessary because scaled_dot_product_attention
130+
# can natively support GQA now. But needs enable_gqa=True
131+
k = k.repeat_interleave(self.n_rep, dim=1)
132+
v = v.repeat_interleave(self.n_rep, dim=1)
133+
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0)
134+
135+
return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
136+
137+
class AttentionMHA(Attention):
138+
def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
139+
super().__init__()
140+
# Architecture configuration
141+
self.use_kv_cache = args.use_kv_cache
142+
self.n_heads = args.n_heads
143+
self.n_kv_heads = self.n_heads if args.n_kv_heads is None else args.n_kv_heads
144+
assert self.n_heads % self.n_kv_heads == 0, "Head counts must be divisible"
145+
146+
# Model parallelism preparation (currently 1 for single device)
147+
model_parallel_size = 1
148+
self.n_local_heads = self.n_heads // model_parallel_size
149+
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
150+
151+
# Multi-query attention repetition factor
152+
self.n_rep = self.n_local_heads // self.n_local_kv_heads
153+
self.head_dim = args.head_dim
154+
self.max_batch_size = args.max_batch_size
155+
self.max_seq_len = args.max_seq_len
156+
self.dim = args.dim
157+
158+
# Projection layers (combined heads)
159+
self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False)
160+
self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
161+
self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
162+
self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)
163+
164+
# Layer-specific configuration
165+
self.layer_id = layer_id
166+
self.rope = rope # Rotary position embedding implementation
167+
168+
# Causal mask buffer (not saved in model state)
169+
causal_mask = torch.tril(
170+
torch.ones(self.max_seq_len, self.max_seq_len, dtype=torch.bool, device="cpu")
171+
)
172+
self.register_buffer("mask", causal_mask, persistent=False)
173+
174+
# KV Cache initialization if enabled
175+
if self.use_kv_cache:
176+
self.kv_cache = KVCache(
177+
args.max_batch_size,
178+
args.max_seq_len,
179+
self.n_kv_heads,
180+
self.head_dim,
181+
args.enable_dynamic_shape,
182+
)
183+
self.SDPA = SDPA( # Optimized attention implementation
184+
dim=self.n_local_heads * self.head_dim,
185+
head_dim=self.head_dim,
186+
n_rep=self.n_rep,
187+
max_seq_len=self.max_seq_len,
188+
enable_dynamic_shape=args.enable_dynamic_shape,
189+
)
190+
191+
192+
def forward(
193+
self,
194+
x: torch.Tensor,
195+
freqs_cos: torch.Tensor,
196+
freqs_sin: torch.Tensor,
197+
mask: Optional[torch.Tensor] = None,
198+
input_pos: Optional[torch.Tensor] = None,
199+
in_cache_state: Optional[Any] = None,
200+
out_cache_state: Optional[Any] = None,
201+
) -> Tuple[torch.Tensor, Optional[Any]]:
202+
bsz, seqlen, _ = x.shape
203+
204+
# QKV projections with view operations to split heads
205+
q, k, v = self.wq(x), self.wk(x), self.wv(x)
206+
q = q.view(bsz, seqlen, self.n_local_heads, self.head_dim) # Split into heads
207+
k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
208+
v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
209+
210+
# Rotary position embeddings (applied to both queries and keys)
211+
q, k = self.rope(q, k, freqs_cos, freqs_sin)
212+
213+
# Transpose for attention computation: (bs, heads, seqlen, dim)
214+
q, k = q.transpose(1, 2), k.transpose(1, 2)
215+
v = v.transpose(1, 2)
216+
217+
# KV Cache path (optimized for incremental decoding)
218+
if self.use_kv_cache:
219+
assert input_pos is not None, "input_pos required for cache updates"
220+
k, v = self.kv_cache.update(input_pos, k, v) # Update cache
221+
# Use optimized SDPA implementation with cache
222+
output = self.SDPA(input_pos, q, k, v, bsz, seqlen, self.mask)
223+
return self.wo(output), None # No cache state needed for Code A
224+
225+
# Non-cached path (full sequence processing)
226+
# Expand KV heads to match Q heads for grouped multi-query attention
227+
k = k.repeat_interleave(self.n_rep, dim=1)
228+
v = v.repeat_interleave(self.n_rep, dim=1)
229+
230+
# Use PyTorch's optimized attention implementation
231+
output = F.scaled_dot_product_attention(
232+
q, k, v,
233+
attn_mask=self.mask[:seqlen, :seqlen], # Causal mask
234+
dropout_p=0.0
235+
)
236+
# Recombine heads and project to output dimension
237+
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
238+
return self.wo(output), None

0 commit comments

Comments
 (0)