@@ -494,6 +494,102 @@ smdistributed.modelparallel.torch.DistributedOptimizer
494494 ``state_dict`` contains elements corresponding to only the current
495495 partition, or to the entire model.
496496
497+ smdistributed.modelparallel.torch.nn.FlashAttentionLayer
498+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
499+
500+ .. function:: smdistributed.modelparallel.torch.nn.FlashAttentionLayer(attention_dropout_prob=0.1, attention_head_size=None, scale_attention_scores=True, scale_attn_by_layer_idx=False, layer_idx=None, scale=None, triton_flash_attention=False, use_alibi=False)
501+
502+ This class supports
503+ `FlashAttention <https://github.com/HazyResearch/flash-attention>`_
504+ for PyTorch 2.0.
505+ It takes the ``qkv `` matrix as an argument through its ``forward `` class method,
506+ computes attention scores and probabilities,
507+ and then operates the matrix multiplication with value layers.
508+
509+ Through this class, the smp library supports
510+ custom attention masks such as Attention with
511+ Linear Biases (ALiBi), and you can activate them by setting
512+ ``triton_flash_attention `` and ``use_alibi `` to ``True ``.
513+
514+ Note that the Triton flash attention does not support dropout
515+ on the attention probabilities. It uses standard lower triangular
516+ causal mask when causal mode is enabled. It also runs only
517+ on P4d and P4de instances, with fp16 or bf16.
518+
519+ This class computes the scale factor to apply when computing attention.
520+ By default, ``scale `` is set to ``None ``, and it's automatically calculated.
521+ When ``scale_attention_scores `` is ``True `` (which is default), you must pass a value
522+ to ``attention_head_size ``. When ``scale_attn_by_layer_idx `` is ``True ``,
523+ you must pass a value to ``layer_idx ``. If both factors are used, they are
524+ multiplied as follows: ``(1/(sqrt(attention_head_size) * (layer_idx+1))) ``.
525+ This scale calculation can be bypassed if you specify a custom scaling
526+ factor to ``scale ``. In other words, if you specify a value to ``scale ``, the set of parameters
527+ (``scale_attention_scores ``, ``attention_head_size ``, ``scale_attn_by_layer_idx ``, ``layer_idx ``)
528+ is overridden and ignored.
529+
530+ **Parameters **
531+
532+ * ``attention_dropout_prob `` (float): (default: 0.1) specifies dropout probability
533+ to apply to attention.
534+ * ``attention_head_size `` (int): Required when ``scale_attention_scores `` is True.
535+ When ``scale_attention_scores `` is passed, this contributes
536+ ``1/sqrt(attention_head_size) `` to the scale factor.
537+ * ``scale_attention_scores `` (boolean): (default: True) determines whether
538+ to multiply 1/sqrt(attention_head_size) to the scale factor.
539+ * ``layer_idx `` (int): Required when ``scale_attn_by_layer_idx `` is ``True ``.
540+ The layer id to use for scaling attention by layer id.
541+ It contributes 1/(layer_idx + 1) to the scaling factor.
542+ * ``scale_attn_by_layer_idx `` (boolean): (default: False) determines whether
543+ to multiply 1/(layer_idx + 1) to the scale factor.
544+ * ``scale `` (float) (default: None): If passed, this scale factor will be
545+ applied bypassing the all of the previous arguments.
546+ * ``triton_flash_attention `` (bool): (default: False) If passed, Triton
547+ implementation of flash attention will be used. This is necessary to supports
548+ Attention with Linear Biases (ALiBi) (see next arg). Note that this version
549+ of the kernel doesn’t support dropout.
550+ * ``use_alibi `` (bool): (default: False) If passed, it enables Attention with
551+ Linear Biases (ALiBi) using the mask provided.
552+
553+ .. method :: forward(self, qkv, attn_mask=None, causal=False)
554+
555+ Returns a single ``torch.Tensor `` ``(batch_size x num_heads x seq_len x head_size) ``,
556+ which represents the output of attention computation.
557+
558+ **Parameters **
559+
560+ * ``qkv ``: ``torch.Tensor `` in the form of ``(batch_size x seqlen x 3 x num_heads x head_size) ``.
561+ * ``attn_mask ``: ``torch.Tensor `` in the form of ``(batch_size x 1 x 1 x seqlen) ``.
562+ By default it is ``None ``, and usage of this mask needs ``triton_flash_attention ``
563+ and ``use_alibi `` to be set. See how to generate the mask in the following code snippet.
564+ * ``causal ``: When passed, it uses the standard lower triangular mask. The default is ``False ``.
565+
566+ When using ALiBi, it needs an attention mask prepared like the following.
567+
568+ .. code :: python
569+
570+ def generate_alibi_attn_mask (attention_mask , batch_size , seq_length ,
571+ num_attention_heads , alibi_bias_max = 8 ):
572+
573+ device, dtype = attention_mask.device, attention_mask.dtype
574+ alibi_attention_mask = torch.zeros(
575+ 1 , num_attention_heads, 1 , seq_length, dtype = dtype, device = device
576+ )
577+
578+ alibi_bias = torch.arange(1 - seq_length, 1 , dtype = dtype, device = device).view(
579+ 1 , 1 , 1 , seq_length
580+ )
581+ m = torch.arange(1 , num_attention_heads + 1 , dtype = dtype, device = device)
582+ m.mul_(alibi_bias_max / num_attention_heads)
583+ alibi_bias = alibi_bias * (1.0 / (2 ** m.view(1 , num_attention_heads, 1 , 1 )))
584+
585+ alibi_attention_mask.add_(alibi_bias)
586+ alibi_attention_mask = alibi_attention_mask[... , :seq_length, :seq_length]
587+ if attention_mask is not None and attention_mask.bool().any():
588+ alibi_attention_mask.masked_fill(
589+ attention_mask.bool().view(batch_size, 1 , 1 , seq_length), float (" -inf" )
590+ )
591+
592+ return alibi_attention_mask
497593
498594 smdistributed.modelparallel.torch Context Managers and Util Functions
499595^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
0 commit comments