Skip to content

Commit 186e499

Browse files
e-straussmboehm7
authored andcommitted
[SYSTEMDS-3917] New built-in SELU activiation function
Closes #2328.
1 parent 37ac75b commit 186e499

File tree

3 files changed

+139
-0
lines changed

3 files changed

+139
-0
lines changed

scripts/nn/layers/selu.dml

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
/*
23+
* SeLU (Scaled Exponential Linear Unit) nonlinearity layer.
24+
*/
25+
26+
forward = function(matrix[double] X)
27+
return (matrix[double] out) {
28+
/*
29+
* Computes the forward pass for a SeLU nonlinearity layer.
30+
*
31+
* selu(x) = lambda * ( x if x > 0
32+
* α * (exp(x) - 1) if x <= 0 )
33+
*
34+
* Inputs:
35+
* - X: Inputs, of shape (any, any).
36+
*
37+
* Outputs:
38+
* - out: Outputs, of same shape as `X`.
39+
*/
40+
alpha = 1.6732632423543772
41+
lambda = 1.0507009873554805
42+
43+
out = (X > 0) * (lambda * X) +
44+
(X <= 0) * (lambda * alpha * (exp(X) - 1))
45+
}
46+
47+
backward = function(matrix[double] dout, matrix[double] X)
48+
return (matrix[double] dX) {
49+
/*
50+
* Computes the backward pass for a SeLU nonlinearity layer.
51+
*
52+
* Inputs:
53+
* - dout: Gradient wrt `out` from upstream, of same shape as `X`.
54+
* - X: Inputs, of shape (any, any).
55+
*
56+
* Outputs:
57+
* - dX: Gradient wrt `X`, of same shape as `X`.
58+
*/
59+
alpha = 1.6732632423543772
60+
lambda = 1.0507009873554805
61+
62+
dselu = (X > 0) * lambda +
63+
(X <= 0) * (lambda * alpha * exp(X))
64+
65+
dX = dselu * dout
66+
}

src/test/java/org/apache/sysds/test/applications/nn/NNComponentTest.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,11 @@ public void gelu() {
129129
run("gelu.dml");
130130
}
131131

132+
@Test
133+
public void selu() {
134+
run("selu.dml");
135+
}
136+
132137
@Test
133138
public void embedding() {
134139
run("embedding.dml");
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
source("nn/layers/selu.dml") as selu
23+
source("src/test/scripts/applications/nn/util.dml") as test_util
24+
25+
selu_test1 = function() {
26+
print("Testing SeLU, test 1")
27+
28+
X = matrix("1. -0.5
29+
0. 2.", rows=2, cols=2)
30+
dout = matrix("1 1
31+
1 1", rows=2, cols=2)
32+
33+
# Reference from PyTorch nn.SELU
34+
out_expected = matrix("1.050701 -0.69175816
35+
0. 2.101402", rows=2, cols=2)
36+
gradient_expected = matrix("1.050701 1.0663412
37+
1.7580993 1.050701", rows=2, cols=2)
38+
39+
out = selu::forward(X)
40+
test_util::check_all_close(out, out_expected, 0.00001)
41+
42+
gradient = selu::backward(dout, X)
43+
test_util::check_all_close(gradient, gradient_expected, 0.00001)
44+
}
45+
46+
selu_test2 = function() {
47+
print("Testing SeLU, test 2")
48+
49+
X = matrix("0.5 -1.5
50+
1. -2.", rows=2, cols=2)
51+
dout = matrix("1 1
52+
1 1", rows=2, cols=2)
53+
54+
# Precomputed reference from PyTorch nn.SELU
55+
out_expected = matrix("0.5253505 -1.3658143
56+
1.050701 -1.5201665", rows=2, cols=2)
57+
gradient_expected = matrix("1.050701 0.392285
58+
1.050701 0.23793286", rows=2, cols=2)
59+
60+
out = selu::forward(X)
61+
test_util::check_all_close(out, out_expected, 0.00001)
62+
63+
gradient = selu::backward(dout, X)
64+
test_util::check_all_close(gradient, gradient_expected, 0.00001)
65+
}
66+
67+
selu_test1()
68+
selu_test2()

0 commit comments

Comments
 (0)