Skip to content

Commit 09786d8

Browse files
author
Martin Yuan
committed
Add abstract base class for attention mechanisms with unified interface
1 parent 15c772c commit 09786d8

File tree

2 files changed

+140
-98
lines changed

2 files changed

+140
-98
lines changed

examples/models/llama/attention.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Optional, Tuple, Any
3+
import torch
4+
import torch.nn as nn
5+
import torch.nn.functional as F
6+
from executorch.examples.models.llama.llama_transformer import ModelArgs
7+
8+
class Attention(nn.Module, ABC):
9+
"""Abstract base class for attention mechanisms with unified interface."""
10+
@abstractmethod
11+
def forward(
12+
self,
13+
x: torch.Tensor,
14+
freqs_cos: torch.Tensor,
15+
freqs_sin: torch.Tensor,
16+
mask: Optional[torch.Tensor] = None,
17+
input_pos: Optional[torch.Tensor] = None,
18+
in_cache_state: Optional[Any] = None,
19+
out_cache_state: Optional[Any] = None,
20+
) -> Tuple[torch.Tensor, Optional[Any]]:
21+
"""Forward pass for attention mechanism.
22+
23+
Args:
24+
x: Input tensor of shape (batch_size, seq_len, dim)
25+
freqs_cos, freqs_sin: Rotary position embedding frequencies
26+
mask: Optional attention mask
27+
input_pos: Positions for KV cache updates
28+
in_cache_state/out_cache_state: Cache states
29+
30+
Returns:
31+
Tuple of (output tensor, updated cache state)
32+
"""
33+
pass
34+
35+
class AttentionMHA(Attention):
36+
def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
37+
super().__init__()
38+
# Architecture configuration
39+
self.use_kv_cache = args.use_kv_cache
40+
self.n_heads = args.n_heads
41+
self.n_kv_heads = self.n_heads if args.n_kv_heads is None else args.n_kv_heads
42+
assert self.n_heads % self.n_kv_heads == 0, "Head counts must be divisible"
43+
44+
# Model parallelism preparation (currently 1 for single device)
45+
model_parallel_size = 1
46+
self.n_local_heads = self.n_heads // model_parallel_size
47+
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
48+
49+
# Multi-query attention repetition factor
50+
self.n_rep = self.n_local_heads // self.n_local_kv_heads
51+
self.head_dim = args.head_dim
52+
self.max_batch_size = args.max_batch_size
53+
self.max_seq_len = args.max_seq_len
54+
self.dim = args.dim
55+
56+
# Projection layers (combined heads)
57+
self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False)
58+
self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
59+
self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
60+
self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)
61+
62+
# Layer-specific configuration
63+
self.layer_id = layer_id
64+
self.rope = rope # Rotary position embedding implementation
65+
66+
# Causal mask buffer (not saved in model state)
67+
causal_mask = torch.tril(
68+
torch.ones(self.max_seq_len, self.max_seq_len, dtype=torch.bool, device="cpu")
69+
)
70+
self.register_buffer("mask", causal_mask, persistent=False)
71+
72+
# KV Cache initialization if enabled
73+
if self.use_kv_cache:
74+
self.kv_cache = KVCache(
75+
args.max_batch_size,
76+
args.max_seq_len,
77+
self.n_kv_heads,
78+
self.head_dim,
79+
args.enable_dynamic_shape,
80+
)
81+
self.SDPA = SDPA( # Optimized attention implementation
82+
dim=self.n_local_heads * self.head_dim,
83+
head_dim=self.head_dim,
84+
n_rep=self.n_rep,
85+
max_seq_len=self.max_seq_len,
86+
enable_dynamic_shape=args.enable_dynamic_shape,
87+
)
88+
89+
def forward(
90+
self,
91+
x: torch.Tensor,
92+
freqs_cos: torch.Tensor,
93+
freqs_sin: torch.Tensor,
94+
mask: Optional[torch.Tensor] = None,
95+
input_pos: Optional[torch.Tensor] = None,
96+
in_cache_state: Optional[Any] = None,
97+
out_cache_state: Optional[Any] = None,
98+
) -> Tuple[torch.Tensor, Optional[Any]]:
99+
bsz, seqlen, _ = x.shape
100+
101+
# QKV projections with view operations to split heads
102+
q, k, v = self.wq(x), self.wk(x), self.wv(x)
103+
q = q.view(bsz, seqlen, self.n_local_heads, self.head_dim) # Split into heads
104+
k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
105+
v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
106+
107+
# Rotary position embeddings (applied to both queries and keys)
108+
q, k = self.rope(q, k, freqs_cos, freqs_sin)
109+
110+
# Transpose for attention computation: (bs, heads, seqlen, dim)
111+
q, k = q.transpose(1, 2), k.transpose(1, 2)
112+
v = v.transpose(1, 2)
113+
114+
# KV Cache path (optimized for incremental decoding)
115+
if self.use_kv_cache:
116+
assert input_pos is not None, "input_pos required for cache updates"
117+
k, v = self.kv_cache.update(input_pos, k, v) # Update cache
118+
# Use optimized SDPA implementation with cache
119+
output = self.SDPA(input_pos, q, k, v, bsz, seqlen, self.mask)
120+
return self.wo(output), None # No cache state needed for Code A
121+
122+
# Non-cached path (full sequence processing)
123+
# Expand KV heads to match Q heads for grouped multi-query attention
124+
k = k.repeat_interleave(self.n_rep, dim=1)
125+
v = v.repeat_interleave(self.n_rep, dim=1)
126+
127+
# Use PyTorch's optimized attention implementation
128+
output = F.scaled_dot_product_attention(
129+
q, k, v,
130+
attn_mask=self.mask[:seqlen, :seqlen], # Causal mask
131+
dropout_p=0.0
132+
)
133+
# Recombine heads and project to output dimension
134+
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
135+
return self.wo(output), None

