Skip to content

Commit a51c7c7

Browse files
committed
Add comment on complex, extend Dual number testset
1 parent cd041f4 commit a51c7c7

File tree

1 file changed

+8
-15
lines changed

1 file changed

+8
-15
lines changed

ext/NNlibCUDA/test/activations.jl

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,23 +7,16 @@
77
end
88

99
@testset "forward diff" begin
10-
f = (x) -> logσ.(x)
11-
ds = Dual.(rand(5),1)
12-
@test f(ds) collect(f(CuArray(ds)))
13-
f = (x) -> tanh.(x)
14-
ds = Dual.(rand(5),1)
15-
@test f(ds) collect(f(CuArray(ds)))
16-
f = (x) -> σ.(x)
17-
ds = Dual.(rand(5),1)
18-
@test f(ds) collect(f(CuArray(ds)))
19-
f = (x) -> elu.(x)
20-
ds = Dual.(rand(5),1)
21-
@test f(ds) collect(f(CuArray(ds)))
22-
f = (x) -> relu.(x)
23-
ds = Dual.(rand(5),1)
24-
@test f(ds) collect(f(CuArray(ds)))
10+
for f in NNlib.ACTIVATIONS
11+
if f [:rrelu]
12+
@eval gputest(x -> $f.(x), Dual.(rand(5), 1))
13+
end
14+
end
2515
end
2616

17+
# Broadcasting over complex CuArray works without NNlibCUDA, this test checks that
18+
# NNlibCUDA does not cause such operations to take a fast path which does not support
19+
# complex numbers (e.g. CUDNN)
2720
@testset "complex" begin
2821
f(x) = tanh.(x)
2922
cs = rand(ComplexF64, 5)

0 commit comments

Comments
 (0)