@@ -28,8 +28,8 @@ contains
28
28
new_unittest("selu" , test_selu), &
29
29
new_unittest("sigmoid", test_sigmoid), &
30
30
new_unittest("silu" , test_silu), &
31
- new_unittest("softmax", test_softmax) &
32
- new_unittest("logsoftmax", test_logsoftmax), &
31
+ new_unittest("softmax", test_softmax), &
32
+ new_unittest("logsoftmax", test_logsoftmax) &
33
33
]
34
34
end subroutine collect_specialfunctions_activation
35
35
@@ -81,11 +81,11 @@ contains
81
81
subroutine test_relu(error)
82
82
type(error_type), allocatable, intent(out) :: error
83
83
integer, parameter :: n = 10
84
- real(sp) :: x(n), y(n), y_ref(n), a
84
+ real(sp) :: x(n), y(n), y_ref(n)
85
85
86
86
x = linspace(-2._sp, 2._sp, n)
87
87
y_ref = max(0._sp, x)
88
- y = relu( x , a )
88
+ y = relu( x )
89
89
call check(error, norm2(y-y_ref) < n*tol_sp )
90
90
if (allocated(error)) return
91
91
@@ -95,7 +95,7 @@ contains
95
95
elsewhere
96
96
y_ref = 0.0_sp
97
97
end where
98
- y = relu_grad( x , a )
98
+ y = relu_grad( x )
99
99
call check(error, norm2(y-y_ref) < n*tol_sp )
100
100
if (allocated(error)) return
101
101
end subroutine
@@ -113,7 +113,7 @@ contains
113
113
elsewhere
114
114
y_ref = scale * (alpha * exp(x) - alpha)
115
115
end where
116
- y = selu( x , a )
116
+ y = selu( x )
117
117
call check(error, norm2(y-y_ref) < n*tol_sp )
118
118
if (allocated(error)) return
119
119
@@ -123,7 +123,7 @@ contains
123
123
elsewhere
124
124
y_ref = scale * alpha * exp(x)
125
125
end where
126
- y = selu_grad( x , a )
126
+ y = selu_grad( x )
127
127
call check(error, norm2(y-y_ref) < n*tol_sp )
128
128
if (allocated(error)) return
129
129
end subroutine
0 commit comments