1010import torch
1111import torchtune .modules .attention as TorchTuneAttention
1212from executorch .extension .llm .modules .kv_cache import KVCache as InferenceKVCache
13+ from executorch .extension .llm .custom_ops import custom_ops
1314from torch import nn
1415from torchtune .modules .attention_utils import _MaskType , _sdpa_or_flex_attention
1516from torchtune .modules .kv_cache import KVCache
@@ -146,6 +147,7 @@ def __init__(
146147 # Use flex attention if supported and we are sample packing
147148 self ._attention_call = _sdpa_or_flex_attention ()
148149 self ._sdpa = SDPA (
150+ max_seq_len = self .max_seq_len ,
149151 num_kv_heads = self .num_kv_heads ,
150152 num_heads = self .num_heads ,
151153 head_dim = self .head_dim ,
@@ -310,7 +312,7 @@ def false_fn(y):
310312 self .kv_cache .v_cache .copy_ (v )
311313 self .kv_cache .cache_pos .copy_ (cache_pos )
312314
313- output = self ._sdpa (q , k , v , b , s_x , mask = mask )
315+ output = self ._sdpa (q , k , v , b , s_x , mask = mask , input_pos = input_pos )
314316 return self .output_proj (output )
315317
316318
@@ -322,6 +324,7 @@ class SDPA(nn.Module):
322324
323325 def __init__ (
324326 self ,
327+ max_seq_len : int ,
325328 num_kv_heads : int ,
326329 num_heads : int ,
327330 head_dim : int ,
@@ -331,6 +334,7 @@ def __init__(
331334 kv_cache ,
332335 ) -> None :
333336 super ().__init__ ()
337+ self .max_seq_len = max_seq_len
334338 self .num_kv_heads = num_kv_heads
335339 self .num_heads = num_heads
336340 self .head_dim = head_dim
@@ -348,7 +352,23 @@ def forward(
348352 bsz : int ,
349353 seq_len : int ,
350354 mask : Optional [_MaskType ] = None ,
355+ # Below args are only used for ET custom sdpa op.
356+ input_pos : Optional [torch .Tensor ] = None ,
351357 ) -> torch .Tensor :
358+ start_pos = input_pos [0 ][- 1 ].item () - seq_len + 1
359+ torch ._check_is_size (start_pos )
360+ torch ._check (start_pos <= self .max_seq_len )
361+ output = torch .ops .llama .custom_sdpa (
362+ q ,
363+ k ,
364+ v ,
365+ start_pos ,
366+ None , # Attention mask
367+ 0 , # dropout probability. Ignored by the code
368+ True , # is_causal TODO: flip to false if kv cache is enabled???
369+ )
370+ return output .view (bsz , seq_len , - 1 )
371+
352372 # View + expand + reshape bring num_kv_heads to num_heads for k and v
353373 # to match q.
354374
0 commit comments