@@ -123,14 +123,12 @@ def __init__(
123123        head_dim : int ,
124124        n_rep : int ,
125125        max_context_len : int ,
126-         enable_dynamic_shape : bool ,
127126    ):
128127        super ().__init__ ()
129128        self .dim  =  dim 
130129        self .head_dim  =  head_dim 
131130        self .n_rep  =  n_rep 
132131        self .max_context_len  =  max_context_len 
133-         self .enable_dynamic_shape  =  enable_dynamic_shape 
134132
135133    def  forward (
136134        self ,
@@ -142,21 +140,12 @@ def forward(
142140        seqlen ,
143141        mask : torch .Tensor ,
144142    ) ->  torch .Tensor :
145-         if  self .enable_dynamic_shape :
146-             start_pos  =  input_pos [- 1 ].item ()
147-             torch ._check_is_size (start_pos )
148-             torch ._check (start_pos  <  self .max_context_len )
149-             seq_length  =  q .size (2 )
150-             # pyre-ignore: Incompatible parameter type [6] 
151-             attn_mask  =  mask .narrow (0 , start_pos , seq_length )
152-         else :
153-             attn_mask  =  mask [None , None , input_pos ]
154143
155144        # TODO(kimishpatel): This should not be necessary because scaled_dot_product_attention 
156145        # can natively support GQA now. But needs enable_gqa=True 
157146        k  =  k .repeat_interleave (self .n_rep , dim = 1 )
158147        v  =  v .repeat_interleave (self .n_rep , dim = 1 )
159-         y  =  F .scaled_dot_product_attention (q , k , v , attn_mask = attn_mask , dropout_p = 0.0 )
148+         y  =  F .scaled_dot_product_attention (q , k , v , attn_mask = mask , dropout_p = 0.0 )
160149
161150        return  y .transpose (1 , 2 ).contiguous ().view (bsz , seqlen , self .dim )
162151
@@ -236,21 +225,79 @@ def __init__(
236225        enable_dynamic_shape : bool ,
237226        dtype = torch .float32 ,
238227    ):
228+         self .window_size  =  max_context_length 
229+         """ 
230+         Reason why we want the kv cache size to be twice the context length: 
231+         Sliding window attention without ringbuffer 
232+         pos   0  1  2  3  4  5  6  7  8  9  10 
233+         0     x  0  0  0  0  0  0  0  0  0  0 
234+         1     x  x  0  0  0  0  0  0  0  0  0 
235+         2     x  x  x  0  0  0  0  0  0  0  0 
236+         3     x  x  x  x  0  0  0  0  0  0  0 
237+         4     0  x  x  x  x  0  0  0  0  0  0 
238+         5     0  0  x  x  x  x  0  0  0  0  0 
239+         6     0  0  0  x  x  x  x  0  0  0  0 
240+         7     0  0  0  0  x  x  x  x  0  0  0 
241+         8     0  0  0  0  0  x  x  x  x  0  0 
242+         9     0  0  0  0  0  0  x  x  x  x  0 
243+         10    0  0  0  0  0  0  0  x  x  x  x 
244+ 
245+         So when doing attention for pos = 5 and seq_len = 4 our attention 
246+         mask would be 
247+         5     0  0  x  x  x  x  0  0  0  0  0 
248+         6     0  0  0  x  x  x  x  0  0  0  0 
249+         7     0  0  0  0  x  x  x  x  0  0  0 
250+         8     0  0  0  0  0  x  x  x  x  0  0 
251+         Thus tok at pos = 5 is able to attend to tokens at pos 2, 3 and 4. 
252+         This is how training is done. 
253+ 
254+         Now lets consider ring kv cache of size 4. When we are at pos = 5 
255+         before updating the kv cache, state of the kv cache would be 
256+         [4 1 2 3]. That is we evicted token at pos = 0 out. Now during 
257+         attention calculation at pos = 5 seq len = 4, we will update cache and 
258+         new pos in the cache would be [8 5 6 7]. So note that 5 can now only attend 
259+         to itself. Not 2, 3 and 4 as you would have during training. 
260+         So not having kept 2, 3 and 4 in cache means we will have divergent behavior. 
261+         Worst case of this would have been when update it equal to the length of 
262+         the cache. like in our case pos = 5 seq len = 4. 
263+         Thus we need to have a cache that is larger. How much larger, as much as 
264+         the sliding window size. So twice the max_context_length. 
265+         How would that have helped. Lets see. At pos = 5 our cache would have 
266+         [0, 1, 2, 3, 4, NA, NA, NA] After cache update we would have 
267+         [8, 1, 2, 3, 4, 5, 6, 7]. We kicked out token at pos = 0. However, the 
268+         current step still has access to [pos - sliding_window_size, pos] tokens. 
269+          
270+         To make sure we dont over attend, i.e. we dont have pos = 5 
271+         to attend to pos = 1, mask calculaton has to account for the sliding window 
272+         size. 
273+         """ 
239274        super ().__init__ (
240275            max_batch_size ,
241-             max_context_length ,
276+             max_context_length   *   2 ,
242277            n_heads ,
243278            head_dim ,
244279            enable_dynamic_shape ,
245280            dtype ,
246281        )
247-         self .cache_positions_manager  =  CachePositionsManager (max_context_length )
282+         self .cache_positions_manager  =  CachePositionsManager (self .max_context_length )
283+         self .is_ring_buffer  =  True 
284+ 
285+     def  create_causal_mask_for_ring_buffer (self , start_pos , seq_len ):
286+         pos_q  =  start_pos  +  torch .arange (seq_len , dtype = torch .long ).view (- 1 , 1 )
287+         cache_positions  =  self .cache_positions_manager .cache_positions 
288+         delta  =  pos_q  -  cache_positions 
289+         attn_mask  =  (cache_positions  >=  0 ) &  (delta  >=  0 ) &  (delta  <  self .window_size )
290+         attn_mask  =  torch .where (attn_mask  ==  True , 0 , float ("-inf" ))  # noqa E712 
291+         return  attn_mask 
248292
249293    def  update (
250294        self , input_pos : torch .Tensor , k_val : torch .Tensor , v_val : torch .Tensor 
251295    ) ->  Tuple [torch .Tensor , torch .Tensor ]:
252296        # input_pos: [S], k_val: [B, H, S, D] 
253297        seq_len  =  k_val .size (2 )
298+         assert  seq_len  <=  self .k_cache .size (
299+             2 
300+         ), f"Update sequence length({ seq_len } { self .k_cache .size (2 )}  
254301        indices  =  self .cache_positions_manager .calculate_positions_and_update_indices (
255302            input_pos , seq_len 
256303        )
@@ -286,6 +333,7 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
286333        self .attention_qkv_bias  =  args .attention_qkv_bias 
287334        self .use_qk_norm  =  args .use_qk_norm 
288335        self .qk_norm_before_rope  =  args .qk_norm_before_rope 
336+         self .enable_dynamic_shape  =  args .enable_dynamic_shape 
289337
290338        if  self .use_qk_norm :
291339            q_norm_dim  =  self .head_dim 
@@ -331,7 +379,6 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
331379                head_dim = self .head_dim ,
332380                n_rep = self .n_rep ,
333381                max_context_len = self .max_context_len ,
334-                 enable_dynamic_shape = args .enable_dynamic_shape ,
335382            )
336383
337384    def  forward (
@@ -368,8 +415,22 @@ def forward(
368415
369416        if  self .use_kv_cache :
370417            assert  input_pos  is  not None 
418+             if  self .enable_dynamic_shape :
419+                 start_pos  =  input_pos [- 1 ].item ()
420+                 torch ._check_is_size (start_pos )
421+                 torch ._check (start_pos  <  self .max_context_len )
422+                 seq_length  =  q .size (2 )
423+                 # pyre-ignore: Incompatible parameter type [6] 
424+                 attn_mask  =  self .mask .narrow (0 , start_pos , seq_length )
425+             else :
426+                 # mask is always 2D 
427+                 attn_mask  =  self .mask [input_pos ]
371428            k , v  =  self .kv_cache .update (input_pos , k , v )
372-             output  =  self .SDPA (input_pos , q , k , v , bsz , seqlen , self .mask )
429+             if  getattr (self .kv_cache , "is_ring_buffer" , False ):
430+                 attn_mask  =  self .kv_cache .create_causal_mask_for_ring_buffer (
431+                     input_pos [0 ].item (), seqlen 
432+                 )
433+             output  =  self .SDPA (input_pos , q , k , v , bsz , seqlen , attn_mask )
373434            return  self .wo (output ), None 
374435
375436        # grouped multiquery attention: expand out keys and values 
0 commit comments