Skip to content

Commit eb040ed

Browse files
authored
Merge pull request #48 from DomCRose/activation_eltype_restriction
Restrict element type of activation overrides to CUDNN datatypes
2 parents 45180b7 + a51c7c7 commit eb040ed

File tree

2 files changed

+16
-5
lines changed

2 files changed

+16
-5
lines changed

ext/NNlibCUDA/src/cudnn/activations.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@ for (f, op) in [
1818
@eval begin
1919
# in-place
2020
function Base.materialize!(dst::DenseCuArray{<:CUDNNFloat},
21-
bc::Broadcast.Broadcasted{<:Any,<:Any,typeof($f),<:Tuple{DenseCuArray}})
21+
bc::Broadcast.Broadcasted{<:Any,<:Any,typeof($f),<:Tuple{DenseCuArray{<:CUDNNFloat}}})
2222
$op(bc.args[1], dst)
2323
return dst
2424
end
2525

2626
# out of place
27-
function Base.materialize(bc::Broadcast.Broadcasted{<:Any,<:Any,typeof($f),<:Tuple{DenseCuArray}})
27+
function Base.materialize(bc::Broadcast.Broadcasted{<:Any,<:Any,typeof($f),<:Tuple{DenseCuArray{<:CUDNNFloat}}})
2828
ElType = Broadcast.combine_eltypes(bc.f, bc.args)
2929
dst = similar(bc, ElType)
3030
$op(bc.args[1], dst)

ext/NNlibCUDA/test/activations.jl

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,20 @@
77
end
88

99
@testset "forward diff" begin
10-
f(x) = logσ.(x)
11-
ds = Dual.(rand(5),1)
12-
@test f(ds) collect(f(CuArray(ds)))
10+
for f in NNlib.ACTIVATIONS
11+
if f [:rrelu]
12+
@eval gputest(x -> $f.(x), Dual.(rand(5), 1))
13+
end
14+
end
15+
end
16+
17+
# Broadcasting over complex CuArray works without NNlibCUDA, this test checks that
18+
# NNlibCUDA does not cause such operations to take a fast path which does not support
19+
# complex numbers (e.g. CUDNN)
20+
@testset "complex" begin
21+
f(x) = tanh.(x)
22+
cs = rand(ComplexF64, 5)
23+
@test f(cs) collect(f(CuArray(cs)))
1324
end
1425

1526
@testset "softplus" begin

0 commit comments

Comments
 (0)