44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7+ from typing import Tuple
8+
79from torch import nn , Tensor
810
911
@@ -17,22 +19,32 @@ class SDPA(nn.Module):
1719
1820 def __init__ (
1921 self ,
22+ num_kv_heads : int ,
23+ num_heads : int ,
24+ head_dim : int ,
25+ q_per_kv : int ,
26+ attn_dropout : float ,
27+ is_causal : bool ,
2028 attention_fn ,
2129 kv_cache ,
22- q_per_kv ,
2330 ) -> None :
2431 super ().__init__ ()
25- self ._attention_fn = attention_fn
26- self .kv_cache = kv_cache
32+ self .num_kv_heads = num_kv_heads
33+ self .num_heads = num_heads
34+ self .head_dim = head_dim
2735 self .q_per_kv = q_per_kv
36+ self .attn_dropout = attn_dropout
37+ self .is_causal = is_causal
38+ self ._attention_fn = attention_fn
39+ self ._kv_cache = kv_cache
2840
2941 def kv_cache_update (
3042 self ,
3143 input_pos : Tensor ,
3244 k : Tensor ,
3345 v : Tensor ,
3446 ) -> Tuple [Tensor , Tensor ]:
35- k , v = self .kv_cache .update (input_pos , k , v )
47+ k , v = self ._kv_cache .update (input_pos , k , v )
3648 return k , v
3749
3850 def sdpa (
@@ -72,7 +84,7 @@ def sdpa(
7284 v ,
7385 mask = mask ,
7486 dropout_p = self .attn_dropout ,
75- is_causal = self .kv_cache is None and mask is None and self .is_causal ,
87+ is_causal = self ._kv_cache is None and mask is None and self .is_causal ,
7688 )
7789 # Reshape the output to be the same shape as the input
78- return output .transpose (1 , 2 ).contiguous ().view (b , s_x , - 1 )
90+ return output .transpose (1 , 2 ).contiguous ().view (bsz , seq_len , - 1 )
0 commit comments