77from transformers import PreTrainedModel
88from transformers .activations import gelu_new
99from transformers .models .bert import modeling_bert
10- from transformers .models .bert .modeling_bert import BertEncoder , BertOnlyMLMHead , BertPooler
10+ from transformers .models .bert .modeling_bert import BertEncoder , BertOnlyMLMHead , BertPooler , BertSelfAttention
1111from transformers .pytorch_utils import Conv1D
12- from transformers .utils import is_flash_attn_2_available , logging
12+ from transformers .utils import logging
1313from transformers .utils .import_utils import _is_package_available
1414
15- if is_flash_attn_2_available ():
16- from flash_attn import flash_attn_func , flash_attn_varlen_func
17- from flash_attn .bert_padding import index_first_axis , pad_input , unpad_input # noqa
18-
1915if _is_package_available ("xformers" ):
2016 from xformers .ops import fmha
2117 import xformers .ops as xops
@@ -38,29 +34,12 @@ def create_block_diagonal_mask(seqlens: torch.LongTensor) -> torch.Tensor:
3834 return mask # shape: (total_len, total_len)
3935
4036
41- class BertSelfFlashAttention (nn .Module ):
42- def __init__ (self , config , position_embedding_type = None ):
43- super ().__init__ ()
44- if config .hidden_size % config .num_attention_heads != 0 and not hasattr (config , "embedding_size" ):
45- raise ValueError (
46- f"The hidden size ({ config .hidden_size } ) is not a multiple of the number of attention "
47- f"heads ({ config .num_attention_heads } )"
48- )
49- if not _is_package_available ("xformers" ):
50- raise RuntimeError ("xformers is not installed for BertSelfFlashAttention" )
51- LOG .info ("BertSelfFlashAttention is successfully initialized" )
52- self .num_attention_heads = config .num_attention_heads
53- self .attention_head_size = int (config .hidden_size / config .num_attention_heads )
54- self .all_head_size = self .num_attention_heads * self .attention_head_size
55- self .split_size = config .hidden_size
56- self .embed_dim = config .hidden_size
57- self .c_attn = Conv1D (3 * self .embed_dim , self .embed_dim )
58- self .c_proj = Conv1D (self .embed_dim , self .embed_dim )
59- self .dropout = nn .Dropout (config .attention_probs_dropout_prob )
37+ class BertSelfFlashAttention (BertSelfAttention ):
6038
6139 def split_heads (self , x : torch .Tensor ) -> torch .Tensor :
6240 new_x_shape = x .size ()[:- 1 ] + (self .num_attention_heads , self .attention_head_size )
63- return x .view (new_x_shape )
41+ x = x .view (new_x_shape )
42+ return x
6443
6544 def forward (
6645 self ,
@@ -72,20 +51,20 @@ def forward(
7251 past_key_value : Optional [Tuple [Tuple [torch .FloatTensor ]]] = None ,
7352 output_attentions : Optional [bool ] = False ,
7453 ) -> Tuple [torch .Tensor ]:
75- query , key , value = self .c_attn (hidden_states ).split (self .split_size , dim = 2 )
76- dtype = query .dtype
77- query_layer = self .split_heads (query ).to (torch .bfloat16 )
78- key_layer = self .split_heads (key ).to (torch .bfloat16 )
79- value_layer = self .split_heads (value ).to (torch .bfloat16 )
54+
55+ query_layer = self .split_heads (self .query (hidden_states )).to (torch .bfloat16 )
56+ key_layer = self .split_heads (self .key (hidden_states )).to (torch .bfloat16 )
57+ value_layer = self .split_heads (self .value (hidden_states )).to (torch .bfloat16 )
8058 attn_dropout = self .dropout .p if self .training else 0.0
8159
60+ dtype = hidden_states .dtype
8261 attn_output = xops .memory_efficient_attention (
8362 query_layer , key_layer , value_layer , attn_bias = attention_mask , p = attn_dropout
8463 )
8564
86- attn_output = attn_output .to ( dtype )
87- attn_output = self . c_proj ( attn_output )
88- attn_output = self . dropout ( attn_output )
65+ new_context_layer_shape = attn_output .size ()[: - 2 ] + ( self . all_head_size , )
66+ attn_output = attn_output . view ( new_context_layer_shape ). to ( dtype )
67+
8968 # The BertLayer expects a tuple
9069 return (attn_output ,)
9170
0 commit comments