Skip to content

Commit 0f0c7b0

Browse files
committed
Add fix and tests
1 parent 45180b7 commit 0f0c7b0

File tree

2 files changed

+21
-3
lines changed

2 files changed

+21
-3
lines changed

ext/NNlibCUDA/src/cudnn/activations.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@ for (f, op) in [
1818
@eval begin
1919
# in-place
2020
function Base.materialize!(dst::DenseCuArray{<:CUDNNFloat},
21-
bc::Broadcast.Broadcasted{<:Any,<:Any,typeof($f),<:Tuple{DenseCuArray}})
21+
bc::Broadcast.Broadcasted{<:Any,<:Any,typeof($f),<:Tuple{DenseCuArray{<:CUDNNFloat}}})
2222
$op(bc.args[1], dst)
2323
return dst
2424
end
2525

2626
# out of place
27-
function Base.materialize(bc::Broadcast.Broadcasted{<:Any,<:Any,typeof($f),<:Tuple{DenseCuArray}})
27+
function Base.materialize(bc::Broadcast.Broadcasted{<:Any,<:Any,typeof($f),<:Tuple{DenseCuArray{<:CUDNNFloat}}})
2828
ElType = Broadcast.combine_eltypes(bc.f, bc.args)
2929
dst = similar(bc, ElType)
3030
$op(bc.args[1], dst)

ext/NNlibCUDA/test/activations.jl

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,27 @@
77
end
88

99
@testset "forward diff" begin
10-
f(x) = logσ.(x)
10+
f = (x) -> logσ.(x)
1111
ds = Dual.(rand(5),1)
1212
@test f(ds) collect(f(CuArray(ds)))
13+
f = (x) -> tanh.(x)
14+
ds = Dual.(rand(5),1)
15+
@test f(ds) collect(f(CuArray(ds)))
16+
f = (x) -> σ.(x)
17+
ds = Dual.(rand(5),1)
18+
@test f(ds) collect(f(CuArray(ds)))
19+
f = (x) -> elu.(x)
20+
ds = Dual.(rand(5),1)
21+
@test f(ds) collect(f(CuArray(ds)))
22+
f = (x) -> relu.(x)
23+
ds = Dual.(rand(5),1)
24+
@test f(ds) collect(f(CuArray(ds)))
25+
end
26+
27+
@testset "complex" begin
28+
f = (x) -> tanh.(x)
29+
cs = rand(ComplexF64, 5)
30+
@test f(cs) collect(f(CuArray(cs)))
1331
end
1432

1533
@testset "softplus" begin

0 commit comments

Comments
 (0)