@@ -64,6 +64,25 @@ def causal_mask_function(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int)
6464 return kv_idx <= q_idx
6565
6666
67+ def sliding_window_overlay (sliding_window : int ) -> Callable :
68+ """
69+ This is an overlay depicting a sliding window pattern. Add it on top of a causal mask for a proper sliding
70+ window mask.
71+ """
72+
73+ def inner_mask (batch_idx : int , head_idx : int , q_idx : int , kv_idx : int ) -> bool :
74+ return kv_idx > q_idx - sliding_window
75+
76+ return inner_mask
77+
78+
79+ def sliding_window_causal_mask_function (sliding_window : int ) -> Callable :
80+ """
81+ This return the mask_function function to create a sliding window mask.
82+ """
83+ return and_masks (sliding_window_overlay (sliding_window ), causal_mask_function )
84+
85+
6786def _vmap_for_bhqkv (mask_function : Callable , bh_indices : bool = True ) -> Callable :
6887 """
6988 Used to vmap our mask_functions over the q_idx and kv_idx dimensions of the inputs. Optionally, vmap over
@@ -280,12 +299,65 @@ def eager_mask(
280299 return mask
281300
282301
302+ def flash_attention_mask (
303+ batch_size : int ,
304+ cache_position : ms .Tensor ,
305+ kv_length : int ,
306+ kv_offset : int = 0 ,
307+ mask_function : Callable = causal_mask_function ,
308+ attention_mask : Optional [ms .Tensor ] = None ,
309+ ** kwargs ,
310+ ):
311+ """
312+ Create the attention mask necesary to use FA2. Since FA2 is un-padded by definition, here we simply return
313+ `None` if the mask is fully causal, or we return the 2D mask which will then be used to extract the seq_lens.
314+ We just slice it in case of sliding window.
315+
316+ Args:
317+ batch_size (`int`):
318+ The batch size of the input sequence.
319+ cache_position (`ms.Tensor`):
320+ A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
321+ kv_length (`int`):
322+ The size that the key and value states will have during the attention computation.
323+ kv_offset (`int`, optional):
324+ An optional offset to indicate at which first position the key and values states will refer to.
325+ mask_function (`Callable`):
326+ The mask factory function describing the mask pattern.
327+ attention_mask (`ms.Tensor`, optional):
328+ The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
329+ """
330+ if attention_mask is not None :
331+ # Here we need to slice from the right if using sliding or chunked (for full attention, this is equivalent to doing nothing)
332+ attention_mask = attention_mask [:, - kv_length :]
333+ # We only return an actual mask if there is at least 1 padding token, otherwise we return `None` and use `is_causal` in FA2
334+ # (note that the attention_mask is a boolean dtype here)
335+ if attention_mask .all ():
336+ attention_mask = None
337+
338+ return attention_mask
339+
340+
341+ def flex_attention_mask (
342+ batch_size : int ,
343+ cache_position : ms .Tensor ,
344+ kv_length : int ,
345+ kv_offset : int = 0 ,
346+ mask_function : Callable = causal_mask_function ,
347+ attention_mask : Optional [ms .Tensor ] = None ,
348+ ** kwargs ,
349+ ):
350+ raise NotImplementedError ("`flex_attention` is not supported yet." )
351+
352+
283353class AttentionMaskInterface (GeneralInterface ):
284354 # Class instance object, so that a call to `register` can be reflected into all other files correctly, even if
285355 # a new instance is created (in order to locally override a given function)
286356 _global_mapping = {
357+ "sdpa" : sdpa_mask ,
287358 "eager" : eager_mask ,
288- "flash_attention_2" : eager_mask ,
359+ "flash_attention_2" : flash_attention_mask ,
360+ "flex_attention" : flex_attention_mask ,
289361 }
290362
291363
@@ -308,13 +380,13 @@ def _preprocess_mask_arguments(
308380 Args:
309381 config (`PretrainedConfig`):
310382 The model config.
311- input_embeds (`torch .Tensor`):
383+ input_embeds (`ms .Tensor`):
312384 The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the
313385 batch size, query length and dtype.
314- attention_mask (`torch .Tensor`, optional):
386+ attention_mask (`ms .Tensor`, optional):
315387 The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length).
316388 It can also be an already prepared 4D mask, in which case it is returned as-is.
317- cache_position (`torch .Tensor`):
389+ cache_position (`ms .Tensor`):
318390 A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
319391 past_key_values (`Cache`, optional):
320392 The past key values, if we use a cache.
@@ -325,7 +397,7 @@ def _preprocess_mask_arguments(
325397 Returns:
326398 early_exit (`bool`):
327399 Whether we should early exit mask creation, and return the mask as-is.
328- attention_mask (`torch .Tensor` or `BlockMask` or `None`):
400+ attention_mask (`ms .Tensor` or `BlockMask` or `None`):
329401 The attention mask to either return immediately, or to use in downstream mask creation.
330402 kv_length (`int`):
331403 The size that the key and value states will have during the attention computation.
@@ -375,13 +447,13 @@ def create_causal_mask(
375447 Args:
376448 config (`PretrainedConfig`):
377449 The model config.
378- input_embeds (`torch .Tensor`):
450+ input_embeds (`ms .Tensor`):
379451 The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the
380452 batch size, query length and dtype.
381- attention_mask (`torch .Tensor`, optional):
453+ attention_mask (`ms .Tensor`, optional):
382454 The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length).
383455 It can also be an already prepared 4D mask, in which case it is returned as-is.
384- cache_position (`torch .Tensor`):
456+ cache_position (`ms .Tensor`):
385457 A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
386458 past_key_values (`Cache`, optional):
387459 The past key values, if we use a cache.
@@ -435,6 +507,86 @@ def create_causal_mask(
435507 return causal_mask
436508
437509
510+ def create_sliding_window_causal_mask (
511+ config : PretrainedConfig ,
512+ input_embeds : ms .Tensor ,
513+ attention_mask : Optional [ms .Tensor ],
514+ cache_position : ms .Tensor ,
515+ past_key_values : Optional [Cache ],
516+ or_mask_function : Optional [Callable ] = None ,
517+ and_mask_function : Optional [Callable ] = None ,
518+ ) -> Optional [Union [ms .Tensor , BlockMask ]]:
519+ """
520+ Create a sliding window causal mask based on the attention implementation used (stored in the config). This type
521+ of attention pattern was mostly democratized by Mistral. If `past_key_values` has an HybridCache structure, this
522+ function will return the mask corresponding to one of the "sliding_attention" layers (to align to what is needed in the
523+ `modeling_xxx.py` files).
524+
525+ Args:
526+ config (`PretrainedConfig`):
527+ The model config.
528+ input_embeds (`ms.Tensor`):
529+ The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the
530+ batch size, query length and dtype.
531+ attention_mask (`ms.Tensor`, optional):
532+ The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length).
533+ It can also be an already prepared 4D mask, in which case it is returned as-is.
534+ cache_position (`ms.Tensor`):
535+ A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
536+ past_key_values (`Cache`, optional):
537+ The past key values, if we use a cache.
538+ or_mask_function (`Callable`, optional):
539+ An optional mask function to combine with the sliding causal mask function (by doing the union of both). This is
540+ useful to easily overlay another mask on top of the sliding causal one, for example for image tokens handling.
541+ and_mask_function (`Callable`, optional):
542+ An optional mask function to combine with the sliding causal mask function (by doing the intersection of both). This is
543+ useful to easily overlay another mask on top of the sliding causal one, for example for image tokens handling.
544+ """
545+ # If we have an HybridCache structure, here we want to create the mask for the sliding layers
546+ if hasattr (past_key_values , "is_sliding" ) and True in past_key_values .is_sliding :
547+ layer_idx = past_key_values .is_sliding .index (True )
548+ else :
549+ layer_idx = 0
550+
551+ early_exit , attention_mask , kv_length , kv_offset = _preprocess_mask_arguments (
552+ config , input_embeds , attention_mask , cache_position , past_key_values , layer_idx
553+ )
554+ if early_exit :
555+ return attention_mask
556+
557+ sliding_window = getattr (config , "sliding_window" , None )
558+ if sliding_window is None :
559+ raise ValueError ("Could not find a `sliding_window` argument in the config, or it is not set" )
560+
561+ batch_size , dtype = input_embeds .shape [0 ], input_embeds .dtype
562+ mask_factory_function = sliding_window_causal_mask_function (sliding_window )
563+ mask_interface = ALL_MASK_ATTENTION_FUNCTIONS [config ._attn_implementation ]
564+
565+ # Do not allow skip if we are compiling (this is to match BC)
566+ # TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it
567+ allow_is_causal_skip = not past_key_values .is_compileable if past_key_values is not None else True
568+
569+ # Allow slight deviations from sliding causal mask
570+ if or_mask_function is not None or and_mask_function is not None :
571+ raise NotImplementedError ("`or_mask_function` or `and_mask_function` arguments are not supported yet." )
572+
573+ # We now create the mask
574+ causal_mask = mask_interface (
575+ batch_size = batch_size ,
576+ cache_position = cache_position ,
577+ kv_length = kv_length ,
578+ kv_offset = kv_offset ,
579+ mask_function = mask_factory_function ,
580+ attention_mask = attention_mask ,
581+ allow_is_causal_skip = allow_is_causal_skip , # additional kwarg for sdpa
582+ local_size = sliding_window , # Additional kwarg for sdpa
583+ dtype = dtype , # Additional kwarg for eager
584+ config = config , # Pass the config as well, in case someone wants to easily have their own mask_interface
585+ )
586+ return causal_mask
587+
588+
438589LAYER_PATTERN_TO_MASK_FUNCTION_MAPPING = {
439590 "full_attention" : create_causal_mask ,
591+ "sliding_attention" : create_sliding_window_causal_mask ,
440592}
0 commit comments