1010import torch
1111import torchtune .modules .attention as TorchTuneAttention
1212from executorch .extension .llm .modules .kv_cache import KVCache as InferenceKVCache
13+ from executorch .extension .llm .custom_ops import custom_ops
1314from torch import nn
1415from torchtune .modules .attention_utils import _MaskType , _sdpa_or_flex_attention
1516from torchtune .modules .kv_cache import KVCache
@@ -146,6 +147,7 @@ def __init__(
146147 # Use flex attention if supported and we are sample packing
147148 self ._attention_call = _sdpa_or_flex_attention ()
148149 self ._sdpa = SDPA (
150+ max_seq_len = self .max_seq_len ,
149151 num_kv_heads = self .num_kv_heads ,
150152 num_heads = self .num_heads ,
151153 head_dim = self .head_dim ,
@@ -310,7 +312,7 @@ def false_fn(y):
310312 self .kv_cache .v_cache .copy_ (v )
311313 self .kv_cache .cache_pos .copy_ (cache_pos )
312314
313- output = self ._sdpa (q , k , v , b , s_x , mask = mask )
315+ output = self ._sdpa (q , k , v , b , s_x , mask = mask , input_pos = input_pos )
314316 return self .output_proj (output )
315317
316318
@@ -322,6 +324,7 @@ class SDPA(nn.Module):
322324
323325 def __init__ (
324326 self ,
327+ max_seq_len : int ,
325328 num_kv_heads : int ,
326329 num_heads : int ,
327330 head_dim : int ,
@@ -331,6 +334,7 @@ def __init__(
331334 kv_cache ,
332335 ) -> None :
333336 super ().__init__ ()
337+ self .max_seq_len = max_seq_len
334338 self .num_kv_heads = num_kv_heads
335339 self .num_heads = num_heads
336340 self .head_dim = head_dim
@@ -348,32 +352,48 @@ def forward(
348352 bsz : int ,
349353 seq_len : int ,
350354 mask : Optional [_MaskType ] = None ,
355+ # Below args are only used for ET custom sdpa op.
356+ input_pos : Optional [torch .Tensor ] = None ,
351357 ) -> torch .Tensor :
352- # View + expand + reshape bring num_kv_heads to num_heads for k and v
353- # to match q.
354-
355- # [bsz, n_h, s, h_d]
356- q = q .transpose (1 , 2 )
357- k = k .transpose (1 , 2 )
358- v = v .transpose (1 , 2 )
359-
360- # Expand the key and value tensors to have the same shape
361- # as the query tensor by copying values across the relevant dim
362- if self .num_heads != self .num_kv_heads :
363- expand_shape = (- 1 , - 1 , self .q_per_kv , - 1 , - 1 )
364- k = k .unsqueeze (2 ).expand (expand_shape ).flatten (1 , 2 )
365- v = v .unsqueeze (2 ).expand (expand_shape ).flatten (1 , 2 )
366-
367- output = self ._attention_fn (
358+ start_pos = input_pos [0 ][- 1 ].item () - seq_len + 1
359+ torch ._check_is_size (start_pos )
360+ torch ._check (start_pos <= self .max_seq_len )
361+ output = torch .ops .llama .custom_sdpa (
368362 q ,
369363 k ,
370364 v ,
371- mask = mask ,
372- dropout_p = self .attn_dropout ,
373- is_causal = self .kv_cache is None and mask is None and self .is_causal ,
365+ start_pos ,
366+ None , # Attention mask
367+ 0 , # dropout probability. Ignored by the code
368+ True , # is_causal TODO: flip to false if kv cache is enabled???
374369 )
375- # Reshape the output to be the same shape as the input
376- return output .transpose (1 , 2 ).contiguous ().view (bsz , seq_len , - 1 )
370+ return output .view (bsz , seq_len , - 1 )
371+
372+ # # View + expand + reshape bring num_kv_heads to num_heads for k and v
373+ # # to match q.
374+
375+ # # [bsz, n_h, s, h_d]
376+ # q = q.transpose(1, 2)
377+ # k = k.transpose(1, 2)
378+ # v = v.transpose(1, 2)
379+
380+ # # Expand the key and value tensors to have the same shape
381+ # # as the query tensor by copying values across the relevant dim
382+ # if self.num_heads != self.num_kv_heads:
383+ # expand_shape = (-1, -1, self.q_per_kv, -1, -1)
384+ # k = k.unsqueeze(2).expand(expand_shape).flatten(1, 2)
385+ # v = v.unsqueeze(2).expand(expand_shape).flatten(1, 2)
386+
387+ # output = self._attention_fn(
388+ # q,
389+ # k,
390+ # v,
391+ # mask=mask,
392+ # dropout_p=self.attn_dropout,
393+ # is_causal=self.kv_cache is None and mask is None and self.is_causal,
394+ # )
395+ # # Reshape the output to be the same shape as the input
396+ # return output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
377397
378398
379399def _replace_mha_with_inference_mha (module : torch .nn .Module ) -> None :
0 commit comments