Skip to content
92 changes: 76 additions & 16 deletions examples/models/llama/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,12 @@ def __init__(
head_dim: int,
n_rep: int,
max_context_len: int,
enable_dynamic_shape: bool,
):
super().__init__()
self.dim = dim
self.head_dim = head_dim
self.n_rep = n_rep
self.max_context_len = max_context_len
self.enable_dynamic_shape = enable_dynamic_shape

def forward(
self,
Expand All @@ -142,21 +140,12 @@ def forward(
seqlen,
mask: torch.Tensor,
) -> torch.Tensor:
if self.enable_dynamic_shape:
start_pos = input_pos[-1].item()
torch._check_is_size(start_pos)
torch._check(start_pos < self.max_context_len)
seq_length = q.size(2)
# pyre-ignore: Incompatible parameter type [6]
attn_mask = mask.narrow(0, start_pos, seq_length)
else:
attn_mask = mask[None, None, input_pos]

# TODO(kimishpatel): This should not be necessary because scaled_dot_product_attention
# can natively support GQA now. But needs enable_gqa=True
k = k.repeat_interleave(self.n_rep, dim=1)
v = v.repeat_interleave(self.n_rep, dim=1)
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0)
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)

return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)

Expand Down Expand Up @@ -236,21 +225,79 @@ def __init__(
enable_dynamic_shape: bool,
dtype=torch.float32,
):
self.window_size = max_context_length
"""
Reason why we want the kv cache size to be twice the context length:
Sliding window attention without ringbuffer
pos 0 1 2 3 4 5 6 7 8 9 10
0 x 0 0 0 0 0 0 0 0 0 0
1 x x 0 0 0 0 0 0 0 0 0
2 x x x 0 0 0 0 0 0 0 0
3 x x x x 0 0 0 0 0 0 0
4 0 x x x x 0 0 0 0 0 0
5 0 0 x x x x 0 0 0 0 0
6 0 0 0 x x x x 0 0 0 0
7 0 0 0 0 x x x x 0 0 0
8 0 0 0 0 0 x x x x 0 0
9 0 0 0 0 0 0 x x x x 0
10 0 0 0 0 0 0 0 x x x x

So when doing attention for pos = 5 and seq_len = 4 our attention
mask would be
5 0 0 x x x x 0 0 0 0 0
6 0 0 0 x x x x 0 0 0 0
7 0 0 0 0 x x x x 0 0 0
8 0 0 0 0 0 x x x x 0 0
Thus tok at pos = 5 is able to attend to tokens at pos 2, 3 and 4.
This is how training is done.

Now lets consider ring kv cache of size 4. When we are at pos = 5
before updating the kv cache, state of the kv cache would be
[4 1 2 3]. That is we evicted token at pos = 0 out. Now during
attention calculation at pos = 5 seq len = 4, we will update cache and
new pos in the cache would be [8 5 6 7]. So note that 5 can now only attend
to itself. Not 2, 3 and 4 as you would have during training.
So not having kept 2, 3 and 4 in cache means we will have divergent behavior.
Worst case of this would have been when update it equal to the length of
the cache. like in our case pos = 5 seq len = 4.
Thus we need to have a cache that is larger. How much larger, as much as
the sliding window size. So twice the max_context_length.
How would that have helped. Lets see. At pos = 5 our cache would have
[0, 1, 2, 3, 4, NA, NA, NA] After cache update we would have
[8, 1, 2, 3, 4, 5, 6, 7]. We kicked out token at pos = 0. However, the
current step still has access to [pos - sliding_window_size, pos] tokens.

To make sure we dont over attend, i.e. we dont have pos = 5
to attend to pos = 1, mask calculaton has to account for the sliding window
size.
"""
super().__init__(
max_batch_size,
max_context_length,
max_context_length * 2,
n_heads,
head_dim,
enable_dynamic_shape,
dtype,
)
self.cache_positions_manager = CachePositionsManager(max_context_length)
self.cache_positions_manager = CachePositionsManager(self.max_context_length)
self.is_ring_buffer = True

