1111from torch import nn
1212from torchtune .modules .attention_utils import _MaskType , _sdpa_or_flex_attention
1313from torchtune .modules .kv_cache import KVCache
14+ from torchtune .modules .sdpa import SDPA
1415
1516logger = logging .getLogger (__name__ )
1617
@@ -126,6 +127,8 @@ def __init__(
126127 self .head_dim = head_dim
127128 self .max_seq_len = max_seq_len
128129 self .is_causal = is_causal
130+ # Number of queries per k, v
131+ self .q_per_kv = self .num_heads // self .num_kv_heads
129132
130133 # Set layers
131134 self .kv_cache = kv_cache
@@ -139,6 +142,11 @@ def __init__(
139142
140143 # Use flex attention if supported and we are sample packing
141144 self ._attention_call = _sdpa_or_flex_attention ()
145+ self ._sdpa = SDPA (
146+ attention_fn = self ._attention_call ,
147+ kv_cache = self .kv_cache ,
148+ q_per_kv = self .q_per_kv ,
149+ )
142150
143151 def setup_cache (self , batch_size : int , dtype : torch .dtype ) -> None :
144152 """Setup key value caches for attention calculation. If called
@@ -228,18 +236,12 @@ def forward(
228236
229237 # q has shape [b, s_x, num_heads * head_dim]
230238 q = self .q_proj (x )
231-
232- # number of queries per key/value
233- q_per_kv = self .num_heads // self .num_kv_heads
234- q = q .view (b , s_x , self .num_kv_heads * q_per_kv , self .head_dim )
239+ q = q .view (b , s_x , self .num_kv_heads * self .q_per_kv , self .head_dim )
235240
236241 # Apply positional embeddings
237242 if self .pos_embeddings is not None :
238243 q = self .pos_embeddings (q , input_pos = input_pos )
239244
240- # [b, n_h, s_x, h_d]
241- q = q .transpose (1 , 2 )
242-
243245 # Normalize q
244246 if self .q_norm is not None :
245247 q = self .q_norm (q )
@@ -262,48 +264,17 @@ def forward(
262264 # Apply positional embeddings
263265 # k: [b, s_y, n_kv, h_d]
264266 k = k .view (b , s_y , self .num_kv_heads , self .head_dim )
267+ v = v .view (b , s_y , self .num_kv_heads , self .head_dim )
265268 if self .pos_embeddings is not None :
266269 k = self .pos_embeddings (k , input_pos = input_pos )
267270
268- # View + expand + reshape bring num_kv_heads to num_heads for k and v
269- # to match q.
270-
271- # k: [b, s_y, n_kv, 1, h_d]
272- # v: [b, s_y, n_kv, 1, h_d]
273- k = k .view (b , s_y , self .num_kv_heads , 1 , self .head_dim )
274- v = v .view (b , s_y , self .num_kv_heads , 1 , self .head_dim )
275-
276- # Expand the key and value tensors to have the same shape
277- # as the query tensor by copying values across the relevant dim
278- if self .num_heads != self .num_kv_heads :
279- k = k .expand (b , s_y , self .num_kv_heads , q_per_kv , self .head_dim )
280- v = v .expand (b , s_y , self .num_kv_heads , q_per_kv , self .head_dim )
281-
282- # [b, s, n_h, h_d]
283- k = k .reshape (b , s_y , - 1 , self .head_dim )
284- v = v .reshape (b , s_y , - 1 , self .head_dim )
285-
286- # [b, n_h, s, h_d]
287- k = k .transpose (1 , 2 )
288- v = v .transpose (1 , 2 )
289-
290271 # Normalize k
291272 if self .k_norm is not None :
292273 k = self .k_norm (k )
293274
294275 # Update key-value cache
295276 if self .kv_cache is not None :
296- k , v = self .kv_cache .update (input_pos , k , v )
297-
298- output = self ._attention_call (
299- q ,
300- k ,
301- v ,
302- mask = mask ,
303- dropout_p = self .attn_dropout ,
304- is_causal = self .kv_cache is None and mask is None and self .is_causal ,
305- )
277+ SDPA .kv_cache_update (input_pos , k , v )
306278
307- # reshape the output to be the same shape as the input
308- output = output .transpose (1 , 2 ).contiguous ().view (b , s_x , - 1 )
279+ output = SDPA .sdpa (q , k , v , b , s_x )
309280 return self .output_proj (output )
0 commit comments