Skip to content

Commit b3aedf3

Browse files
committed
remove gelu custom impl and use pytorch impl
1 parent cbc38bb commit b3aedf3

File tree

1 file changed

+1
-15
lines changed

1 file changed

+1
-15
lines changed

docs/transformers/LoRA/GPT2.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,6 @@
4444
"vocab_size": 50257
4545
}
4646

47-
import math
48-
from torch import Tensor
49-
5047

5148
# from transformers
5249
class Conv1D(nn.Module):
@@ -74,23 +71,12 @@ def forward(self, x):
7471
return x
7572

7673

77-
# from transformers
78-
class NewGELUActivation(nn.Module):
79-
"""
80-
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
81-
the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
82-
"""
83-
84-
def forward(self, input: Tensor) -> Tensor:
85-
return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
86-
87-
8874
class HeadFFN(nn.Module): # todo rename
8975
def __init__(self, dim):
9076
super().__init__()
9177
self.c_fc = Conv1D(dim, config['n_embd'])
9278
self.c_proj = Conv1D(config['n_embd'], dim)
93-
self.act = NewGELUActivation()
79+
self.act = nn.functional.gelu
9480
self.dropout = nn.Dropout(config['resid_pdrop'])
9581

9682
def forward(self, hidden_states):

0 commit comments

Comments
 (0)