Skip to content

Commit f8522a7

Browse files
MaximilianSchreffphaniarnab
authored andcommitted
[SYSTEMDS-3821] Add GELU Activation Function (Approximation)
This patch introduces the Gaussian Error Linear Unit (GELU) activation function to SystemDS as a built-in operation. The implementation uses the widely adopted approximate formulation (https://arxiv.org/abs/1606.08415). This patch is a part of a series of commits to support popular Transformer architectures in SystemDS. The GELU activation the most commonly used activation functions in models like BERT and GPT. Closes #2177
1 parent e32c323 commit f8522a7

File tree

3 files changed

+141
-0
lines changed

3 files changed

+141
-0
lines changed

scripts/nn/layers/gelu.dml

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
/*
23+
* Gaussian Error Linear Unit (GELU) nonlinearity layer.
24+
*/
25+
26+
source("nn/layers/tanh.dml") as tanh
27+
28+
forward = function(matrix[double] X)
29+
return (matrix[double] out) {
30+
/*
31+
* Computes the forward pass for a GELU nonlinearity layer, via
32+
* its tanh approximation.
33+
*
34+
* Performs an element-wise evaluation of
35+
* `GELU(x) = x * CDF(x)`.
36+
* where CDF is the cumulative distribution function of the
37+
* standard normal distribution:
38+
* `CDF(x) = 0.5 * (1 + erf(x/sqrt(2)))`
39+
* This implementation uses the tanh approximation:
40+
* `CDF(x) =~ 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715x^3)))`
41+
*
42+
* Inputs:
43+
* - X: Inputs, of shape (any, any).
44+
*
45+
* Outputs:
46+
* - out: Outputs, of same shape as `X`.
47+
*/
48+
cdf = 0.5 * (1 + tanh(sqrt(2 / pi) * (X + 0.044715 * X^3)))
49+
out = cdf * X
50+
}
51+
52+
backward = function(matrix[double] dout, matrix[double] X)
53+
return (matrix[double] dX) {
54+
/*
55+
* Computes the backward pass for a GELU nonlinearity layer, via
56+
* its tanh approximation.
57+
*
58+
* Inputs:
59+
* - dout: Gradient wrt `out` from upstream, of same shape as `X`.
60+
* - X: Previous input data matrix, of shape (any, any).
61+
*
62+
* Outputs:
63+
* - dX: Gradient wrt `X`, of same shape as `X`.
64+
*/
65+
a = sqrt(2 / pi)
66+
b = 0.044715
67+
T = tanh(a * (X + b * X^3))
68+
dT = 1 - T^2
69+
dX = dout * (0.5 * (1 + T) + 0.5 * X * dT * a * (1 + 3 * b * X^2))
70+
}

src/test/java/org/apache/sysds/test/applications/nn/NNComponentTest.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,11 @@ public void resnet() {
124124
run("resnet_bottleneck.dml");
125125
}
126126

127+
@Test
128+
public void gelu() {
129+
run("gelu.dml");
130+
}
131+
127132
@Override
128133
protected void run(String name) {
129134
super.run("component/" + name);
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
source("nn/layers/gelu.dml") as gelu
23+
source("src/test/scripts/applications/nn/util.dml") as test_util
24+
25+
gelu_test1 = function() {
26+
print("Testing GELU, test 1")
27+
28+
X = matrix("1. -0.5
29+
0. 2.", rows=2, cols=2)
30+
dout = matrix("1 1
31+
1 1", rows=2, cols=2)
32+
out_expected = matrix("0.841192 -0.154286
33+
0. 1.9545977", rows=2, cols=2)
34+
gradient_expected = matrix("1.0829641 0.13263011
35+
0.5 1.0860993", rows=2, cols=2)
36+
37+
out = gelu::forward(X)
38+
39+
test_util::check_all_close(out, out_expected, 0.00001)
40+
41+
gradient = gelu::backward(dout, X)
42+
test_util::check_all_close(gradient, gradient_expected, 0.00001)
43+
}
44+
45+
gelu_test2 = function() {
46+
print("Testing GELU, test 2")
47+
48+
X = matrix("0.5 -1.5
49+
1. -2.", rows=2, cols=2)
50+
dout = matrix("1 1
51+
1 1", rows=2, cols=2)
52+
out_expected = matrix("0.345714 -0.10042843
53+
0.841192 -0.04540229", rows=2, cols=2)
54+
gradient_expected = matrix("0.8673699 -0.1277108
55+
1.0829641 -0.08609922", rows=2, cols=2)
56+
57+
out = gelu::forward(X)
58+
59+
test_util::check_all_close(out, out_expected, 0.00001)
60+
61+
gradient = gelu::backward(dout, X)
62+
test_util::check_all_close(gradient, gradient_expected, 0.00001)
63+
}
64+
65+
gelu_test1()
66+
gelu_test2()

0 commit comments

Comments
 (0)