File tree Expand file tree Collapse file tree 2 files changed +4
-4
lines changed
Expand file tree Collapse file tree 2 files changed +4
-4
lines changed Original file line number Diff line number Diff line change @@ -69,7 +69,7 @@ def forward(
6969 return (attn_output ,)
7070
7171
72- modeling_bert .BERT_SELF_ATTENTION_CLASSES .update ({"flash_attention_2 " : BertSelfFlashAttention })
72+ modeling_bert .BERT_SELF_ATTENTION_CLASSES .update ({"xformers " : BertSelfFlashAttention })
7373
7474
7575class PositionalEncodingLayer (nn .Module ):
@@ -231,7 +231,7 @@ class CehrBertPreTrainedModel(PreTrainedModel):
231231 supports_gradient_checkpointing = True
232232 _no_split_modules = ["BertLayer" ]
233233 _supports_sdpa = True
234- _supports_flash_attn_2 = True
234+ # _supports_flash_attn_2 = True
235235
236236 def _init_weights (self , module ):
237237 """Initialize the weights."""
@@ -286,7 +286,7 @@ def forward(
286286 # [batch_size, from_seq_length, to_seq_length]
287287 # ourselves in which case we just need to make it broadcastable to all heads.
288288 # The flash attention requires the original attention_mask
289- if not getattr (self .config , "_attn_implementation" , "eager" ) == "flash_attention_2 " :
289+ if not getattr (self .config , "_attn_implementation" , "eager" ) == "xformers " :
290290 if seq_lens is not None :
291291 attention_mask = create_block_diagonal_mask (seq_lens )
292292 else :
Original file line number Diff line number Diff line change @@ -132,7 +132,7 @@ def load_and_create_model(
132132 ** model_args .as_dict (),
133133 )
134134 model = CehrBertForPreTraining (model_config )
135- if model_args .attn_implementation == "flash_attention_2 " :
135+ if model_args .attn_implementation == "xformers " :
136136 model .gradient_checkpointing_enable ()
137137 model .enable_input_require_grads ()
138138 model .use_memory_efficient_attention = True # or model.use_flash_attention_2 = True
You can’t perform that action at this time.
0 commit comments