@@ -123,14 +123,12 @@ def __init__(
123123 head_dim : int ,
124124 n_rep : int ,
125125 max_context_len : int ,
126- enable_dynamic_shape : bool ,
127126 ):
128127 super ().__init__ ()
129128 self .dim = dim
130129 self .head_dim = head_dim
131130 self .n_rep = n_rep
132131 self .max_context_len = max_context_len
133- self .enable_dynamic_shape = enable_dynamic_shape
134132
135133 def forward (
136134 self ,
@@ -142,21 +140,12 @@ def forward(
142140 seqlen ,
143141 mask : torch .Tensor ,
144142 ) -> torch .Tensor :
145- if self .enable_dynamic_shape :
146- start_pos = input_pos [- 1 ].item ()
147- torch ._check_is_size (start_pos )
148- torch ._check (start_pos < self .max_context_len )
149- seq_length = q .size (2 )
150- # pyre-ignore: Incompatible parameter type [6]
151- attn_mask = mask .narrow (0 , start_pos , seq_length )
152- else :
153- attn_mask = mask [None , None , input_pos ]
154143
155144 # TODO(kimishpatel): This should not be necessary because scaled_dot_product_attention
156145 # can natively support GQA now. But needs enable_gqa=True
157146 k = k .repeat_interleave (self .n_rep , dim = 1 )
158147 v = v .repeat_interleave (self .n_rep , dim = 1 )
159- y = F .scaled_dot_product_attention (q , k , v , attn_mask = attn_mask , dropout_p = 0.0 )
148+ y = F .scaled_dot_product_attention (q , k , v , attn_mask = mask , dropout_p = 0.0 )
160149
161150 return y .transpose (1 , 2 ).contiguous ().view (bsz , seqlen , self .dim )
162151
@@ -236,21 +225,79 @@ def __init__(
236225 enable_dynamic_shape : bool ,
237226 dtype = torch .float32 ,
238227 ):
228+ self .window_size = max_context_length
229+ """
230+ Reason why we want the kv cache size to be twice the context length:
231+ Sliding window attention without ringbuffer
232+ pos 0 1 2 3 4 5 6 7 8 9 10
233+ 0 x 0 0 0 0 0 0 0 0 0 0
234+ 1 x x 0 0 0 0 0 0 0 0 0
235+ 2 x x x 0 0 0 0 0 0 0 0
236+ 3 x x x x 0 0 0 0 0 0 0
237+ 4 0 x x x x 0 0 0 0 0 0
238+ 5 0 0 x x x x 0 0 0 0 0
239+ 6 0 0 0 x x x x 0 0 0 0
240+ 7 0 0 0 0 x x x x 0 0 0
241+ 8 0 0 0 0 0 x x x x 0 0
242+ 9 0 0 0 0 0 0 x x x x 0
243+ 10 0 0 0 0 0 0 0 x x x x
244+
245+ So when doing attention for pos = 5 and seq_len = 4 our attention
246+ mask would be
247+ 5 0 0 x x x x 0 0 0 0 0
248+ 6 0 0 0 x x x x 0 0 0 0
249+ 7 0 0 0 0 x x x x 0 0 0
250+ 8 0 0 0 0 0 x x x x 0 0
251+ Thus tok at pos = 5 is able to attend to tokens at pos 2, 3 and 4.
252+ This is how training is done.
253+
254+ Now lets consider ring kv cache of size 4. When we are at pos = 5
255+ before updating the kv cache, state of the kv cache would be
256+ [4 1 2 3]. That is we evicted token at pos = 0 out. Now during
257+ attention calculation at pos = 5 seq len = 4, we will update cache and
258+ new pos in the cache would be [8 5 6 7]. So note that 5 can now only attend
259+ to itself. Not 2, 3 and 4 as you would have during training.
260+ So not having kept 2, 3 and 4 in cache means we will have divergent behavior.
261+ Worst case of this would have been when update it equal to the length of
262+ the cache. like in our case pos = 5 seq len = 4.
263+ Thus we need to have a cache that is larger. How much larger, as much as
264+ the sliding window size. So twice the max_context_length.
265+ How would that have helped. Lets see. At pos = 5 our cache would have
266+ [0, 1, 2, 3, 4, NA, NA, NA] After cache update we would have
267+ [8, 1, 2, 3, 4, 5, 6, 7]. We kicked out token at pos = 0. However, the
268+ current step still has access to [pos - sliding_window_size, pos] tokens.
269+
270+ To make sure we dont over attend, i.e. we dont have pos = 5
271+ to attend to pos = 1, mask calculaton has to account for the sliding window
272+ size.
273+ """
239274 super ().__init__ (
240275 max_batch_size ,
241- max_context_length ,
276+ max_context_length * 2 ,
242277 n_heads ,
243278 head_dim ,
244279 enable_dynamic_shape ,
245280 dtype ,
246281 )
247- self .cache_positions_manager = CachePositionsManager (max_context_length )
282+ self .cache_positions_manager = CachePositionsManager (self .max_context_length )
283+ self .is_ring_buffer = True
284+
285+ def create_causal_mask_for_ring_buffer (self , start_pos , seq_len ):
286+ pos_q = start_pos + torch .arange (seq_len , dtype = torch .long ).view (- 1 , 1 )
287+ cache_positions = self .cache_positions_manager .cache_positions
288+ delta = pos_q - cache_positions
289+ attn_mask = (cache_positions >= 0 ) & (delta >= 0 ) & (delta < self .window_size )
290+ attn_mask = torch .where (attn_mask == True , 0 , float ("-inf" )) # noqa E712
291+ return attn_mask
248292
249293 def update (
250294 self , input_pos : torch .Tensor , k_val : torch .Tensor , v_val : torch .Tensor
251295 ) -> Tuple [torch .Tensor , torch .Tensor ]:
252296 # input_pos: [S], k_val: [B, H, S, D]
253297 seq_len = k_val .size (2 )
298+ assert seq_len <= self .k_cache .size (
299+ 2
300+ ), f"Update sequence length({ seq_len } ) for kv cache must be smaller than the cache size({ self .k_cache .size (2 )} )"
254301 indices = self .cache_positions_manager .calculate_positions_and_update_indices (
255302 input_pos , seq_len
256303 )
@@ -286,6 +333,7 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
286333 self .attention_qkv_bias = args .attention_qkv_bias
287334 self .use_qk_norm = args .use_qk_norm
288335 self .qk_norm_before_rope = args .qk_norm_before_rope
336+ self .enable_dynamic_shape = args .enable_dynamic_shape
289337
290338 if self .use_qk_norm :
291339 q_norm_dim = self .head_dim
@@ -331,7 +379,6 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
331379 head_dim = self .head_dim ,
332380 n_rep = self .n_rep ,
333381 max_context_len = self .max_context_len ,
334- enable_dynamic_shape = args .enable_dynamic_shape ,
335382 )
336383
337384 def forward (
@@ -368,8 +415,22 @@ def forward(
368415
369416 if self .use_kv_cache :
370417 assert input_pos is not None
418+ if self .enable_dynamic_shape :
419+ start_pos = input_pos [- 1 ].item ()
420+ torch ._check_is_size (start_pos )
421+ torch ._check (start_pos < self .max_context_len )
422+ seq_length = q .size (2 )
423+ # pyre-ignore: Incompatible parameter type [6]
424+ attn_mask = self .mask .narrow (0 , start_pos , seq_length )
425+ else :
426+ # mask is always 2D
427+ attn_mask = self .mask [input_pos ]
371428 k , v = self .kv_cache .update (input_pos , k , v )
372- output = self .SDPA (input_pos , q , k , v , bsz , seqlen , self .mask )
429+ if getattr (self .kv_cache , "is_ring_buffer" , False ):
430+ attn_mask = self .kv_cache .create_causal_mask_for_ring_buffer (
431+ input_pos [0 ].item (), seqlen
432+ )
433+ output = self .SDPA (input_pos , q , k , v , bsz , seqlen , attn_mask )
373434 return self .wo (output ), None
374435
375436 # grouped multiquery attention: expand out keys and values
0 commit comments