Skip to content

Commit 8b22fde

Browse files
committed
test using CONV1D
1 parent 877d924 commit 8b22fde

File tree

1 file changed

+14
-19
lines changed

1 file changed

+14
-19
lines changed

src/cehrbert/models/hf_models/hf_cehrbert.py

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from transformers.activations import gelu_new
99
from transformers.models.bert import modeling_bert
1010
from transformers.models.bert.modeling_bert import BertEncoder, BertOnlyMLMHead, BertPooler
11+
from transformers.pytorch_utils import Conv1D
1112
from transformers.utils import is_flash_attn_2_available, logging
1213

1314
if is_flash_attn_2_available():
@@ -32,12 +33,9 @@ def __init__(self, config, position_embedding_type=None):
3233
self.num_attention_heads = config.num_attention_heads
3334
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
3435
self.all_head_size = self.num_attention_heads * self.attention_head_size
35-
36-
self.query = nn.Linear(config.hidden_size, self.all_head_size)
37-
self.key = nn.Linear(config.hidden_size, self.all_head_size)
38-
self.value = nn.Linear(config.hidden_size, self.all_head_size)
39-
40-
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
36+
self.split_size = config.hidden_size
37+
self.c_attn = Conv1D(3 * config.hidden_size, config.hidden_size)
38+
self.dropout_rate = config.attention_probs_dropout_prob
4139

4240
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
4341
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
@@ -58,8 +56,9 @@ def forward(
5856
dtype = hidden_states.dtype
5957

6058
batch_size = hidden_states.size(0)
59+
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
6160
# (batch, n_heads, seq_length, head_dim)
62-
query_layer = self.transpose_for_scores(self.query(hidden_states))
61+
query_layer = self.transpose_for_scores(query)
6362
# If this is instantiated as a cross-attention module, the keys
6463
# and values come from an encoder; the attention mask needs to be
6564
# such that the encoder's padding tokens are not attended to.
@@ -71,21 +70,17 @@ def forward(
7170
value_layer = past_key_value[1]
7271
attention_mask = encoder_attention_mask
7372
elif is_cross_attention:
74-
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
75-
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
73+
key_layer = self.transpose_for_scores(key)
74+
value_layer = self.transpose_for_scores(value)
7675
attention_mask = encoder_attention_mask
7776
elif past_key_value is not None:
78-
key_layer = self.transpose_for_scores(self.key(hidden_states))
79-
value_layer = self.transpose_for_scores(self.value(hidden_states))
77+
key_layer = self.transpose_for_scores(key)
78+
value_layer = self.transpose_for_scores(value)
8079
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
8180
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
8281
else:
83-
key_layer = self.transpose_for_scores(self.key(hidden_states))
84-
value_layer = self.transpose_for_scores(self.value(hidden_states))
85-
86-
query_layer = query_layer.to(torch.bfloat16)
87-
key_layer = key_layer.to(torch.bfloat16)
88-
value_layer = value_layer.to(torch.bfloat16)
82+
key_layer = self.transpose_for_scores(key)
83+
value_layer = self.transpose_for_scores(value)
8984

9085
# Flash Attention forward pass
9186
# Use the built-in scaled_dot_product_attention with Flash Attention
@@ -95,7 +90,7 @@ def forward(
9590
key_layer,
9691
value_layer,
9792
attn_mask=attention_mask,
98-
dropout_p=self.dropout.p,
93+
dropout_p=self.dropout_rate,
9994
is_causal=False,
10095
scale=None, # Default is 1/sqrt(head_dim)
10196
)
@@ -270,7 +265,7 @@ class CehrBertPreTrainedModel(PreTrainedModel):
270265

271266
def _init_weights(self, module):
272267
"""Initialize the weights."""
273-
if isinstance(module, nn.Linear):
268+
if isinstance(module, (nn.Linear, Conv1D)):
274269
# Slightly different from the TF version which uses truncated_normal for initialization
275270
# cf https://github.com/pytorch/pytorch/pull/5617
276271
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)

0 commit comments

Comments
 (0)