Skip to content

Commit 1e3d978

Browse files
committed
Custom SDPA in attention
1 parent 8145cda commit 1e3d978

File tree

1 file changed

+42
-22
lines changed

1 file changed

+42
-22
lines changed

extension/llm/modules/attention.py

Lines changed: 42 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import torch
1111
import torchtune.modules.attention as TorchTuneAttention
1212
from executorch.extension.llm.modules.kv_cache import KVCache as InferenceKVCache
13+
from executorch.extension.llm.custom_ops import custom_ops
1314
from torch import nn
1415
from torchtune.modules.attention_utils import _MaskType, _sdpa_or_flex_attention
1516
from torchtune.modules.kv_cache import KVCache
@@ -146,6 +147,7 @@ def __init__(
146147
# Use flex attention if supported and we are sample packing
147148
self._attention_call = _sdpa_or_flex_attention()
148149
self._sdpa = SDPA(
150+
max_seq_len=self.max_seq_len,
149151
num_kv_heads=self.num_kv_heads,
150152
num_heads=self.num_heads,
151153
head_dim=self.head_dim,
@@ -310,7 +312,7 @@ def false_fn(y):
310312
self.kv_cache.v_cache.copy_(v)
311313
self.kv_cache.cache_pos.copy_(cache_pos)
312314

313-
output = self._sdpa(q, k, v, b, s_x, mask=mask)
315+
output = self._sdpa(q, k, v, b, s_x, mask=mask, input_pos=input_pos)
314316
return self.output_proj(output)
315317

316318

@@ -322,6 +324,7 @@ class SDPA(nn.Module):
322324

323325
def __init__(
324326
self,
327+
max_seq_len: int,
325328
num_kv_heads: int,
326329
num_heads: int,
327330
head_dim: int,
@@ -331,6 +334,7 @@ def __init__(
331334
kv_cache,
332335
) -> None:
333336
super().__init__()
337+
self.max_seq_len = max_seq_len
334338
self.num_kv_heads = num_kv_heads
335339
self.num_heads = num_heads
336340
self.head_dim = head_dim
@@ -348,32 +352,48 @@ def forward(
348352
bsz: int,
349353
seq_len: int,
350354
mask: Optional[_MaskType] = None,
355+
# Below args are only used for ET custom sdpa op.
356+
input_pos: Optional[torch.Tensor] = None,
351357
) -> torch.Tensor:
352-
# View + expand + reshape bring num_kv_heads to num_heads for k and v
353-
# to match q.
354-
355-
# [bsz, n_h, s, h_d]
356-
q = q.transpose(1, 2)
357-
k = k.transpose(1, 2)
358-
v = v.transpose(1, 2)
359-
360-
# Expand the key and value tensors to have the same shape
361-
# as the query tensor by copying values across the relevant dim
362-
if self.num_heads != self.num_kv_heads:
363-
expand_shape = (-1, -1, self.q_per_kv, -1, -1)
364-
k = k.unsqueeze(2).expand(expand_shape).flatten(1, 2)
365-
v = v.unsqueeze(2).expand(expand_shape).flatten(1, 2)
366-
367-
output = self._attention_fn(
358+
start_pos = input_pos[0][-1].item() - seq_len + 1
359+
torch._check_is_size(start_pos)
360+
torch._check(start_pos <= self.max_seq_len)
361+
output = torch.ops.llama.custom_sdpa(
368362
q,
369363
k,
370364
v,
371-
mask=mask,
372-
dropout_p=self.attn_dropout,
373-
is_causal=self.kv_cache is None and mask is None and self.is_causal,
365+
start_pos,
366+
None, # Attention mask
367+
0, # dropout probability. Ignored by the code
368+
True, # is_causal TODO: flip to false if kv cache is enabled???
374369
)
375-
# Reshape the output to be the same shape as the input
376-
return output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
370+
return output.view(bsz, seq_len, -1)
371+
372+
# # View + expand + reshape bring num_kv_heads to num_heads for k and v
373+
# # to match q.
374+
375+
# # [bsz, n_h, s, h_d]
376+
# q = q.transpose(1, 2)
377+
# k = k.transpose(1, 2)
378+
# v = v.transpose(1, 2)
379+
380+
# # Expand the key and value tensors to have the same shape
381+
# # as the query tensor by copying values across the relevant dim
382+
# if self.num_heads != self.num_kv_heads:
383+
# expand_shape = (-1, -1, self.q_per_kv, -1, -1)
384+
# k = k.unsqueeze(2).expand(expand_shape).flatten(1, 2)
385+
# v = v.unsqueeze(2).expand(expand_shape).flatten(1, 2)
386+
387+
# output = self._attention_fn(
388+
# q,
389+
# k,
390+
# v,
391+
# mask=mask,
392+
# dropout_p=self.attn_dropout,
393+
# is_causal=self.kv_cache is None and mask is None and self.is_causal,
394+
# )
395+
# # Reshape the output to be the same shape as the input
396+
# return output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
377397

378398

379399
def _replace_mha_with_inference_mha(module: torch.nn.Module) -> None:

0 commit comments

Comments
 (0)