Skip to content

Commit 3c8cceb

Browse files
committed
Let AlphaDropout join the RNG fun
1 parent 742341c commit 3c8cceb

File tree

3 files changed

+47
-34
lines changed

3 files changed

+47
-34
lines changed

src/layers/normalise.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ function Base.show(io::IO, d::Dropout)
9393
end
9494

9595
"""
96-
AlphaDropout(p)
96+
AlphaDropout(p; rng = default_rng())
9797
9898
A dropout layer. Used in
9999
[Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515).
@@ -102,14 +102,16 @@ remain the same as before.
102102
103103
Does nothing to the input once [`testmode!`](@ref) is true.
104104
"""
105-
mutable struct AlphaDropout{F}
105+
mutable struct AlphaDropout{F,R<:AbstractRNG}
106106
p::F
107107
active::Union{Bool, Nothing}
108-
function AlphaDropout(p, active = nothing)
108+
rng::R
109+
function AlphaDropout(p, active = nothing, rng = Random.default_rng())
109110
@assert 0 p 1
110-
new{typeof(p)}(p, active)
111+
new{typeof(p)}(p, active, rng)
111112
end
112113
end
114+
AlphaDropout(p; rng = Random.default_rng()) = AlphaDropout(p, nothing, rng)
113115

114116
function (a::AlphaDropout)(x::AbstractArray{T}) where T
115117
_isactive(a) || return x
@@ -121,7 +123,7 @@ function (a::AlphaDropout)(x::AbstractArray{T}) where T
121123
A = T(inv(sqrt((1 - p) * (1 + p * α′^2))))
122124
B = T(-A * α′ * p)
123125

124-
noise = rand!(similar(x))
126+
noise = rand!(a.rng, similar(x))
125127
return A .* ifelse.(noise .> p, x, α′) .+ B
126128
end
127129

test/cuda/layers.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -282,8 +282,10 @@ end
282282
end
283283

284284
@testset "Dropout RNGs" begin
285-
m = Dropout(0.1; rng = MersenneTwister(123))
286-
@test_throws ErrorException gpu(m)
287-
m = Dropout(0.1; rng = CUDA.default_rng())
288-
@test gpu(m).rng === CUDA.default_rng()
285+
@testset for layer in (Dropout, AlphaDropout)
286+
m = layer(0.1; rng = MersenneTwister(123))
287+
@test_throws ErrorException gpu(m)
288+
m = layer(0.1; rng = CUDA.default_rng())
289+
@test gpu(m).rng === CUDA.default_rng()
290+
end
289291
end

test/layers/normalisation.jl

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -67,31 +67,40 @@ evalwgrad(f, x...) = pullback(f, x...)[1]
6767
end
6868

6969
@testset "AlphaDropout" begin
70-
x = [1., 2., 3.]
71-
@test x == AlphaDropout(0.1)(x)
72-
@test x == evalwgrad(AlphaDropout(0), x)
73-
@test zero(x) == evalwgrad(AlphaDropout(1), x)
74-
75-
x = randn(1000) # large enough to prevent flaky test
76-
m = AlphaDropout(0.5)
77-
78-
y = evalwgrad(m, x)
79-
# Should preserve unit mean and variance
80-
@test mean(y) 0 atol=0.1
81-
@test var(y) 1 atol=0.1
82-
83-
testmode!(m, true) # should override istraining
84-
@test evalwgrad(m, x) == x
85-
86-
testmode!(m, false)
87-
y = evalwgrad(m, x)
88-
@test mean(y) 0 atol=0.1
89-
@test var(y) 1 atol=0.1
90-
91-
# Known good value ranges
92-
# Values taken from https://github.com/pytorch/pytorch/blob/v1.10.0/test/cpp/api/modules.cpp#L1337-L1338
93-
x = ones(100)
94-
@test 40 < sum(evalwgrad(m, x)) < 130
70+
@testset for rng_kwargs in ((), (; rng = MersenneTwister(123)))
71+
x = [1., 2., 3.]
72+
@test x == AlphaDropout(0.1; rng_kwargs...)(x)
73+
@test x == evalwgrad(AlphaDropout(0; rng_kwargs...), x)
74+
@test zero(x) == evalwgrad(AlphaDropout(1; rng_kwargs...), x)
75+
76+
x = randn(1000) # large enough to prevent flaky test
77+
m = AlphaDropout(0.5; rng_kwargs...)
78+
79+
y = evalwgrad(m, x)
80+
# Should preserve unit mean and variance
81+
@test mean(y) 0 atol=0.1
82+
@test var(y) 1 atol=0.1
83+
84+
testmode!(m, true) # should override istraining
85+
@test evalwgrad(m, x) == x
86+
87+
testmode!(m, false)
88+
y = evalwgrad(m, x)
89+
@test mean(y) 0 atol=0.1
90+
@test var(y) 1 atol=0.1
91+
92+
# Known good value ranges
93+
# Values taken from https://github.com/pytorch/pytorch/blob/v1.10.0/test/cpp/api/modules.cpp#L1337-L1338
94+
x = ones(100)
95+
@test 40 < sum(evalwgrad(m, x)) < 130
96+
97+
# CPU RNGs map onto CPU ok
98+
if isempty(rng_kwargs)
99+
@test cpu(m).rng === Random.default_rng()
100+
else
101+
@test cpu(m).rng === only(values(rng_kwargs))
102+
end
103+
end
95104
end
96105

97106
@testset "BatchNorm" begin

0 commit comments

Comments
 (0)