@@ -232,22 +232,16 @@ def __init__(
232232        max_seq_length : int ,
233233        n_heads : int ,
234234        head_dim : int ,
235-         transpose_cache : bool ,
236235        enable_dynamic_shape : bool ,
237236        dtype = torch .float32 ,
238237    ):
239238        super ().__init__ ()
240239        self .max_seq_length  =  max_seq_length 
241-         self .is_transposed  =  transpose_cache 
242-         if  transpose_cache :
243-             cache_shape  =  (max_batch_size , n_heads , max_seq_length , head_dim )
244-         else :
245-             cache_shape  =  (max_batch_size , max_seq_length , n_heads , head_dim )
240+         cache_shape  =  (max_batch_size , n_heads , max_seq_length , head_dim )
246241
247242        self .max_batch_size  =  max_batch_size 
248243        self .n_heads  =  n_heads 
249244        self .head_dim  =  head_dim 
250-         self .transpose_cache  =  transpose_cache 
251245        self .enable_dynamic_shape  =  enable_dynamic_shape 
252246        self .register_buffer (
253247            "k_cache" , torch .zeros (cache_shape , dtype = dtype , device = "cpu" )
@@ -259,12 +253,12 @@ def __init__(
259253    def  update (
260254        self , input_pos : torch .Tensor , k_val : torch .Tensor , v_val : torch .Tensor 
261255    ) ->  Tuple [torch .Tensor , torch .Tensor ]:
262-         # input_pos: [S], k_val: [B, H, S, D] or [B, S, H, D] depending on transpose_cache  
256+         # input_pos: [S], k_val: [B, H, S, D] 
263257        if  self .enable_dynamic_shape :
264258            start_pos  =  input_pos [0 ].item ()
265259            torch ._check_is_size (start_pos )
266260            torch ._check (start_pos  <  self .max_seq_length )
267-             dim_to_slice  =  2   if   self . transpose_cache   else   1 
261+             dim_to_slice  =  2 
268262            seq_length  =  k_val .size (dim_to_slice )
269263            # Replace the entry in the cache for this token 
270264            # The following lines are equivalent to: 
@@ -283,28 +277,22 @@ def update(
283277        else :
284278            k_out  =  self .k_cache 
285279            v_out  =  self .v_cache 
286-             if  self .transpose_cache :
287-                 k_out [:, :, input_pos ] =  k_val 
288-                 v_out [:, :, input_pos ] =  v_val 
289-             else :
290-                 k_out [:, input_pos ] =  k_val 
291-                 v_out [:, input_pos ] =  v_val 
280+             k_out [:, :, input_pos ] =  k_val 
281+             v_out [:, :, input_pos ] =  v_val 
292282
293283            return  k_out , v_out 
294284
295285
296286class  SDPA (nn .Module ):
297287    def  __init__ (
298288        self ,
299-         kv_cache : KVCache ,
300289        dim : int ,
301290        head_dim : int ,
302291        n_rep : int ,
303292        max_seq_len : int ,
304293        enable_dynamic_shape : bool ,
305294    ):
306295        super ().__init__ ()
307-         self .kv_cache  =  kv_cache 
308296        self .dim  =  dim 
309297        self .head_dim  =  head_dim 
310298        self .n_rep  =  n_rep 
@@ -314,18 +302,13 @@ def __init__(
314302    def  forward (
315303        self ,
316304        input_pos : torch .Tensor ,
317-         q : torch .Tensor ,  # Already have rotary embeddings. (bs, seqlen, n_local_heads , head_dim) 
318-         k : torch .Tensor ,  # Already have rotary embeddings. (bs, seqlen, n_local_kv_heads , head_dim) 
319-         v : torch .Tensor ,  # (bs, seqlen, n_local_kv_heads , head_dim) 
305+         q : torch .Tensor ,  # Already have rotary embeddings. (bs, n_local_heads, seqlen , head_dim) 
306+         k : torch .Tensor ,  # Already have rotary embeddings. (bs, n_local_kv_heads, seqlen , head_dim) 
307+         v : torch .Tensor ,  # (bs, n_local_kv_heads, seqlen , head_dim) 
320308        bsz ,
321309        seqlen ,
322310        mask : torch .Tensor ,
323311    ) ->  torch .Tensor :
324-         q  =  q .transpose (1 , 2 )  # (bs, n_local_heads, seqlen, head_dim) 
325-         k  =  k .transpose (1 , 2 )
326-         v  =  v .transpose (1 , 2 )
327- 
328-         k , v  =  self .kv_cache .update (input_pos , k , v )
329312        if  self .enable_dynamic_shape :
330313            start_pos  =  input_pos [- 1 ].item ()
331314            torch ._check_is_size (start_pos )
@@ -336,6 +319,8 @@ def forward(
336319        else :
337320            attn_mask  =  mask [None , None , input_pos ]
338321
322+         # TODO(kimishpatel): This should not be necessary because scaled_dot_product_attention 
323+         # can natively support GQA now. But needs enable_gqa=True 
339324        k  =  k .repeat_interleave (self .n_rep , dim = 1 )
340325        v  =  v .repeat_interleave (self .n_rep , dim = 1 )
341326        y  =  F .scaled_dot_product_attention (q , k , v , attn_mask = attn_mask , dropout_p = 0.0 )
@@ -383,11 +368,9 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
383368                args .max_seq_len ,
384369                self .n_kv_heads ,
385370                self .head_dim ,
386-                 not  args .use_sdpa_with_kv_cache_op ,  # if we are using the custom op don't transpose the cache. Expect untransposed q k v 
387371                args .enable_dynamic_shape ,
388372            )
389373            self .SDPA  =  SDPA (
390-                 kv_cache = self .kv_cache ,
391374                dim = self .n_local_heads  *  self .head_dim ,
392375                head_dim = self .head_dim ,
393376                n_rep = self .n_rep ,
@@ -414,15 +397,16 @@ def forward(
414397        # RoPE relative positional embeddings 
415398        q , k  =  self .rope .forward (q , k , freqs_cos , freqs_sin )
416399
400+         q  =  q .transpose (1 , 2 )  # (bs, n_local_heads, seqlen, head_dim) 
401+         k  =  k .transpose (1 , 2 )
402+         v  =  v .transpose (1 , 2 )
403+ 
417404        if  self .use_kv_cache :
418405            assert  input_pos  is  not   None 
406+             k , v  =  self .kv_cache .update (input_pos , k , v )
419407            output  =  self .SDPA (input_pos , q , k , v , bsz , seqlen , self .mask )
420408            return  self .wo (output )
421409
422-         q  =  q .transpose (1 , 2 )  # (bs, n_local_heads, seqlen, head_dim) 
423-         k  =  k .transpose (1 , 2 )
424-         v  =  v .transpose (1 , 2 )
425- 
426410        # grouped multiquery attention: expand out keys and values 
427411        k  =  k .repeat_interleave (self .n_rep , dim = 1 )
428412        v  =  v .repeat_interleave (self .n_rep , dim = 1 )
0 commit comments