@@ -28,8 +28,8 @@ contains
2828 new_unittest("selu" , test_selu), &
2929 new_unittest("sigmoid", test_sigmoid), &
3030 new_unittest("silu" , test_silu), &
31- new_unittest("softmax", test_softmax) &
32- new_unittest("logsoftmax", test_logsoftmax), &
31+ new_unittest("softmax", test_softmax), &
32+ new_unittest("logsoftmax", test_logsoftmax) &
3333 ]
3434 end subroutine collect_specialfunctions_activation
3535
@@ -81,11 +81,11 @@ contains
8181 subroutine test_relu(error)
8282 type(error_type), allocatable, intent(out) :: error
8383 integer, parameter :: n = 10
84- real(sp) :: x(n), y(n), y_ref(n), a
84+ real(sp) :: x(n), y(n), y_ref(n)
8585
8686 x = linspace(-2._sp, 2._sp, n)
8787 y_ref = max(0._sp, x)
88- y = relu( x , a )
88+ y = relu( x )
8989 call check(error, norm2(y-y_ref) < n*tol_sp )
9090 if (allocated(error)) return
9191
@@ -95,7 +95,7 @@ contains
9595 elsewhere
9696 y_ref = 0.0_sp
9797 end where
98- y = relu_grad( x , a )
98+ y = relu_grad( x )
9999 call check(error, norm2(y-y_ref) < n*tol_sp )
100100 if (allocated(error)) return
101101 end subroutine
@@ -113,7 +113,7 @@ contains
113113 elsewhere
114114 y_ref = scale * (alpha * exp(x) - alpha)
115115 end where
116- y = selu( x , a )
116+ y = selu( x )
117117 call check(error, norm2(y-y_ref) < n*tol_sp )
118118 if (allocated(error)) return
119119
@@ -123,7 +123,7 @@ contains
123123 elsewhere
124124 y_ref = scale * alpha * exp(x)
125125 end where
126- y = selu_grad( x , a )
126+ y = selu_grad( x )
127127 call check(error, norm2(y-y_ref) < n*tol_sp )
128128 if (allocated(error)) return
129129 end subroutine
0 commit comments