Skip to content

Commit bb65e9d

Browse files
committed
implemented xformers in the bert attention layer
1 parent bbfa072 commit bb65e9d

File tree

1 file changed

+13
-34
lines changed

1 file changed

+13
-34
lines changed

src/cehrbert/models/hf_models/hf_cehrbert.py

Lines changed: 13 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,11 @@
77
from transformers import PreTrainedModel
88
from transformers.activations import gelu_new
99
from transformers.models.bert import modeling_bert
10-
from transformers.models.bert.modeling_bert import BertEncoder, BertOnlyMLMHead, BertPooler
10+
from transformers.models.bert.modeling_bert import BertEncoder, BertOnlyMLMHead, BertPooler, BertSelfAttention
1111
from transformers.pytorch_utils import Conv1D
12-
from transformers.utils import is_flash_attn_2_available, logging
12+
from transformers.utils import logging
1313
from transformers.utils.import_utils import _is_package_available
1414

15-
if is_flash_attn_2_available():
16-
from flash_attn import flash_attn_func, flash_attn_varlen_func
17-
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
18-
1915
if _is_package_available("xformers"):
2016
from xformers.ops import fmha
2117
import xformers.ops as xops
@@ -38,29 +34,12 @@ def create_block_diagonal_mask(seqlens: torch.LongTensor) -> torch.Tensor:
3834
return mask # shape: (total_len, total_len)
3935

4036

41-
class BertSelfFlashAttention(nn.Module):
42-
def __init__(self, config, position_embedding_type=None):
43-
super().__init__()
44-
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
45-
raise ValueError(
46-
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
47-
f"heads ({config.num_attention_heads})"
48-
)
49-
if not _is_package_available("xformers"):
50-
raise RuntimeError("xformers is not installed for BertSelfFlashAttention")
51-
LOG.info("BertSelfFlashAttention is successfully initialized")
52-
self.num_attention_heads = config.num_attention_heads
53-
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
54-
self.all_head_size = self.num_attention_heads * self.attention_head_size
55-
self.split_size = config.hidden_size
56-
self.embed_dim = config.hidden_size
57-
self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
58-
self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
59-
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
37+
class BertSelfFlashAttention(BertSelfAttention):
6038

6139
def split_heads(self, x: torch.Tensor) -> torch.Tensor:
6240
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
63-
return x.view(new_x_shape)
41+
x = x.view(new_x_shape)
42+
return x
6443

6544
def forward(
6645
self,
@@ -72,20 +51,20 @@ def forward(
7251
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
7352
output_attentions: Optional[bool] = False,
7453
) -> Tuple[torch.Tensor]:
75-
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
76-
dtype = query.dtype
77-
query_layer = self.split_heads(query).to(torch.bfloat16)
78-
key_layer = self.split_heads(key).to(torch.bfloat16)
79-
value_layer = self.split_heads(value).to(torch.bfloat16)
54+
55+
query_layer = self.split_heads(self.query(hidden_states)).to(torch.bfloat16)
56+
key_layer = self.split_heads(self.key(hidden_states)).to(torch.bfloat16)
57+
value_layer = self.split_heads(self.value(hidden_states)).to(torch.bfloat16)
8058
attn_dropout = self.dropout.p if self.training else 0.0
8159

60+
dtype = hidden_states.dtype
8261
attn_output = xops.memory_efficient_attention(
8362
query_layer, key_layer, value_layer, attn_bias=attention_mask, p=attn_dropout
8463
)
8564

86-
attn_output = attn_output.to(dtype)
87-
attn_output = self.c_proj(attn_output)
88-
attn_output = self.dropout(attn_output)
65+
new_context_layer_shape = attn_output.size()[:-2] + (self.all_head_size,)
66+
attn_output = attn_output.view(new_context_layer_shape).to(dtype)
67+
8968
# The BertLayer expects a tuple
9069
return (attn_output,)
9170

0 commit comments

Comments
 (0)