Skip to content

Commit 4c1afde

Browse files
committed
add tests
1 parent 1914e78 commit 4c1afde

File tree

2 files changed

+128
-8
lines changed

2 files changed

+128
-8
lines changed

test/specialfunctions/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ set(fppFiles
99
fypp_f90("${fyppFlags}" "${fppFiles}" outFiles)
1010

1111
ADDTEST(specialfunctions_gamma)
12+
ADDTEST(specialfunctions_activations)

test/specialfunctions/test_specialfunctions_activations.fypp

Lines changed: 127 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,24 +21,109 @@ contains
2121
type(unittest_type), allocatable, intent(out) :: testsuite(:)
2222

2323
testsuite = [ &
24+
new_unittest("gaussian", test_gaussian), &
25+
new_unittest("elu", test_elu), &
26+
new_unittest("relu", test_relu), &
27+
new_unittest("gelu" , test_gelu), &
28+
new_unittest("selu" , test_selu), &
2429
new_unittest("sigmoid", test_sigmoid), &
25-
new_unittest("logsoftmax", test_logsoftmax), &
26-
new_unittest("gelu" , test_gelu ), &
30+
new_unittest("silu" , test_silu), &
2731
new_unittest("softmax", test_softmax) &
32+
new_unittest("logsoftmax", test_logsoftmax), &
2833
]
2934
end subroutine collect_specialfunctions_activation
3035

31-
subroutine test_sigmoid(error)
36+
subroutine test_gaussian(error)
3237
type(error_type), allocatable, intent(out) :: error
3338
integer, parameter :: n = 10
3439
real(sp) :: x(n), y(n), y_ref(n)
3540

36-
y_ref = [0.119202919304371, 0.174285307526588, 0.247663781046867,&
37-
0.339243650436401, 0.444671928882599, 0.555328071117401,&
38-
0.660756349563599, 0.752336204051971, 0.825714707374573,&
39-
0.880797028541565]
4041
x = linspace(-2._sp, 2._sp, n)
41-
y = sigmoid( x )
42+
y_ref = exp(-x**2)
43+
y = gaussian( x )
44+
call check(error, norm2(y-y_ref) < n*tol_sp )
45+
if (allocated(error)) return
46+
47+
! Derivative
48+
y_ref = -2.0 * x * exp(-x**2)
49+
y = gaussian_grad( x )
50+
call check(error, norm2(y-y_ref) < n*tol_sp )
51+
if (allocated(error)) return
52+
end subroutine
53+
54+
subroutine test_elu(error)
55+
type(error_type), allocatable, intent(out) :: error
56+
integer, parameter :: n = 10
57+
real(sp) :: x(n), y(n), y_ref(n), a
58+
59+
x = linspace(-2._sp, 2._sp, n)
60+
a = 1.0_sp
61+
where(x >= 0._sp)
62+
y_ref = x
63+
elsewhere
64+
y_ref = a * (exp(x) - 1._sp)
65+
end where
66+
y = elu( x , a )
67+
call check(error, norm2(y-y_ref) < n*tol_sp )
68+
if (allocated(error)) return
69+
70+
! Derivative
71+
where(x >= 0._sp)
72+
y_ref = 1.0_sp
73+
elsewhere
74+
y_ref = a * exp(x)
75+
end where
76+
y = elu_grad( x , a )
77+
call check(error, norm2(y-y_ref) < n*tol_sp )
78+
if (allocated(error)) return
79+
end subroutine
80+
81+
subroutine test_relu(error)
82+
type(error_type), allocatable, intent(out) :: error
83+
integer, parameter :: n = 10
84+
real(sp) :: x(n), y(n), y_ref(n), a
85+
86+
x = linspace(-2._sp, 2._sp, n)
87+
y_ref = max(0._sp, x)
88+
y = relu( x , a )
89+
call check(error, norm2(y-y_ref) < n*tol_sp )
90+
if (allocated(error)) return
91+
92+
! Derivative
93+
where(x > 0._sp)
94+
y_ref = 1.0_sp
95+
elsewhere
96+
y_ref = 0.0_sp
97+
end where
98+
y = relu_grad( x , a )
99+
call check(error, norm2(y-y_ref) < n*tol_sp )
100+
if (allocated(error)) return
101+
end subroutine
102+
103+
subroutine test_selu(error)
104+
type(error_type), allocatable, intent(out) :: error
105+
integer, parameter :: n = 10
106+
real(sp), parameter :: scale = 1.0507009873554804934193349852946_sp
107+
real(sp), parameter :: alpha = 1.6732632423543772848170429916717_sp
108+
real(sp) :: x(n), y(n), y_ref(n)
109+
110+
x = linspace(-2._sp, 2._sp, n)
111+
where(x >= 0._sp)
112+
y_ref = scale * x
113+
elsewhere
114+
y_ref = scale * (alpha * exp(x) - alpha)
115+
end where
116+
y = selu( x , a )
117+
call check(error, norm2(y-y_ref) < n*tol_sp )
118+
if (allocated(error)) return
119+
120+
! Derivative
121+
where(x >= 0._sp)
122+
y_ref = scale
123+
elsewhere
124+
y_ref = scale * alpha * exp(x)
125+
end where
126+
y = selu_grad( x , a )
42127
call check(error, norm2(y-y_ref) < n*tol_sp )
43128
if (allocated(error)) return
44129
end subroutine
@@ -62,6 +147,40 @@ contains
62147
if (allocated(error)) return
63148
end subroutine
64149

150+
subroutine test_sigmoid(error)
151+
type(error_type), allocatable, intent(out) :: error
152+
integer, parameter :: n = 10
153+
real(sp) :: x(n), y(n), y_ref(n)
154+
155+
y_ref = [0.119202919304371, 0.174285307526588, 0.247663781046867,&
156+
0.339243650436401, 0.444671928882599, 0.555328071117401,&
157+
0.660756349563599, 0.752336204051971, 0.825714707374573,&
158+
0.880797028541565]
159+
x = linspace(-2._sp, 2._sp, n)
160+
y = sigmoid( x )
161+
call check(error, norm2(y-y_ref) < n*tol_sp )
162+
if (allocated(error)) return
163+
end subroutine
164+
165+
subroutine test_silu(error)
166+
type(error_type), allocatable, intent(out) :: error
167+
integer, parameter :: n = 10
168+
real(sp) :: x(n), y(n), y_ref(n), a
169+
170+
x = linspace(-2._sp, 2._sp, n)
171+
y_ref = x / (1._sp + exp(-x))
172+
y = silu( x )
173+
call check(error, norm2(y-y_ref) < n*tol_sp )
174+
if (allocated(error)) return
175+
176+
! Derivative
177+
y_ref = (1._sp + exp(x))**2
178+
y_ref = exp(x) * ( x + y_ref ) / y_ref
179+
y = silu_grad( x )
180+
call check(error, norm2(y-y_ref) < n*tol_sp )
181+
if (allocated(error)) return
182+
end subroutine
183+
65184
subroutine test_softmax(error)
66185
type(error_type), allocatable, intent(out) :: error
67186

0 commit comments

Comments
 (0)