def create_causal_mask_for_ring_buffer(self, start_pos, seq_len):
pos_q = start_pos + torch.arange(seq_len, dtype=torch.long).view(-1, 1)
cache_positions = self.cache_positions_manager.cache_positions
delta = pos_q - cache_positions
attn_mask = (cache_positions >= 0) & (delta >= 0) & (delta < self.window_size)
attn_mask = torch.where(attn_mask == True, 0, float("-inf")) # noqa E712
return attn_mask

def update(
self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# input_pos: [S], k_val: [B, H, S, D]
seq_len = k_val.size(2)
assert seq_len <= self.k_cache.size(
2
), f"Update sequence length({seq_len}) for kv cache must be smaller than the cache size({self.k_cache.size(2)})"
indices = self.cache_positions_manager.calculate_positions_and_update_indices(
input_pos, seq_len
)
Expand Down Expand Up @@ -286,6 +333,7 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
self.attention_qkv_bias = args.attention_qkv_bias
self.use_qk_norm = args.use_qk_norm
self.qk_norm_before_rope = args.qk_norm_before_rope
self.enable_dynamic_shape = args.enable_dynamic_shape

if self.use_qk_norm:
q_norm_dim = self.head_dim
Expand Down Expand Up @@ -331,7 +379,6 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
head_dim=self.head_dim,
n_rep=self.n_rep,
max_context_len=self.max_context_len,
enable_dynamic_shape=args.enable_dynamic_shape,
)

