Skip to content

Commit 2855c5b

Browse files
committed
added more proper test of gradients of softmax and logitsoftmax
1 parent e7c1611 commit 2855c5b

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

test/runtests.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using NNlib, Test
1+
using NNlib, Test, Flux
22

33
@testset "NNlib" begin
44

@@ -29,6 +29,14 @@ xs = Float32[1, 2, 3000.]
2929

3030
xs = Float32[1 2 3; 1000 2000 3000]
3131
@test logsoftmax(xs) [-999 -1998 -2997; 0 0 0.]
32+
3233
@test NNlib.∇logsoftmax(ones(size(xs)), xs) zeros(Float32, size(xs))
3334
@test NNlib.∇softmax(ones(size(xs)), xs) zeros(Float32, size(xs))
35+
gradtest(f, xs::AbstractArray...) = Flux.Tracker.gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
36+
@test gradtest(logsoftmax, xs)
37+
@test gradtest(softmax, xs)
38+
39+
xs = randn(5, 10)
40+
@test gradtest(logsoftmax, xs)
41+
@test gradtest(softmax, xs)
3442
end

0 commit comments

Comments
 (0)