@@ -77,13 +77,11 @@ def __init__(self, dim):
77
77
self .c_fc = Conv1D (dim , config ['n_embd' ])
78
78
self .c_proj = Conv1D (config ['n_embd' ], dim )
79
79
self .act = nn .functional .gelu
80
- self .dropout = nn .Dropout (config ['resid_pdrop' ])
81
80
82
81
def forward (self , hidden_states ):
83
82
hidden_states = self .c_fc (hidden_states )
84
83
hidden_states = self .act (hidden_states )
85
84
hidden_states = self .c_proj (hidden_states )
86
- hidden_states = self .dropout (hidden_states )
87
85
return hidden_states
88
86
89
87
@@ -98,9 +96,6 @@ def __init__(self):
98
96
self .c_att = Conv1D (config ['n_embd' ] * 3 , config ['n_embd' ])
99
97
self .c_proj = Conv1D (config ['n_embd' ], config ['n_embd' ])
100
98
101
- self .resid_dropout = nn .Dropout (config ['resid_pdrop' ])
102
- self .attn_dropout = nn .Dropout (config ['attn_pdrop' ])
103
-
104
99
def _split_heads (self , tensor , num_heads , attn_head_size ):
105
100
"""
106
101
Splits hidden_size dim into attn_head_size and num_heads
@@ -123,7 +118,7 @@ def forward(self, hidden_states):
123
118
key ,
124
119
value ,
125
120
attn_mask = None ,
126
- dropout_p = self . attn_dropout . p if self . training else 0.0 ,
121
+ dropout_p = 0.0 ,
127
122
is_causal = True , # for the triangular mask
128
123
)
129
124
@@ -132,7 +127,6 @@ def forward(self, hidden_states):
132
127
attn_output = attn_output .view (batch_size , seq_length , self .embed_dim )
133
128
134
129
attn_output = self .c_proj (attn_output )
135
- attn_output = self .resid_dropout (attn_output )
136
130
137
131
return attn_output
138
132
@@ -168,8 +162,6 @@ def __init__(self):
168
162
self .token_embedding = nn .Embedding (config ['vocab_size' ], config ['n_embd' ])
169
163
self .position_embedding = nn .Embedding (config ['n_positions' ], config ['n_embd' ])
170
164
171
- self .dropout = nn .Dropout (p = config ['embd_pdrop' ], inplace = False )
172
-
173
165
self .blocks = nn .ModuleList ([Block () for _ in range (config ['n_layer' ])])
174
166
175
167
self .final_norm = nn .LayerNorm (config ['n_embd' ], eps = config ['layer_norm_epsilon' ])
@@ -183,9 +175,7 @@ def forward(self, input_ids):
183
175
position_ids = torch .arange (input_shape ) # T C
184
176
position_embeddings = self .position_embedding (position_ids ) # B T C
185
177
186
- embeddings = token_embeddings + position_embeddings
187
-
188
- hidden_states = self .dropout (embeddings )
178
+ hidden_states = token_embeddings + position_embeddings
189
179
190
180
for block in self .blocks :
191
181
hidden_states = block (hidden_states )
0 commit comments