69
69
@test elu (1.0 ) == 1.0
70
70
@test gelu (1.0 ) == 0.8411919906082768
71
71
@test swish (1.0 ) == 1.0 / (1.0 + exp (- 1.0 ))
72
- @test lisht (1.0 ) ≈ 1.0 * tanh (1.0 )
72
+ @test lisht (1.0 ) ≈ 1.0 * tanh (1.0 )
73
73
@test softplus (1.0 ) ≈ log (exp (1.0 ) + 1.0 )
74
74
@test softsign (1.0 ) == 0.5
75
75
@test selu (1.0 ) == 1.0507009873554804934193349852946
@@ -126,20 +126,20 @@ end
126
126
@test typeof (relu6 (Int64 (1 ))) == Int64
127
127
@test typeof (relu6 (Int32 (1 ))) == Int32
128
128
end
129
-
129
+
130
130
@testset " hardtanh: " begin
131
131
# hardtanh doesn't have to force floating point outputs
132
132
@test typeof (hardtanh (Int64 (1 ))) == Int64
133
133
@test typeof (hardtanh (Int32 (1 ))) == Int32
134
134
end
135
-
135
+
136
136
@testset " trelu: " begin
137
137
# trelu doesn't have to force floating point outputs
138
138
@test typeof (trelu (Int64 (1 ))) == Int64
139
139
@test typeof (trelu (Int32 (1 ))) == Int32
140
140
end
141
141
end
142
-
142
+
143
143
@testset " Float gradient inference" begin
144
144
test_gradient_float_precision_preserving .(ACTIVATION_FUNCTIONS)
145
145
end
201
201
@test leakyrelu (- 0.4 ,0.3 ) ≈ - 0.12
202
202
203
203
@test relu6 (10.0 ) == 6.0
204
- @test - 0.2 <= rrelu (- 0.4 ,0.25 ,0.5 ) <= - 0.1
204
+ @test - 0.2 <= rrelu (- 0.4 ,0.25 ,0.5 ) <= - 0.1
205
205
206
206
@testset " celu" begin
207
207
@test celu (42 ) == 42
225
225
end
226
226
227
227
@test logcosh (1_000.0 ) + log (2 ) == 1_000.0
228
-
228
+
229
229
@testset " hardsigmoid" begin
230
230
@test hardsigmoid (0.3 ) == 0.56
231
231
@test hardsigmoid (- 0.3 ) == 0.44
@@ -234,14 +234,77 @@ end
234
234
@eval @test hardsigmoid .($ T[- 100_000 , 100_000. ]) ≈ $ T[0. , 1. ]
235
235
end
236
236
end
237
-
237
+
238
238
@test hardtanh (10.0 ) == 1.0
239
239
@test lisht (2.5 ) == 2.5 * tanh (2.5 )
240
-
240
+
241
241
@testset " trelu" begin
242
242
@test trelu (0.5 ) == 0.0
243
243
@test trelu (1.0 ) == 0.0
244
244
@test trelu (1.1 ) == 1.1
245
245
@test trelu (0.9 ,0.5 ) == 0.9
246
246
end
247
+
248
+ @testset " mutating softmax" begin
249
+ xs = Float64[1 2 3 ; 5 6 7 ]
250
+
251
+ out = zeros (Float64, size (xs))
252
+ NNlib. softmax! (out, xs)
253
+ @test isapprox (out, softmax (xs); rtol= 1e-6 )
254
+ NNlib. logsoftmax! (out, xs)
255
+ @test isapprox (out, logsoftmax (xs); rtol= 1e-6 )
256
+
257
+ out = ones (Float64, size (xs))
258
+ NNlib. softmax! (out, xs)
259
+ @test isapprox (out, softmax (xs); rtol= 1e-6 )
260
+ NNlib. logsoftmax! (out, xs)
261
+ @test isapprox (out, logsoftmax (xs); rtol= 1e-6 )
262
+
263
+ out = zeros (Float64, size (xs))
264
+ NNlib.∇softmax! (out, xs)
265
+ @test isapprox (out, NNlib.∇softmax (zeros (size (xs)), xs); rtol= 1e-6 )
266
+ out = zeros (Float64, size (xs))
267
+ NNlib.∇logsoftmax! (out, xs)
268
+ @test isapprox (out, NNlib.∇softmax (zeros (size (xs)), xs); rtol= 1e-6 )
269
+
270
+ out = ones (Float64, size (xs))
271
+ NNlib.∇softmax! (out, xs)
272
+ @test isapprox (out, NNlib.∇softmax (ones (size (xs)), xs); rtol= 1e-6 )
273
+ out = ones (Float64, size (xs))
274
+ NNlib.∇logsoftmax! (out, xs)
275
+ @test isapprox (out, NNlib.∇softmax (ones (size (xs)), xs); rtol= 1e-6 )
276
+
277
+ xs = [
278
+ - 0.238639 0.748142 - 0.283194 - 0.525461 - 1.5348 - 0.797842 ;
279
+ 0.690384 0.211427 0.254794 - 0.213572 - 0.314174 - 0.372663 ;
280
+ - 1.146370 - 0.577988 0.718952 0.919720 - 0.620773 0.929977
281
+ ]
282
+
283
+ out = zeros (Float64, size (xs))
284
+ NNlib. softmax! (out, xs)
285
+ @test isapprox (out, softmax (xs); rtol= 1e-6 )
286
+ NNlib. logsoftmax! (out, xs)
287
+ @test isapprox (out, logsoftmax (xs); rtol= 1e-6 )
288
+
289
+ out = ones (Float64, size (xs))
290
+ NNlib. softmax! (out, xs)
291
+ @test isapprox (out, softmax (xs); rtol= 1e-6 )
292
+ NNlib. logsoftmax! (out, xs)
293
+ @test isapprox (out, logsoftmax (xs); rtol= 1e-6 )
294
+
295
+ out = zeros (Float64, size (xs))
296
+ NNlib.∇softmax! (out, xs)
297
+ @test isapprox (out, NNlib.∇softmax (zeros (size (xs)), xs); rtol= 1e-6 )
298
+ out = zeros (Float64, size (xs))
299
+ NNlib.∇logsoftmax! (out, xs)
300
+ @test isapprox (out, NNlib.∇softmax (zeros (size (xs)), xs); rtol= 1e-6 )
301
+
302
+ out = ones (Float64, size (xs))
303
+ NNlib.∇softmax! (out, xs)
304
+ @test isapprox (out, NNlib.∇softmax (ones (size (xs)), xs); rtol= 1e-6 )
305
+ out = ones (Float64, size (xs))
306
+ NNlib.∇logsoftmax! (out, xs)
307
+ @test isapprox (out, NNlib.∇softmax (ones (size (xs)), xs); rtol= 1e-6 )
308
+ end
309
+
247
310
end
0 commit comments