|
1 | 1 | """Flex attention monkey patch"""
|
2 | 2 |
|
| 3 | +import sys |
| 4 | +from typing import Optional, Tuple, Union |
| 5 | + |
3 | 6 | import torch
|
4 | 7 | import transformers
|
5 | 8 |
|
6 | 9 |
|
7 |
| -def patch_flex(): |
| 10 | +def patch_flex_wrapper(): |
| 11 | + # TODO remove this patch when transformers#37285 is merged and in a release |
8 | 12 | is_torch_2_6 = torch.__version__.startswith("2.6")
|
9 | 13 | is_transformers_below_4_51 = transformers.__version__ < "4.51.0"
|
10 | 14 |
|
11 |
| - if is_torch_2_6 and is_transformers_below_4_51: |
12 |
| - from torch.nn.attention.flex_attention import flex_attention |
| 15 | + if not (is_torch_2_6 and is_transformers_below_4_51): |
| 16 | + return |
| 17 | + |
| 18 | + from torch.nn.attention.flex_attention import flex_attention |
| 19 | + |
| 20 | + class WrappedFlexAttention: |
| 21 | + """ |
| 22 | + We are doing a singleton class so that flex attention is compiled once when it's first called. |
| 23 | + """ |
| 24 | + |
| 25 | + _instance = None |
| 26 | + _is_flex_compiled = False |
| 27 | + _compiled_flex_attention = None |
13 | 28 |
|
14 |
| - class WrappedFlexAttention: |
| 29 | + def __new__(cls, *args, **kwargs): |
| 30 | + if cls._instance is None: |
| 31 | + # Create a new instance if one doesn't already exist |
| 32 | + cls._instance = super().__new__(cls) |
| 33 | + return cls._instance |
| 34 | + |
| 35 | + @torch.compiler.disable(recursive=False) |
| 36 | + def __init__(self): |
15 | 37 | """
|
16 |
| - We are doing a singleton class so that flex attention is compiled once when it's first called. |
| 38 | + Initialize or update the singleton instance. |
17 | 39 | """
|
| 40 | + if not self._is_flex_compiled: |
| 41 | + self._compiled_flex_attention = torch.compile( |
| 42 | + flex_attention, |
| 43 | + dynamic=False, |
| 44 | + mode="max-autotune-no-cudagraphs", |
| 45 | + fullgraph=True, |
| 46 | + ) |
| 47 | + self._is_flex_compiled = True |
| 48 | + |
| 49 | + def __call__(self): |
| 50 | + return self._compiled_flex_attention |
| 51 | + |
| 52 | + transformers.integrations.flex_attention.WrappedFlexAttention = WrappedFlexAttention |
| 53 | + |
| 54 | + |
| 55 | +def patch_flex_make_mask(): |
| 56 | + is_torch_2_6 = torch.__version__.startswith("2.6") |
| 57 | + is_transformers_eq_4_51 = transformers.__version__ == "4.51.0" |
| 58 | + |
| 59 | + if not (is_torch_2_6 and is_transformers_eq_4_51): |
| 60 | + return |
| 61 | + |
| 62 | + from torch.nn.attention.flex_attention import ( |
| 63 | + BlockMask, |
| 64 | + ) |
| 65 | + from torch.nn.attention.flex_attention import ( |
| 66 | + create_block_mask as create_block_causal_mask_flex, |
| 67 | + ) |
| 68 | + |
| 69 | + Offset = Union[torch.Tensor, int] |
| 70 | + |
| 71 | + def patched_make_flex_block_causal_mask( |
| 72 | + attention_mask_2d: torch.Tensor, |
| 73 | + attention_chunk_size: Optional[int] = None, |
| 74 | + query_length=None, |
| 75 | + key_length=None, |
| 76 | + offsets: Optional[Tuple[Offset, Offset]] = None, |
| 77 | + ) -> "BlockMask": |
| 78 | + """ |
| 79 | + Create a block causal document mask for a batch of sequences, both packed and unpacked. |
| 80 | + Create Block causal logic and passing it into :func:`torch.nn.attention.flex_attention.create_block_mask`. |
| 81 | + The resultant BlockMask is a compressed representation of the full block causal |
| 82 | + mask. BlockMask is essential for performant computation of flex attention. |
| 83 | + See: https://pytorch.org/blog/flexattention/ |
| 84 | +
|
| 85 | + Args: |
| 86 | + attention_mask_2d (torch.Tensor): Attention mask for packed and padded sequences |
| 87 | + of shape (batch_size, total_seq_len). e.g. |
18 | 88 |
|
19 |
| - _instance = None |
20 |
| - _is_flex_compiled = False |
21 |
| - _compiled_flex_attention = None |
22 |
| - |
23 |
| - def __new__(cls, *args, **kwargs): |
24 |
| - if cls._instance is None: |
25 |
| - # Create a new instance if one doesn't already exist |
26 |
| - cls._instance = super().__new__(cls) |
27 |
| - return cls._instance |
28 |
| - |
29 |
| - @torch.compiler.disable(recursive=False) |
30 |
| - def __init__(self): |
31 |
| - """ |
32 |
| - Initialize or update the singleton instance. |
33 |
| - """ |
34 |
| - if not self._is_flex_compiled: |
35 |
| - self._compiled_flex_attention = torch.compile( |
36 |
| - flex_attention, |
37 |
| - dynamic=False, |
38 |
| - mode="max-autotune-no-cudagraphs", |
39 |
| - fullgraph=True, |
40 |
| - ) |
41 |
| - self._is_flex_compiled = True |
42 |
| - |
43 |
| - def __call__(self): |
44 |
| - return self._compiled_flex_attention |
45 |
| - |
46 |
| - transformers.integrations.flex_attention.WrappedFlexAttention = ( |
47 |
| - WrappedFlexAttention |
| 89 | + For unpacked sequence: |
| 90 | + [[1, 1, 1, 1, 0, 0, 0], |
| 91 | + [1, 1, 1, 1, 1, 0, 0]] |
| 92 | +
|
| 93 | + For packed sequence: |
| 94 | + [[1, 1, 1, 2, 2, 2, 0], |
| 95 | + [1, 1, 2, 2, 2, 3, 3]] |
| 96 | +
|
| 97 | + Returns: |
| 98 | + BlockMask |
| 99 | + """ |
| 100 | + |
| 101 | + batch_size, total_seq_len = attention_mask_2d.shape |
| 102 | + if not key_length: |
| 103 | + key_length = total_seq_len |
| 104 | + if not query_length: |
| 105 | + query_length = total_seq_len |
| 106 | + attention_mask_2d = torch.nn.functional.pad( |
| 107 | + attention_mask_2d, value=0, pad=(0, key_length) |
| 108 | + ) |
| 109 | + device = attention_mask_2d.device |
| 110 | + document_ids = attention_mask_2d.clone() |
| 111 | + |
| 112 | + if attention_chunk_size is not None: |
| 113 | + # we create an arange, then we just // by chunk size to get [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3] |
| 114 | + document_ids = (document_ids.fill_(1).cumsum(-1) - 1) // ( |
| 115 | + attention_chunk_size |
| 116 | + ) |
| 117 | + |
| 118 | + # Instead of passing a tensor mask, flex attention requires a mask_mod function |
| 119 | + # that determines which elements of QK^T should be included in the attention |
| 120 | + # computation prior to the softmax. For sample packing, we need both the |
| 121 | + # logic for both causal mask and document mask. See PyTorch's official |
| 122 | + # blog post for more details: https://pytorch.org/blog/flexattention/#mask-mods |
| 123 | + def causal_mask_mod( |
| 124 | + batch_idx, head_idx, q_idx, kv_idx |
| 125 | + ): # pylint: disable=unused-argument |
| 126 | + """ |
| 127 | + Defines the logic of a block causal mask by combining both a standard causal mask |
| 128 | + and a block diagonal document mask. |
| 129 | +
|
| 130 | + See :func:`~torchtune.modules.attention_utils.create_block_causal_mask` |
| 131 | + for an illustration. |
| 132 | + """ |
| 133 | + causal_mask = q_idx >= kv_idx # not valid when decoding |
| 134 | + document_mask = ( |
| 135 | + document_ids[batch_idx, q_idx] == document_ids[batch_idx, kv_idx] |
| 136 | + ) |
| 137 | + padding_mask = attention_mask_2d[batch_idx, q_idx] > 0 |
| 138 | + final_mask = causal_mask & padding_mask & document_mask |
| 139 | + return final_mask |
| 140 | + |
| 141 | + if offsets is not None: |
| 142 | + q_offset = offsets[0] |
| 143 | + kv_offset = offsets[1] |
| 144 | + |
| 145 | + def mask_mod(batch_idx, head_idx, q_idx, kv_idx): |
| 146 | + offset_q = q_idx + q_offset |
| 147 | + offset_kv = kv_idx + kv_offset |
| 148 | + return causal_mask_mod(batch_idx, head_idx, offset_q, offset_kv) |
| 149 | + |
| 150 | + else: |
| 151 | + mask_mod = causal_mask_mod |
| 152 | + return create_block_causal_mask_flex( |
| 153 | + mask_mod=mask_mod, |
| 154 | + B=batch_size, |
| 155 | + H=None, # attention head |
| 156 | + Q_LEN=query_length, |
| 157 | + KV_LEN=key_length, |
| 158 | + device=device, |
| 159 | + _compile=True, |
48 | 160 | )
|
| 161 | + |
| 162 | + for n in tuple(sys.modules): |
| 163 | + if ".modeling_" in n and "llama4" not in n: |
| 164 | + if hasattr(sys.modules[n], "make_flex_block_causal_mask"): |
| 165 | + print(n) |
| 166 | + sys.modules[n].make_flex_block_causal_mask = ( |
| 167 | + patched_make_flex_block_causal_mask |
| 168 | + ) |
| 169 | + |
| 170 | + transformers.integrations.flex_attention.make_flex_block_causal_mask = ( |
| 171 | + patched_make_flex_block_causal_mask |
| 172 | + ) |
0 commit comments