Skip to content

Commit b02b11c

Browse files
author
Joseph Suarez
committed
Major bug fix in puffernet.h: small but significant discrepency in puffer vs. torch gelu
1 parent 8eaab1c commit b02b11c

File tree

3 files changed

+18
-3
lines changed

3 files changed

+18
-3
lines changed

pufferlib/extensions/puffernet.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ void _relu(float* input, float* output, int size) {
7676

7777
void _gelu(float* input, float* output, int size) {
7878
for (int i = 0; i < size; i++) {
79-
output[i] = 0.5f*input[i]*(1 + tanhf(0.6628526501011142 * (input[i] + 0.044715f*input[i]*input[i]*input[i])));
79+
output[i] = 0.5f*input[i]*(1 + tanhf(0.7978845608028654 * (input[i] + 0.044715f*input[i]*input[i]*input[i])));
8080
}
8181
}
8282

pufferlib/extensions/puffernet.pyx

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ cdef extern from "puffernet.h":
1818
void _linear(float* input, float* weights, float* bias, float* output,
1919
int batch_size, int input_dim, int output_dim)
2020
void _relu(float* input, float* output,int size)
21+
void _gelu(float* input, float* output, int size)
2122
float _sigmoid(float x)
2223
void _conv2d(float* input, float* weights, float* bias,
2324
float* output, int batch_size, int in_width, int in_height,
@@ -47,6 +48,9 @@ def puf_linear_layer(cnp.ndarray input, cnp.ndarray weights, cnp.ndarray bias, c
4748
def puf_relu(cnp.ndarray input, cnp.ndarray output, int size):
4849
_relu(<float*> input.data, <float*> output.data, size)
4950

51+
def puf_gelu(cnp.ndarray input, cnp.ndarray output, int size):
52+
_gelu(<float*> input.data, <float*> output.data, size)
53+
5054
def puf_sigmoid(float x):
5155
return _sigmoid(x)
5256

tests/test_puffernet.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,17 @@ def test_puffernet_relu(batch_size=16, input_size=128):
4242

4343
assert_near(input_puffer, output_torch.numpy())
4444

45+
def test_puffernet_gelu(batch_size=16, input_size=128):
46+
input_puffer = make_dummy_data(batch_size, input_size)
47+
48+
input_torch = torch.from_numpy(input_puffer)
49+
output_torch = torch.nn.functional.gelu(input_torch, approximate='tanh').detach()
50+
51+
# PufferNet done second because it is in-place on the input
52+
puffernet.puf_gelu(input_puffer, input_puffer, batch_size*input_size)
53+
54+
assert_near(input_puffer, output_torch.numpy())
55+
4556
def test_puffernet_sigmoid(n=1024, epsilon=1e-4):
4657
input_np = make_dummy_data(n)
4758

@@ -247,9 +258,8 @@ def test_nmmo3(batch_size=1, input_size=512, hidden_size=512):
247258
pass
248259

249260
if __name__ == '__main__':
250-
test_nmmo3()
251-
exit()
252261
test_puffernet_relu()
262+
test_puffernet_gelu()
253263
test_puffernet_sigmoid()
254264
test_puffernet_linear_layer()
255265
test_puffernet_convolution_layer()
@@ -260,3 +270,4 @@ def test_nmmo3(batch_size=1, input_size=512, hidden_size=512):
260270
test_puffernet_one_hot()
261271
test_puffernet_cat_dim1()
262272
test_puffernet_argmax_multidiscrete()
273+
#test_nmmo3()

0 commit comments

Comments
 (0)