Skip to content

Commit 106e726

Browse files
committed
remove droput layers
1 parent b3aedf3 commit 106e726

File tree

1 file changed

+2
-12
lines changed

1 file changed

+2
-12
lines changed

docs/transformers/LoRA/GPT2.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,11 @@ def __init__(self, dim):
7777
self.c_fc = Conv1D(dim, config['n_embd'])
7878
self.c_proj = Conv1D(config['n_embd'], dim)
7979
self.act = nn.functional.gelu
80-
self.dropout = nn.Dropout(config['resid_pdrop'])
8180

8281
def forward(self, hidden_states):
8382
hidden_states = self.c_fc(hidden_states)
8483
hidden_states = self.act(hidden_states)
8584
hidden_states = self.c_proj(hidden_states)
86-
hidden_states = self.dropout(hidden_states)
8785
return hidden_states
8886

8987

@@ -98,9 +96,6 @@ def __init__(self):
9896
self.c_att = Conv1D(config['n_embd'] * 3, config['n_embd'])
9997
self.c_proj = Conv1D(config['n_embd'], config['n_embd'])
10098

101-
self.resid_dropout = nn.Dropout(config['resid_pdrop'])
102-
self.attn_dropout = nn.Dropout(config['attn_pdrop'])
103-
10499
def _split_heads(self, tensor, num_heads, attn_head_size):
105100
"""
106101
Splits hidden_size dim into attn_head_size and num_heads
@@ -123,7 +118,7 @@ def forward(self, hidden_states):
123118
key,
124119
value,
125120
attn_mask=None,
126-
dropout_p=self.attn_dropout.p if self.training else 0.0,
121+
dropout_p=0.0,
127122
is_causal=True, # for the triangular mask
128123
)
129124

@@ -132,7 +127,6 @@ def forward(self, hidden_states):
132127
attn_output = attn_output.view(batch_size, seq_length, self.embed_dim)
133128

134129
attn_output = self.c_proj(attn_output)
135-
attn_output = self.resid_dropout(attn_output)
136130

137131
return attn_output
138132

@@ -168,8 +162,6 @@ def __init__(self):
168162
self.token_embedding = nn.Embedding(config['vocab_size'], config['n_embd'])
169163
self.position_embedding = nn.Embedding(config['n_positions'], config['n_embd'])
170164

171-
self.dropout = nn.Dropout(p=config['embd_pdrop'], inplace=False)
172-
173165
self.blocks = nn.ModuleList([Block() for _ in range(config['n_layer'])])
174166

175167
self.final_norm = nn.LayerNorm(config['n_embd'], eps=config['layer_norm_epsilon'])
@@ -183,9 +175,7 @@ def forward(self, input_ids):
183175
position_ids = torch.arange(input_shape) # T C
184176
position_embeddings = self.position_embedding(position_ids) # B T C
185177

186-
embeddings = token_embeddings + position_embeddings
187-
188-
hidden_states = self.dropout(embeddings)
178+
hidden_states = token_embeddings + position_embeddings
189179

190180
for block in self.blocks:
191181
hidden_states = block(hidden_states)

0 commit comments

Comments
 (0)