1111from torch import nn
1212from torchtune .modules .attention_utils import _MaskType , _sdpa_or_flex_attention
1313from torchtune .modules .kv_cache import KVCache
14+ from torchtune .modules .sdpa import SDPA
1415
1516logger = logging .getLogger (__name__ )
1617
@@ -126,6 +127,8 @@ def __init__(
126127 self .head_dim = head_dim
127128 self .max_seq_len = max_seq_len
128129 self .is_causal = is_causal
130+ # Number of queries per k, v
131+ self .q_per_kv = self .num_heads // self .num_kv_heads
129132
130133 # Set layers
131134 self .kv_cache = kv_cache
@@ -139,6 +142,11 @@ def __init__(
139142
140143 # Use flex attention if supported and we are sample packing
141144 self ._attention_call = _sdpa_or_flex_attention ()
145+ self ._sdpa = SDPA (
146+ attention_fn = self ._attention_call ,
147+ kv_cache = self .kv_cache ,
148+ q_per_kv = self .q_per_kv ,
149+ )
142150
143151 def setup_cache (
144152 self , batch_size : int , dtype : torch .dtype , max_seq_len : int
@@ -227,18 +235,12 @@ def forward(
227235
228236 # q has shape [b, s_x, num_heads * head_dim]
229237 q = self .q_proj (x )
230-
231- # number of queries per key/value
232- q_per_kv = self .num_heads // self .num_kv_heads
233- q = q .view (b , s_x , self .num_kv_heads * q_per_kv , self .head_dim )
238+ q = q .view (b , s_x , self .num_kv_heads * self .q_per_kv , self .head_dim )
234239
235240 # Apply positional embeddings
236241 if self .pos_embeddings is not None :
237242 q = self .pos_embeddings (q , input_pos = input_pos )
238243
239- # [b, n_h, s_x, h_d]
240- q = q .transpose (1 , 2 )
241-
242244 # Normalize q
243245 if self .q_norm is not None :
244246 q = self .q_norm (q )
@@ -261,48 +263,17 @@ def forward(
261263 # Apply positional embeddings
262264 # k: [b, s_y, n_kv, h_d]
263265 k = k .view (b , s_y , - 1 , self .head_dim )
266+ v = v .view (b , s_y , - 1 , self .head_dim )
264267 if self .pos_embeddings is not None :
265268 k = self .pos_embeddings (k , input_pos = input_pos )
266269
267- # View + expand + reshape bring num_kv_heads to num_heads for k and v
268- # to match q.
269-
270- # k: [b, s_y, n_kv, 1, h_d]
271- # v: [b, s_y, n_kv, 1, h_d]
272- k = k .view (b , s_y , self .num_kv_heads , 1 , self .head_dim )
273- v = v .view (b , s_y , self .num_kv_heads , 1 , self .head_dim )
274-
275- # If needed, expand the key and value tensors to have the same shape
276- # as the query tensor by copying values across the relevant dim
277- if self .num_heads != self .num_kv_heads :
278- k = k .expand (b , s_y , self .num_kv_heads , q_per_kv , self .head_dim )
279- v = v .expand (b , s_y , self .num_kv_heads , q_per_kv , self .head_dim )
280-
281- # [b, s, n_h, h_d]
282- k = k .reshape (b , s_y , - 1 , self .head_dim )
283- v = v .reshape (b , s_y , - 1 , self .head_dim )
284-
285- # [b, n_h, s, h_d]
286- k = k .transpose (1 , 2 )
287- v = v .transpose (1 , 2 )
288-
289270 # Normalize k
290271 if self .k_norm is not None :
291272 k = self .k_norm (k )
292273
293274 # Update key-value cache
294275 if self .kv_cache is not None :
295- k , v = self .kv_cache .update (k , v )
296-
297- output = self ._attention_call (
298- q ,
299- k ,
300- v ,
301- mask = mask ,
302- dropout_p = self .attn_dropout ,
303- is_causal = self .kv_cache is None and mask is None and self .is_causal ,
304- )
276+ SDPA .kv_cache_update (input_pos , k , v )
305277
306- # reshape the output to be the same shape as the input
307- output = output .transpose (1 , 2 ).contiguous ().view (b , s_x , - 1 )
278+ output = SDPA .sdpa (q , k , v , b , s_x )
308279 return self .output_proj (output )
0 commit comments