@@ -21,24 +21,109 @@ contains
21
21
type(unittest_type), allocatable, intent(out) :: testsuite(:)
22
22
23
23
testsuite = [ &
24
+ new_unittest("gaussian", test_gaussian), &
25
+ new_unittest("elu", test_elu), &
26
+ new_unittest("relu", test_relu), &
27
+ new_unittest("gelu" , test_gelu), &
28
+ new_unittest("selu" , test_selu), &
24
29
new_unittest("sigmoid", test_sigmoid), &
25
- new_unittest("logsoftmax", test_logsoftmax), &
26
- new_unittest("gelu" , test_gelu ), &
30
+ new_unittest("silu" , test_silu), &
27
31
new_unittest("softmax", test_softmax) &
32
+ new_unittest("logsoftmax", test_logsoftmax), &
28
33
]
29
34
end subroutine collect_specialfunctions_activation
30
35
31
- subroutine test_sigmoid (error)
36
+ subroutine test_gaussian (error)
32
37
type(error_type), allocatable, intent(out) :: error
33
38
integer, parameter :: n = 10
34
39
real(sp) :: x(n), y(n), y_ref(n)
35
40
36
- y_ref = [0.119202919304371, 0.174285307526588, 0.247663781046867,&
37
- 0.339243650436401, 0.444671928882599, 0.555328071117401,&
38
- 0.660756349563599, 0.752336204051971, 0.825714707374573,&
39
- 0.880797028541565]
40
41
x = linspace(-2._sp, 2._sp, n)
41
- y = sigmoid( x )
42
+ y_ref = exp(-x**2)
43
+ y = gaussian( x )
44
+ call check(error, norm2(y-y_ref) < n*tol_sp )
45
+ if (allocated(error)) return
46
+
47
+ ! Derivative
48
+ y_ref = -2.0 * x * exp(-x**2)
49
+ y = gaussian_grad( x )
50
+ call check(error, norm2(y-y_ref) < n*tol_sp )
51
+ if (allocated(error)) return
52
+ end subroutine
53
+
54
+ subroutine test_elu(error)
55
+ type(error_type), allocatable, intent(out) :: error
56
+ integer, parameter :: n = 10
57
+ real(sp) :: x(n), y(n), y_ref(n), a
58
+
59
+ x = linspace(-2._sp, 2._sp, n)
60
+ a = 1.0_sp
61
+ where(x >= 0._sp)
62
+ y_ref = x
63
+ elsewhere
64
+ y_ref = a * (exp(x) - 1._sp)
65
+ end where
66
+ y = elu( x , a )
67
+ call check(error, norm2(y-y_ref) < n*tol_sp )
68
+ if (allocated(error)) return
69
+
70
+ ! Derivative
71
+ where(x >= 0._sp)
72
+ y_ref = 1.0_sp
73
+ elsewhere
74
+ y_ref = a * exp(x)
75
+ end where
76
+ y = elu_grad( x , a )
77
+ call check(error, norm2(y-y_ref) < n*tol_sp )
78
+ if (allocated(error)) return
79
+ end subroutine
80
+
81
+ subroutine test_relu(error)
82
+ type(error_type), allocatable, intent(out) :: error
83
+ integer, parameter :: n = 10
84
+ real(sp) :: x(n), y(n), y_ref(n), a
85
+
86
+ x = linspace(-2._sp, 2._sp, n)
87
+ y_ref = max(0._sp, x)
88
+ y = relu( x , a )
89
+ call check(error, norm2(y-y_ref) < n*tol_sp )
90
+ if (allocated(error)) return
91
+
92
+ ! Derivative
93
+ where(x > 0._sp)
94
+ y_ref = 1.0_sp
95
+ elsewhere
96
+ y_ref = 0.0_sp
97
+ end where
98
+ y = relu_grad( x , a )
99
+ call check(error, norm2(y-y_ref) < n*tol_sp )
100
+ if (allocated(error)) return
101
+ end subroutine
102
+
103
+ subroutine test_selu(error)
104
+ type(error_type), allocatable, intent(out) :: error
105
+ integer, parameter :: n = 10
106
+ real(sp), parameter :: scale = 1.0507009873554804934193349852946_sp
107
+ real(sp), parameter :: alpha = 1.6732632423543772848170429916717_sp
108
+ real(sp) :: x(n), y(n), y_ref(n)
109
+
110
+ x = linspace(-2._sp, 2._sp, n)
111
+ where(x >= 0._sp)
112
+ y_ref = scale * x
113
+ elsewhere
114
+ y_ref = scale * (alpha * exp(x) - alpha)
115
+ end where
116
+ y = selu( x , a )
117
+ call check(error, norm2(y-y_ref) < n*tol_sp )
118
+ if (allocated(error)) return
119
+
120
+ ! Derivative
121
+ where(x >= 0._sp)
122
+ y_ref = scale
123
+ elsewhere
124
+ y_ref = scale * alpha * exp(x)
125
+ end where
126
+ y = selu_grad( x , a )
42
127
call check(error, norm2(y-y_ref) < n*tol_sp )
43
128
if (allocated(error)) return
44
129
end subroutine
@@ -62,6 +147,40 @@ contains
62
147
if (allocated(error)) return
63
148
end subroutine
64
149
150
+ subroutine test_sigmoid(error)
151
+ type(error_type), allocatable, intent(out) :: error
152
+ integer, parameter :: n = 10
153
+ real(sp) :: x(n), y(n), y_ref(n)
154
+
155
+ y_ref = [0.119202919304371, 0.174285307526588, 0.247663781046867,&
156
+ 0.339243650436401, 0.444671928882599, 0.555328071117401,&
157
+ 0.660756349563599, 0.752336204051971, 0.825714707374573,&
158
+ 0.880797028541565]
159
+ x = linspace(-2._sp, 2._sp, n)
160
+ y = sigmoid( x )
161
+ call check(error, norm2(y-y_ref) < n*tol_sp )
162
+ if (allocated(error)) return
163
+ end subroutine
164
+
165
+ subroutine test_silu(error)
166
+ type(error_type), allocatable, intent(out) :: error
167
+ integer, parameter :: n = 10
168
+ real(sp) :: x(n), y(n), y_ref(n), a
169
+
170
+ x = linspace(-2._sp, 2._sp, n)
171
+ y_ref = x / (1._sp + exp(-x))
172
+ y = silu( x )
173
+ call check(error, norm2(y-y_ref) < n*tol_sp )
174
+ if (allocated(error)) return
175
+
176
+ ! Derivative
177
+ y_ref = (1._sp + exp(x))**2
178
+ y_ref = exp(x) * ( x + y_ref ) / y_ref
179
+ y = silu_grad( x )
180
+ call check(error, norm2(y-y_ref) < n*tol_sp )
181
+ if (allocated(error)) return
182
+ end subroutine
183
+
65
184
subroutine test_softmax(error)
66
185
type(error_type), allocatable, intent(out) :: error
67
186
0 commit comments