1
1
using NNlib, Test, Zygote
2
2
3
- ACTIVATION_FUNCTIONS = [σ, relu, leakyrelu, relu6, rrelu, elu, gelu, celu, swish, selu, softplus, softsign, logcosh, mish, tanhshrink, softshrink];
3
+ ACTIVATION_FUNCTIONS = [σ, hardσ, hardtanh, relu, leakyrelu, relu6, rrelu, elu, gelu, celu, swish, lisht, selu, trelu , softplus, softsign, logcosh, mish, tanhshrink, softshrink];
4
4
5
5
function test_value_float_precision_preserving (a)
6
6
@testset " $(a) : " begin
37
37
38
38
@testset " Activation Functions" begin
39
39
@test σ (0.0 ) == 0.5
40
+ @test hardσ (0.0 ) == 0.5
41
+ @test hardtanh (0.0 ) == 0.0
40
42
@test relu (0.0 ) == 0.0
41
43
@test leakyrelu (0.0 ) == 0.0
42
44
@test relu6 (0.0 ) == 0.0
43
45
@test rrelu (0.0 ) == 0.0
44
46
@test elu (0.0 ) == 0.0
45
47
@test gelu (0.0 ) == 0.0
46
48
@test swish (0.0 ) == 0.0
49
+ @test lisht (0.0 ) == 0.0
47
50
@test softplus (0.0 ) ≈ log (2.0 )
48
51
@test softplus (1e8 ) ≈ 1e8
49
52
@test softplus (- 1e8 ) ≈ 0.0
50
53
@test softsign (0.0 ) == 0.0
51
54
@test selu (0.0 ) == 0.0
52
55
@test celu (0.0 ) == 0.0
56
+ @test trelu (0.0 ) == 0.0
53
57
@test logcosh (0.0 ) == log (cosh (0.0 ))
54
58
@test mish (0.0 ) == 0.0
55
59
@test tanhshrink (0.0 ) == 0.0
56
60
@test softshrink (0.0 ) == 0.0
57
61
58
62
@test σ (1.0 ) == 1.0 / (1.0 + exp (- 1.0 ))
63
+ @test hardσ (1.0 ) == max (0 ,min (1 ,0.2 * 1.0 + 0.5 ))
64
+ @test hardtanh (1.0 ) == 1.0
59
65
@test relu (1.0 ) == 1.0
60
66
@test leakyrelu (1.0 ) == 1.0
61
67
@test relu6 (1.0 ) == 1.0
62
68
@test rrelu (1.0 ) == 1.0
63
69
@test elu (1.0 ) == 1.0
64
70
@test gelu (1.0 ) == 0.8411919906082768
65
71
@test swish (1.0 ) == 1.0 / (1.0 + exp (- 1.0 ))
72
+ @test lisht (1.0 ) ≈ 1.0 * tanh (1.0 )
66
73
@test softplus (1.0 ) ≈ log (exp (1.0 ) + 1.0 )
67
74
@test softsign (1.0 ) == 0.5
68
75
@test selu (1.0 ) == 1.0507009873554804934193349852946
69
76
@test celu (1.0 ) == 1.0
77
+ @test trelu (1.0 ) == 0.0
70
78
@test logcosh (1.0 ) ≈ log (cosh (1.0 ))
71
79
@test mish (1.0 ) ≈ tanh (log (1.0 + exp (1.0 )))
72
80
@test tanhshrink (1.0 ) ≈ 0.23840584404423515
73
81
@test softshrink (1.0 ) == 0.5
74
82
75
83
@test σ (- 1.0 ) == 1.0 / (1.0 + exp (1.0 ))
84
+ @test hardσ (- 1.0 ) == max (0 ,min (1 ,0.2 *- 1.0 + 0.5 ))
85
+ @test hardtanh (- 1.0 ) == - 1.0
76
86
@test relu (- 1.0 ) == 0.0
77
87
@test leakyrelu (- 1.0 ) == - 0.01
78
88
@test relu6 (- 1.0 ) == 0.0
79
89
@test - 1 / 3.0 <= rrelu (- 1.0 ) <= - 1 / 8.0
80
90
@test elu (- 1.0 ) == exp (- 1.0 ) - 1.0
81
91
@test gelu (- 1.0 ) == - 0.15880800939172324
82
92
@test swish (- 1.0 ) == - 1.0 / (1.0 + exp (1.0 ))
93
+ @test lisht (- 1.0 ) ≈ - 1.0 * tanh (- 1.0 )
83
94
@test softplus (- 1.0 ) ≈ log (exp (- 1.0 ) + 1.0 )
84
95
@test softsign (- 1.0 ) == - 0.5
85
96
@test selu (- 1.0 ) == 1.0507009873554804934193349852946 * 1.6732632423543772848170429916717 * (exp (- 1.0 ) - 1.0 )
86
97
@test celu (- 1.0 ) == exp (- 1.0 ) - 1
98
+ @test trelu (- 1.0 ) == 0.0
87
99
@test log (cosh (- 1.0 )) ≈ log (cosh (- 1.0 ))
88
100
@test mish (- 1.0 ) ≈ - tanh (log (1.0 + exp (- 1.0 )))
89
101
@test tanhshrink (- 1.0 ) ≈ - 0.23840584404423515
101
113
end
102
114
103
115
@testset " Test Integer64 and Integer32 inputs will force Float64 outputs" begin
104
- test_value_int_input_forces_float64 .(filter (x -> (x != relu && x != relu6), ACTIVATION_FUNCTIONS))
116
+ test_value_int_input_forces_float64 .(filter (x -> (x != relu && x != relu6 && x != hardtanh && x != trelu ), ACTIVATION_FUNCTIONS))
105
117
106
118
@testset " relu: " begin
107
119
# relu doesn't have to force floating point outputs
114
126
@test typeof (relu6 (Int64 (1 ))) == Int64
115
127
@test typeof (relu6 (Int32 (1 ))) == Int32
116
128
end
117
-
129
+
130
+ @testset " hardtanh: " begin
131
+ # hardtanh doesn't have to force floating point outputs
132
+ @test typeof (hardtanh (Int64 (1 ))) == Int64
133
+ @test typeof (hardtanh (Int32 (1 ))) == Int32
134
+ end
135
+
136
+ @testset " trelu: " begin
137
+ # trelu doesn't have to force floating point outputs
138
+ @test typeof (trelu (Int64 (1 ))) == Int64
139
+ @test typeof (trelu (Int32 (1 ))) == Int32
140
+ end
118
141
end
119
142
120
143
@testset " Float gradient inference" begin
202
225
end
203
226
204
227
@test logcosh (1_000.0 ) + log (2 ) == 1_000.0
205
- end
228
+
229
+ @testset " hardsigmoid" begin
230
+ @test hardsigmoid (0.3 ) == 0.56
231
+ @test hardsigmoid (- 0.3 ) == 0.44
232
+ @test hardsigmoid (0.1 ,0.5 ) == 0.55
233
+ for T in [:Float32 , :Float64 ]
234
+ @eval @test hardsigmoid .($ T[- 100_000 , 100_000. ]) ≈ $ T[0. , 1. ]
235
+ end
236
+ end
237
+
238
+ @test hardtanh (10.0 ) == 1.0
239
+ @test lisht (2.5 ) == 2.5 * tanh (2.5 )
240
+
241
+ @testset " trelu" begin
242
+ @test trelu (0.5 ) == 0.0
243
+ @test trelu (1.0 ) == 0.0
244
+ @test trelu (1.1 ) == 1.1
245
+ @test trelu (0.9 ,0.5 ) == 0.9
246
+ end
247
+ end
0 commit comments