@@ -10,7 +10,7 @@ _dropout_shape(s, dims) = tuple((i ∉ dims ? 1 : si for (i, si) ∈ enumerate(s
10
10
_dropout_kernel (y:: T , p, q) where {T} = y > p ? T (1 / q) : T (0 )
11
11
12
12
"""
13
- dropout([rng = default_rng( )], x, p; dims=:, active=true)
13
+ dropout([rng = rng_from_array(x )], x, p; dims=:, active=true)
14
14
15
15
The dropout function. If `active` is `true`,
16
16
for each input, either sets that input to `0` (with probability
@@ -36,16 +36,20 @@ function dropout(rng, x, p; dims=:, active::Bool=true)
36
36
y = dropout_mask (rng, x, p, dims= dims)
37
37
return x .* y
38
38
end
39
- dropout (x, p; kwargs... ) = dropout (Random . default_rng ( ), x, p; kwargs... )
40
- dropout (x:: CuArray , p; kwargs... ) = dropout (CUDA . default_rng ( ), x, p; kwargs... )
39
+ dropout (x, p; kwargs... ) = dropout (rng_from_array (x ), x, p; kwargs... )
40
+ dropout (x:: CuArray , p; kwargs... ) = dropout (rng_from_array (x ), x, p; kwargs... )
41
41
42
42
@adjoint function dropout (rng, x, p; dims= :, active:: Bool = true )
43
43
active || return x, Δ -> (Δ, nothing )
44
44
y = dropout_mask (rng, x, p, dims= dims)
45
45
return x .* y, Δ -> (nothing , Δ .* y, nothing )
46
46
end
47
47
48
- function dropout_mask (rng:: AbstractRNG , x, p; dims= :)
48
+ dropout_mask (rng:: CUDA.RNG , x:: CuArray , p; kwargs... ) = _dropout_mask (rng, x, p; kwargs... )
49
+ dropout_mask (rng, x:: CuArray , p; kwargs... ) =
50
+ ArgumentError (" x isa CuArray, but rng isa $(typeof (rng)) . dropout_mask only support CUDA.RNG for CuArrays." )
51
+ dropout_mask (rng, x, p; kwargs... ) = _dropout_mask (rng, x, p; kwargs... )
52
+ function _dropout_mask (rng, x, p; dims= :)
49
53
y = rand! (rng, similar (x, _dropout_shape (x, dims)))
50
54
y .= _dropout_kernel .(y, p, 1 - p)
51
55
return y
@@ -71,9 +75,9 @@ mutable struct Dropout{F,D,R<:AbstractRNG}
71
75
active:: Union{Bool, Nothing}
72
76
rng:: R
73
77
end
74
- Dropout (p, dims, active) = Dropout (p, dims, active, Random . default_rng ())
78
+ Dropout (p, dims, active) = Dropout (p, dims, active, rng_from_array ())
75
79
76
- function Dropout (p; dims= :, rng = Random . default_rng ())
80
+ function Dropout (p; dims= :, rng = rng_from_array ())
77
81
@assert 0 ≤ p ≤ 1
78
82
Dropout (p, dims, nothing , rng)
79
83
end
@@ -115,8 +119,8 @@ mutable struct AlphaDropout{F,R<:AbstractRNG}
115
119
new {typeof(p), typeof(rng)} (p, active, rng)
116
120
end
117
121
end
118
- AlphaDropout (p, active) = AlphaDropout (p, active, Random . default_rng ())
119
- AlphaDropout (p; rng = Random . default_rng ()) = AlphaDropout (p, nothing , rng)
122
+ AlphaDropout (p, active) = AlphaDropout (p, active, rng_from_array ())
123
+ AlphaDropout (p; rng = rng_from_array ()) = AlphaDropout (p, nothing , rng)
120
124
121
125
@functor AlphaDropout
122
126
0 commit comments