Skip to content

Commit e5f3ae0

Browse files
committed
updated the attn_implementation to xformers
1 parent e1698ad commit e5f3ae0

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

src/cehrbert/models/hf_models/hf_cehrbert.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def forward(
6969
return (attn_output,)
7070

7171

72-
modeling_bert.BERT_SELF_ATTENTION_CLASSES.update({"flash_attention_2": BertSelfFlashAttention})
72+
modeling_bert.BERT_SELF_ATTENTION_CLASSES.update({"xformers": BertSelfFlashAttention})
7373

7474

7575
class PositionalEncodingLayer(nn.Module):
@@ -231,7 +231,7 @@ class CehrBertPreTrainedModel(PreTrainedModel):
231231
supports_gradient_checkpointing = True
232232
_no_split_modules = ["BertLayer"]
233233
_supports_sdpa = True
234-
_supports_flash_attn_2 = True
234+
# _supports_flash_attn_2 = True
235235

236236
def _init_weights(self, module):
237237
"""Initialize the weights."""
@@ -286,7 +286,7 @@ def forward(
286286
# [batch_size, from_seq_length, to_seq_length]
287287
# ourselves in which case we just need to make it broadcastable to all heads.
288288
# The flash attention requires the original attention_mask
289-
if not getattr(self.config, "_attn_implementation", "eager") == "flash_attention_2":
289+
if not getattr(self.config, "_attn_implementation", "eager") == "xformers":
290290
if seq_lens is not None:
291291
attention_mask = create_block_diagonal_mask(seq_lens)
292292
else:

src/cehrbert/runners/hf_cehrbert_pretrain_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def load_and_create_model(
132132
**model_args.as_dict(),
133133
)
134134
model = CehrBertForPreTraining(model_config)
135-
if model_args.attn_implementation == "flash_attention_2":
135+
if model_args.attn_implementation == "xformers":
136136
model.gradient_checkpointing_enable()
137137
model.enable_input_require_grads()
138138
model.use_memory_efficient_attention = True # or model.use_flash_attention_2 = True

0 commit comments

Comments
 (0)