Skip to content

Commit 8c3e994

Browse files
committed
Added tests for softmax variants which mutate their inputs.
1 parent bf027df commit 8c3e994

File tree

1 file changed

+71
-8
lines changed

1 file changed

+71
-8
lines changed

test/activation.jl

Lines changed: 71 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ end
6969
@test elu(1.0) == 1.0
7070
@test gelu(1.0) == 0.8411919906082768
7171
@test swish(1.0) == 1.0 / (1.0 + exp(-1.0))
72-
@test lisht(1.0) 1.0 * tanh(1.0)
72+
@test lisht(1.0) 1.0 * tanh(1.0)
7373
@test softplus(1.0) log(exp(1.0) + 1.0)
7474
@test softsign(1.0) == 0.5
7575
@test selu(1.0) == 1.0507009873554804934193349852946
@@ -126,20 +126,20 @@ end
126126
@test typeof(relu6(Int64(1))) == Int64
127127
@test typeof(relu6(Int32(1))) == Int32
128128
end
129-
129+
130130
@testset "hardtanh: " begin
131131
# hardtanh doesn't have to force floating point outputs
132132
@test typeof(hardtanh(Int64(1))) == Int64
133133
@test typeof(hardtanh(Int32(1))) == Int32
134134
end
135-
135+
136136
@testset "trelu: " begin
137137
# trelu doesn't have to force floating point outputs
138138
@test typeof(trelu(Int64(1))) == Int64
139139
@test typeof(trelu(Int32(1))) == Int32
140140
end
141141
end
142-
142+
143143
@testset "Float gradient inference" begin
144144
test_gradient_float_precision_preserving.(ACTIVATION_FUNCTIONS)
145145
end
@@ -201,7 +201,7 @@ end
201201
@test leakyrelu(-0.4,0.3) -0.12
202202

203203
@test relu6(10.0) == 6.0
204-
@test -0.2 <= rrelu(-0.4,0.25,0.5) <= -0.1
204+
@test -0.2 <= rrelu(-0.4,0.25,0.5) <= -0.1
205205

206206
@testset "celu" begin
207207
@test celu(42) == 42
@@ -225,7 +225,7 @@ end
225225
end
226226

227227
@test logcosh(1_000.0) + log(2) == 1_000.0
228-
228+
229229
@testset "hardsigmoid" begin
230230
@test hardsigmoid(0.3) == 0.56
231231
@test hardsigmoid(-0.3) == 0.44
@@ -234,14 +234,77 @@ end
234234
@eval @test hardsigmoid.($T[-100_000, 100_000.]) $T[0., 1.]
235235
end
236236
end
237-
237+
238238
@test hardtanh(10.0) == 1.0
239239
@test lisht(2.5) == 2.5*tanh(2.5)
240-
240+
241241
@testset "trelu" begin
242242
@test trelu(0.5) == 0.0
243243
@test trelu(1.0) == 0.0
244244
@test trelu(1.1) == 1.1
245245
@test trelu(0.9,0.5) == 0.9
246246
end
247+
248+
@testset "mutating softmax" begin
249+
xs = Float64[1 2 3; 5 6 7]
250+
251+
out = zeros(Float64, size(xs))
252+
NNlib.softmax!(out, xs)
253+
@test isapprox(out, softmax(xs); rtol=1e-6)
254+
NNlib.logsoftmax!(out, xs)
255+
@test isapprox(out, logsoftmax(xs); rtol=1e-6)
256+
257+
out = ones(Float64, size(xs))
258+
NNlib.softmax!(out, xs)
259+
@test isapprox(out, softmax(xs); rtol=1e-6)
260+
NNlib.logsoftmax!(out, xs)
261+
@test isapprox(out, logsoftmax(xs); rtol=1e-6)
262+
263+
out = zeros(Float64, size(xs))
264+
NNlib.∇softmax!(out, xs)
265+
@test isapprox(out, NNlib.∇softmax(zeros(size(xs)), xs); rtol=1e-6)
266+
out = zeros(Float64, size(xs))
267+
NNlib.∇logsoftmax!(out, xs)
268+
@test isapprox(out, NNlib.∇softmax(zeros(size(xs)), xs); rtol=1e-6)
269+
270+
out = ones(Float64, size(xs))
271+
NNlib.∇softmax!(out, xs)
272+
@test isapprox(out, NNlib.∇softmax(ones(size(xs)), xs); rtol=1e-6)
273+
out = ones(Float64, size(xs))
274+
NNlib.∇logsoftmax!(out, xs)
275+
@test isapprox(out, NNlib.∇softmax(ones(size(xs)), xs); rtol=1e-6)
276+
277+
xs = [
278+
-0.238639 0.748142 -0.283194 -0.525461 -1.5348 -0.797842;
279+
0.690384 0.211427 0.254794 -0.213572 -0.314174 -0.372663;
280+
-1.146370 -0.577988 0.718952 0.919720 -0.620773 0.929977
281+
]
282+
283+
out = zeros(Float64, size(xs))
284+
NNlib.softmax!(out, xs)
285+
@test isapprox(out, softmax(xs); rtol=1e-6)
286+
NNlib.logsoftmax!(out, xs)
287+
@test isapprox(out, logsoftmax(xs); rtol=1e-6)
288+
289+
out = ones(Float64, size(xs))
290+
NNlib.softmax!(out, xs)
291+
@test isapprox(out, softmax(xs); rtol=1e-6)
292+
NNlib.logsoftmax!(out, xs)
293+
@test isapprox(out, logsoftmax(xs); rtol=1e-6)
294+
295+
out = zeros(Float64, size(xs))
296+
NNlib.∇softmax!(out, xs)
297+
@test isapprox(out, NNlib.∇softmax(zeros(size(xs)), xs); rtol=1e-6)
298+
out = zeros(Float64, size(xs))
299+
NNlib.∇logsoftmax!(out, xs)
300+
@test isapprox(out, NNlib.∇softmax(zeros(size(xs)), xs); rtol=1e-6)
301+
302+
out = ones(Float64, size(xs))
303+
NNlib.∇softmax!(out, xs)
304+
@test isapprox(out, NNlib.∇softmax(ones(size(xs)), xs); rtol=1e-6)
305+
out = ones(Float64, size(xs))
306+
NNlib.∇logsoftmax!(out, xs)
307+
@test isapprox(out, NNlib.∇softmax(ones(size(xs)), xs); rtol=1e-6)
308+
end
309+
247310
end

0 commit comments

Comments
 (0)