def forward(
Expand Down Expand Up @@ -368,8 +415,21 @@ def forward(

if self.use_kv_cache:
assert input_pos is not None
if self.enable_dynamic_shape:
start_pos = input_pos[-1].item()
torch._check_is_size(start_pos)
torch._check(start_pos < self.max_context_len)
seq_length = q.size(2)
# pyre-ignore: Incompatible parameter type [6]
attn_mask = self.mask.narrow(0, start_pos, seq_length)
else:
attn_mask = self.mask[None, None, input_pos]
k, v = self.kv_cache.update(input_pos, k, v)
output = self.SDPA(input_pos, q, k, v, bsz, seqlen, self.mask)
if getattr(self.kv_cache, "is_ring_buffer", False):
attn_mask = self.kv_cache.create_causal_mask_for_ring_buffer(
input_pos[0].item(), seqlen
)
output = self.SDPA(input_pos, q, k, v, bsz, seqlen, attn_mask)
return self.wo(output), None

# grouped multiquery attention: expand out keys and values
Expand Down
72 changes: 39 additions & 33 deletions examples/models/llama/source_transformation/sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,11 @@ class SDPACustom(torch.nn.Module):
def __init__(
self,
dim: int,
max_context_len,
enable_dynamic_shape,
use_attention_mask: bool = False,
):
super().__init__()
self.dim = dim
self.max_context_len = max_context_len
self.use_attention_mask = use_attention_mask
self.enable_dynamic_shape = enable_dynamic_shape

def forward(
self,
Expand All @@ -42,16 +38,6 @@ def forward(
seqlen,
mask,
):
if self.use_attention_mask:
if self.enable_dynamic_shape:
start_pos = input_pos[-1].item()
torch._check_is_size(start_pos)
torch._check(start_pos < self.max_context_len)
seq_length = q.size(2)
mask = mask.narrow(0, start_pos, seq_length)
else:
mask = mask[input_pos]

q = q.transpose(1, 2) # (bs, seqlen, n_local_heads, head_dim)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
Expand Down Expand Up @@ -96,8 +82,6 @@ def _replace_sdpa_with_custom_op(
name,
SDPACustom(
child.dim,
child.max_context_len,
child.enable_dynamic_shape,
use_attention_mask=use_attention_mask,
),
)
Expand Down Expand Up @@ -133,12 +117,15 @@ class QuantizedSDPA(torch.nn.Module):
zero points, we need to pass kv_cache to SDPA.
"""

def __init__(self, dim: int, kv_cache: QuantizedKVCache):
def __init__(
self, dim: int, kv_cache: QuantizedKVCache, use_attention_mask: bool = False
):
super().__init__()
self.dim = dim
self.quantized_dtype = torch.int8
self.float_dtype = torch.float32
self.kv_cache = kv_cache
self.use_attention_mask = use_attention_mask

def forward(
self,
Expand Down Expand Up @@ -176,22 +163,40 @@ def forward(
v_scale_fp32 = self.kv_cache.v_cache_scales

start_pos = input_pos[0].item()
output = torch.ops.llama.custom_quantized_sdpa(
q_quantized,
k_quantized,
v_quantized,
start_pos,
None,
0,
True,
None,
q_zero_point_int8,
q_scale_fp32,
k_zero_point_int8,
k_scale_fp32,
v_zero_point_int8,
v_scale_fp32,
)
if self.use_attention_mask:
output = torch.ops.llama.custom_quantized_sdpa(
q_quantized,
k_quantized,
v_quantized,
start_pos,
mask,
0,
False,
None,
q_zero_point_int8,
q_scale_fp32,
k_zero_point_int8,
k_scale_fp32,
v_zero_point_int8,
v_scale_fp32,
)
else:
output = torch.ops.llama.custom_quantized_sdpa(
q_quantized,
k_quantized,
v_quantized,
start_pos,
None,
0,
True,
None,
q_zero_point_int8,
q_scale_fp32,
k_zero_point_int8,
k_scale_fp32,
v_zero_point_int8,
v_scale_fp32,
)

return output.view(bsz, seqlen, self.dim)

Expand All @@ -201,6 +206,7 @@ def _update_attention_module_with_quantized_sdpa(
):
sdpa = getattr(module, "SDPA", None)
assert sdpa is not None
# TODO: add support for SDPA with attention mask
# pyre-ignore
setattr(module, "SDPA", QuantizedSDPA(sdpa.dim, kv_cache)) # noqa: B010

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(
self.dim = dim
self.head_dim = head_dim
self.n_rep = n_rep
self.SDPA = SDPA(dim, head_dim, n_rep, max_context_len, enable_dynamic_shape)
self.SDPA = SDPA(dim, head_dim, n_rep, max_context_len)
self.kv_cache = None

def forward(self, x, freqs_cos, freqs_sin, **kwargs):
Expand Down Expand Up @@ -159,15 +159,9 @@ def test_forward_functionality(self):
k_quantized, v_quantized = model.attention.kv_cache.update(input_pos, k, v)

# Run the forward pass with the quantized SDPA
try:
output = model.attention.SDPA(
input_pos, q, k_quantized, v_quantized, bsz, seqlen, None
)
output = model.attention.SDPA(
input_pos, q, k_quantized, v_quantized, bsz, seqlen, None
)

# Verify the output shape
self.assertEqual(output.shape, (bsz, seqlen, self.dim))
except Exception:
# If the forward pass fails, it might be due to missing custom ops
self.skipTest(
"Custom ops not available, skipping forward functionality test"
)
# Verify the output shape
self.assertEqual(output.shape, (bsz, seqlen, self.dim))
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ def test_simple(self, is_dynamic_shape=False):
self.seq_len = 3
self._init_cache()
q, k_val, v_val = self._init_kv()
self.float_sdpa = SDPACustom(self.dim, self.max_context_len, True)
self.quantized_sdpa = SDPACustom(self.dim, self.max_context_len, True)
self.float_sdpa = SDPACustom(self.dim)
self.quantized_sdpa = SDPACustom(self.dim)
k, v = self.custom_kv_cache.update(input_pos, k_val, v_val)
float_out = self.float_sdpa(input_pos, q, k, v, 1, self.seq_len, None)
k, v = self.quantized_kv_cache.update(input_pos, k_val, v_val)
Expand Down
11 changes: 11 additions & 0 deletions examples/models/llama/tests/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,14 @@ python_unittest(
"//executorch/examples/models/llama:llama_transformer",
],
)

python_unittest(
name = "test_ring_attention",
srcs = [
"test_ring_attention.py",
],
deps = [
"//caffe2:torch",
"//executorch/examples/models/llama:llama_transformer",
],
)
Loading
Loading