Skip to content

Commit 45180b7

Browse files
authored
update softmax tests to match NNlib 0.8.3 (#44)
* update tests for NNlib 0.8.3 * lower precision
1 parent 78b09f5 commit 45180b7

File tree

2 files changed

+14
-6
lines changed

2 files changed

+14
-6
lines changed

ext/NNlibCUDA/Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "NNlibCUDA"
22
uuid = "a00861dc-f156-4864-bf3c-e6376f28a68d"
3-
version = "0.2.1"
3+
version = "0.2.2"
44

55
[deps]
66
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
@@ -11,7 +11,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1111

1212
[compat]
1313
CUDA = "3.3.1"
14-
NNlib = "0.8"
14+
NNlib = "0.8.3"
1515
julia = "1.6"
1616

1717
[extras]

ext/NNlibCUDA/test/softmax.jl

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,20 @@
11
@testset "softmax" begin
22
for (sz, dims) in [((5,), :), ((5,), 1), ((5,5), :), ((5,5), 1), ((5,5), 2), ((5,5,5,5), (2,3)), ((5,5,5,5), (2,4))]
33
x = randn(Float64, sz)
4-
y = softmax(x, dims=dims)
54
dy = randn(Float64, sz)
5+
6+
y = softmax(x, dims=dims)
67
gputest(softmax, x, dims=dims)
7-
gputest(∇softmax, dy, x, y, dims=dims, checkgrad=false)
8-
y = logsoftmax(x, dims=dims)
8+
gputest(NNlib.∇softmax_data, dy, y; dims=dims)
9+
10+
y2 = logsoftmax(x, dims=dims)
911
gputest(logsoftmax, x, dims=dims)
10-
gputest(∇logsoftmax, dy, x, y, dims=dims, checkgrad=false)
12+
gputest(NNlib.∇logsoftmax_data, dy, y2; dims=dims)
13+
14+
# From NNlib 0.8.3, ∇softmax! is not used in the gradient.
15+
# But NNlibCUDA still knows how to call CUDNN routines, let's test they agree:
16+
@test NNlib.∇softmax_data(dy, y; dims=dims) collect(∇softmax!(similar(cu(x)), cu(dy), cu(x), cu(y); dims=dims)) atol=1e-4
17+
@test NNlib.∇logsoftmax_data(dy, y2; dims=dims) collect(∇logsoftmax!(similar(cu(x)), cu(dy), cu(x), cu(y2); dims=dims)) atol=1e-4
18+
# (Note that ∇softmax! does not depend on x, it's just there to disambiguate from an even older signature.)
1119
end
1220
end

0 commit comments

Comments
 (0)