examples/models/llama/llama_transformer.py

Lines changed: 5 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313

1414
import torch
1515
import torch.nn.functional as F
16+
from executorch.examples.models.llama.attention import (
17+
AttentionMHA,
18+
)
1619

1720
from executorch.examples.models.llama.rope import (
1821
hf_apply_rotary_emb,
@@ -328,102 +331,6 @@ def forward(
328331
return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
329332

330333

331-
class Attention(nn.Module):
332-
def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
333-
super().__init__()
334-
self.use_kv_cache = args.use_kv_cache
335-
self.n_heads = args.n_heads
336-
self.n_kv_heads = self.n_heads if args.n_kv_heads is None else args.n_kv_heads
337-
assert self.n_heads % self.n_kv_heads == 0
338-
model_parallel_size = 1
339-
self.n_local_heads = self.n_heads // model_parallel_size
340-
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
341-
self.n_rep = self.n_local_heads // self.n_local_kv_heads
342-
self.head_dim = args.head_dim
343-
self.max_batch_size = args.max_batch_size
344-
self.max_seq_len = args.max_seq_len
345-
self.dim = args.dim
346-
self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False)
347-
self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
348-
self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
349-
self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)
350-
351-
self.layer_id = layer_id
352-
353-
self.rope = rope
354-
355-
causal_mask = torch.tril(
356-
torch.ones(
357-
self.max_seq_len,
358-
self.max_seq_len,
359-
dtype=torch.bool,
360-
device="cpu",
361-
)
362-
)
363-
self.register_buffer("mask", causal_mask, persistent=False)
364-
365-
if self.use_kv_cache:
366-
self.kv_cache = KVCache(
367-
args.max_batch_size,
368-
args.max_seq_len,
369-
self.n_kv_heads,
370-
self.head_dim,
371-
args.enable_dynamic_shape,
372-
)
373-
self.SDPA = SDPA(
374-
dim=self.n_local_heads * self.head_dim,
375-
head_dim=self.head_dim,
376-
n_rep=self.n_rep,
377-
max_seq_len=self.max_seq_len,
378-
enable_dynamic_shape=args.enable_dynamic_shape,
379-
)
380-
381-
def forward(
382-
self,
383-
x: torch.Tensor,
384-
freqs_cos: torch.Tensor,
385-
freqs_sin: torch.Tensor,
386-
input_pos: Optional[torch.Tensor] = None,
387-
):
388-
bsz, seqlen, _ = x.shape
389-
390-
# QKV
391-
q, k, v = self.wq(x), self.wk(x), self.wv(x)
392-
# We need view_copy elimination
393-
q = q.view(bsz, seqlen, self.n_local_heads, self.head_dim)
394-
k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
395-
v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
396-
397-
# RoPE relative positional embeddings
398-
q, k = self.rope.forward(q, k, freqs_cos, freqs_sin)
399-
400-
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
401-
k = k.transpose(1, 2)
402-
v = v.transpose(1, 2)
403-
404-
if self.use_kv_cache:
405-
assert input_pos is not None
406-
k, v = self.kv_cache.update(input_pos, k, v)
407-
output = self.SDPA(input_pos, q, k, v, bsz, seqlen, self.mask)
408-
return self.wo(output)
409-
410-
# grouped multiquery attention: expand out keys and values
411-
k = k.repeat_interleave(self.n_rep, dim=1)
412-
v = v.repeat_interleave(self.n_rep, dim=1)
413-
414-
assert hasattr(self, "mask")
415-
416-
mask = self.mask[:seqlen, :seqlen]
417-
418-
output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
419-
420-
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
421-
422-
output = self.wo(output)
423-
424-
return output
425-
426-
427334
class FeedForward(nn.Module):
428335
def __init__(self, args: ModelArgs):
429336
super().__init__()
@@ -490,7 +397,7 @@ def __init__(self, layer_id: int, args: ModelArgs, rope: Rope):
490397
self.n_heads = args.n_heads
491398
self.dim = args.dim
492399
self.head_dim = args.head_dim
493-
self.attention = Attention(args, layer_id, rope)
400+
self.attention = AttentionMHA(args, layer_id, rope)
494401
if args.moe:
495402
self.block_sparse_moe = MOEFeedForward(args)
496403
else:
@@ -500,7 +407,7 @@ def __init__(self, layer_id: int, args: ModelArgs, rope: Rope):
500407

501408
def forward(self, x, freqs_cos, freqs_sin, input_pos=None): # x: 1xN
502409
h = self.attention.forward(
503-
self.attention_norm(x), freqs_cos, freqs_sin, input_pos
410+
self.attention_norm(x), freqs_cos, freqs_sin, input_pos=input_pos
504411
)
505412

506413
h = x + h

0 commit comments

Comments
 (0)