Skip to content

Commit 079bab3

Browse files
committed
Custom SDPA in attention
1 parent 8145cda commit 079bab3

File tree

1 file changed

+21
-1
lines changed

1 file changed

+21
-1
lines changed

extension/llm/modules/attention.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import torch
1111
import torchtune.modules.attention as TorchTuneAttention
1212
from executorch.extension.llm.modules.kv_cache import KVCache as InferenceKVCache
13+
from executorch.extension.llm.custom_ops import custom_ops
1314
from torch import nn
1415
from torchtune.modules.attention_utils import _MaskType, _sdpa_or_flex_attention
1516
from torchtune.modules.kv_cache import KVCache
@@ -146,6 +147,7 @@ def __init__(
146147
# Use flex attention if supported and we are sample packing
147148
self._attention_call = _sdpa_or_flex_attention()
148149
self._sdpa = SDPA(
150+
max_seq_len=self.max_seq_len,
149151
num_kv_heads=self.num_kv_heads,
150152
num_heads=self.num_heads,
151153
head_dim=self.head_dim,
@@ -310,7 +312,7 @@ def false_fn(y):
310312
self.kv_cache.v_cache.copy_(v)
311313
self.kv_cache.cache_pos.copy_(cache_pos)
312314

313-
output = self._sdpa(q, k, v, b, s_x, mask=mask)
315+
output = self._sdpa(q, k, v, b, s_x, mask=mask, input_pos=input_pos)
314316
return self.output_proj(output)
315317

316318

@@ -322,6 +324,7 @@ class SDPA(nn.Module):
322324

323325
def __init__(
324326
self,
327+
max_seq_len: int,
325328
num_kv_heads: int,
326329
num_heads: int,
327330
head_dim: int,
@@ -331,6 +334,7 @@ def __init__(
331334
kv_cache,
332335
) -> None:
333336
super().__init__()
337+
self.max_seq_len = max_seq_len
334338
self.num_kv_heads = num_kv_heads
335339
self.num_heads = num_heads
336340
self.head_dim = head_dim
@@ -348,7 +352,23 @@ def forward(
348352
bsz: int,
349353
seq_len: int,
350354
mask: Optional[_MaskType] = None,
355+
# Below args are only used for ET custom sdpa op.
356+
input_pos: Optional[torch.Tensor] = None,
351357
) -> torch.Tensor:
358+
start_pos = input_pos[0][-1].item() - seq_len + 1
359+
torch._check_is_size(start_pos)
360+
torch._check(start_pos <= self.max_seq_len)
361+
output = torch.ops.llama.custom_sdpa(
362+
q,
363+
k,
364+
v,
365+
start_pos,
366+
None, # Attention mask
367+
0, # dropout probability. Ignored by the code
368+
True, # is_causal TODO: flip to false if kv cache is enabled???
369+
)
370+
return output.view(bsz, seq_len, -1)
371+
352372
# View + expand + reshape bring num_kv_heads to num_heads for k and v
353373
# to match q.
354374

0 commit comments

Comments
 (0)