88from transformers .activations import gelu_new
99from transformers .models .bert import modeling_bert
1010from transformers .models .bert .modeling_bert import BertEncoder , BertOnlyMLMHead , BertPooler
11+ from transformers .pytorch_utils import Conv1D
1112from transformers .utils import is_flash_attn_2_available , logging
1213
1314if is_flash_attn_2_available ():
@@ -32,12 +33,9 @@ def __init__(self, config, position_embedding_type=None):
3233 self .num_attention_heads = config .num_attention_heads
3334 self .attention_head_size = int (config .hidden_size / config .num_attention_heads )
3435 self .all_head_size = self .num_attention_heads * self .attention_head_size
35-
36- self .query = nn .Linear (config .hidden_size , self .all_head_size )
37- self .key = nn .Linear (config .hidden_size , self .all_head_size )
38- self .value = nn .Linear (config .hidden_size , self .all_head_size )
39-
40- self .dropout = nn .Dropout (config .attention_probs_dropout_prob )
36+ self .split_size = config .hidden_size
37+ self .c_attn = Conv1D (3 * config .hidden_size , config .hidden_size )
38+ self .dropout_rate = config .attention_probs_dropout_prob
4139
4240 def transpose_for_scores (self , x : torch .Tensor ) -> torch .Tensor :
4341 new_x_shape = x .size ()[:- 1 ] + (self .num_attention_heads , self .attention_head_size )
@@ -58,8 +56,9 @@ def forward(
5856 dtype = hidden_states .dtype
5957
6058 batch_size = hidden_states .size (0 )
59+ query , key , value = self .c_attn (hidden_states ).split (self .split_size , dim = 2 )
6160 # (batch, n_heads, seq_length, head_dim)
62- query_layer = self .transpose_for_scores (self . query ( hidden_states ) )
61+ query_layer = self .transpose_for_scores (query )
6362 # If this is instantiated as a cross-attention module, the keys
6463 # and values come from an encoder; the attention mask needs to be
6564 # such that the encoder's padding tokens are not attended to.
@@ -71,21 +70,17 @@ def forward(
7170 value_layer = past_key_value [1 ]
7271 attention_mask = encoder_attention_mask
7372 elif is_cross_attention :
74- key_layer = self .transpose_for_scores (self . key ( encoder_hidden_states ) )
75- value_layer = self .transpose_for_scores (self . value ( encoder_hidden_states ) )
73+ key_layer = self .transpose_for_scores (key )
74+ value_layer = self .transpose_for_scores (value )
7675 attention_mask = encoder_attention_mask
7776 elif past_key_value is not None :
78- key_layer = self .transpose_for_scores (self . key ( hidden_states ) )
79- value_layer = self .transpose_for_scores (self . value ( hidden_states ) )
77+ key_layer = self .transpose_for_scores (key )
78+ value_layer = self .transpose_for_scores (value )
8079 key_layer = torch .cat ([past_key_value [0 ], key_layer ], dim = 2 )
8180 value_layer = torch .cat ([past_key_value [1 ], value_layer ], dim = 2 )
8281 else :
83- key_layer = self .transpose_for_scores (self .key (hidden_states ))
84- value_layer = self .transpose_for_scores (self .value (hidden_states ))
85-
86- query_layer = query_layer .to (torch .bfloat16 )
87- key_layer = key_layer .to (torch .bfloat16 )
88- value_layer = value_layer .to (torch .bfloat16 )
82+ key_layer = self .transpose_for_scores (key )
83+ value_layer = self .transpose_for_scores (value )
8984
9085 # Flash Attention forward pass
9186 # Use the built-in scaled_dot_product_attention with Flash Attention
@@ -95,7 +90,7 @@ def forward(
9590 key_layer ,
9691 value_layer ,
9792 attn_mask = attention_mask ,
98- dropout_p = self .dropout . p ,
93+ dropout_p = self .dropout_rate ,
9994 is_causal = False ,
10095 scale = None , # Default is 1/sqrt(head_dim)
10196 )
@@ -270,7 +265,7 @@ class CehrBertPreTrainedModel(PreTrainedModel):
270265
271266 def _init_weights (self , module ):
272267 """Initialize the weights."""
273- if isinstance (module , nn .Linear ):
268+ if isinstance (module , ( nn .Linear , Conv1D ) ):
274269 # Slightly different from the TF version which uses truncated_normal for initialization
275270 # cf https://github.com/pytorch/pytorch/pull/5617
276271 module .weight .data .normal_ (mean = 0.0 , std = self .config .initializer_range )
0 commit comments