Skip to content

Commit 57a5df9

Browse files
committed
Add support for rng_from_array and restrict to dropout(::CUDA.RNG, ::CuArray)
1 parent 817f2dc commit 57a5df9

File tree

2 files changed

+32
-8
lines changed

2 files changed

+32
-8
lines changed

src/layers/normalise.jl

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ _dropout_shape(s, dims) = tuple((i ∉ dims ? 1 : si for (i, si) ∈ enumerate(s
1010
_dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0)
1111

1212
"""
13-
dropout([rng = default_rng()], x, p; dims=:, active=true)
13+
dropout([rng = rng_from_array(x)], x, p; dims=:, active=true)
1414
1515
The dropout function. If `active` is `true`,
1616
for each input, either sets that input to `0` (with probability
@@ -36,16 +36,20 @@ function dropout(rng, x, p; dims=:, active::Bool=true)
3636
y = dropout_mask(rng, x, p, dims=dims)
3737
return x .* y
3838
end
39-
dropout(x, p; kwargs...) = dropout(Random.default_rng(), x, p; kwargs...)
40-
dropout(x::CuArray, p; kwargs...) = dropout(CUDA.default_rng(), x, p; kwargs...)
39+
dropout(x, p; kwargs...) = dropout(rng_from_array(x), x, p; kwargs...)
40+
dropout(x::CuArray, p; kwargs...) = dropout(rng_from_array(x), x, p; kwargs...)
4141

4242
@adjoint function dropout(rng, x, p; dims=:, active::Bool=true)
4343
active || return x, Δ -> (Δ, nothing)
4444
y = dropout_mask(rng, x, p, dims=dims)
4545
return x .* y, Δ -> (nothing, Δ .* y, nothing)
4646
end
4747

48-
function dropout_mask(rng::AbstractRNG, x, p; dims=:)
48+
dropout_mask(rng::CUDA.RNG, x::CuArray, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...)
49+
dropout_mask(rng, x::CuArray, p; kwargs...) =
50+
ArgumentError("x isa CuArray, but rng isa $(typeof(rng)). dropout_mask only support CUDA.RNG for CuArrays.")
51+
dropout_mask(rng, x, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...)
52+
function _dropout_mask(rng, x, p; dims=:)
4953
y = rand!(rng, similar(x, _dropout_shape(x, dims)))
5054
y .= _dropout_kernel.(y, p, 1 - p)
5155
return y
@@ -71,9 +75,9 @@ mutable struct Dropout{F,D,R<:AbstractRNG}
7175
active::Union{Bool, Nothing}
7276
rng::R
7377
end
74-
Dropout(p, dims, active) = Dropout(p, dims, active, Random.default_rng())
78+
Dropout(p, dims, active) = Dropout(p, dims, active, rng_from_array())
7579

76-
function Dropout(p; dims=:, rng = Random.default_rng())
80+
function Dropout(p; dims=:, rng = rng_from_array())
7781
@assert 0 p 1
7882
Dropout(p, dims, nothing, rng)
7983
end
@@ -115,8 +119,8 @@ mutable struct AlphaDropout{F,R<:AbstractRNG}
115119
new{typeof(p), typeof(rng)}(p, active, rng)
116120
end
117121
end
118-
AlphaDropout(p, active) = AlphaDropout(p, active, Random.default_rng())
119-
AlphaDropout(p; rng = Random.default_rng()) = AlphaDropout(p, nothing, rng)
122+
AlphaDropout(p, active) = AlphaDropout(p, active, rng_from_array())
123+
AlphaDropout(p; rng = rng_from_array()) = AlphaDropout(p, nothing, rng)
120124

121125
@functor AlphaDropout
122126

src/utils.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,26 @@ nfan(dims...) = prod(dims[1:end-2]) .* (dims[end-1], dims[end]) # In case of con
3333
ofeltype(x, y) = convert(float(eltype(x)), y)
3434
epseltype(x) = eps(float(eltype(x)))
3535

36+
"""
37+
rng_from_array()
38+
rng_from_array(x)
39+
40+
Create an instance of the RNG most appropriate for `x`.
41+
The current defaults are:
42+
- `x isa AbstractArray`
43+
- Julia version is < 1.7: `Random.GLOBAL_RNG`
44+
- Julia version is >= 1.7: `Random.default_rng()`
45+
- `x isa CuArray`: `CUDA.default_rng()`
46+
When `x` is unspecified, it is assumed to be a `AbstractArray`.
47+
"""
48+
if VERSION >= v"1.7"
49+
rng_from_array() = Random.default_rng()
50+
else
51+
rng_from_array() = Random.GLOBAL_RNG
52+
end
53+
rng_from_array(::AbstractArray) = rng_from_array()
54+
rng_from_array(::CuArray) = CUDA.default_rng()
55+
3656
"""
3757
glorot_uniform([rng=GLOBAL_RNG], dims...)
3858

0 commit comments

Comments
 (0)