Skip to content

Commit f42f475

Browse files
committed
Revert "simplify default_rng etc"
This reverts commit 0e396a6.
1 parent 0e396a6 commit f42f475

File tree

5 files changed

+45
-15
lines changed

5 files changed

+45
-15
lines changed

src/Flux.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ module Flux
22

33
using Base: tail
44
using LinearAlgebra, Statistics, Random # standard lib
5-
using Random: default_rng
65
using MacroTools, Reexport, ProgressLogging, SpecialFunctions
76
using MacroTools: @forward
87

src/deprecations.jl

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ Base.@deprecate_binding ADADelta AdaDelta
8484
# Remove sub-module Data, while making sure Flux.Data.DataLoader keeps working
8585
Base.@deprecate_binding Data Flux false "Sub-module Flux.Data has been removed. The only thing it contained may be accessed as Flux.DataLoader"
8686

87+
@deprecate rng_from_array() default_rng_value()
88+
8789
function istraining()
8890
Base.depwarn("Flux.istraining() is deprecated, use NNlib.within_gradient(x) instead", :istraining)
8991
false
@@ -183,8 +185,17 @@ function update!(opt::Optimise.AbstractOptimiser, ::Params, grads::Union{Tuple,
183185
""")
184186
end
185187

186-
@deprecate rng_from_array() default_rng_value()
187-
@deprecate default_rng_value() Random.default_rng()
188+
189+
function dropout(rng, x, p; dims=:, active::Bool=true)
190+
if active
191+
NNlib.dropout(rng, x, p; dims)
192+
else
193+
Base.depwarn("Flux.dropout(...; active=false) is deprecated. Please branch outside the function, or call dropout(x, 0) if you must.", :dropout)
194+
return x
195+
end
196+
end
197+
dropout(x, p; kwargs...) = dropout(NNlib._rng_from_array(x), x, p; kwargs...)
198+
188199

189200
# v0.14 deprecations
190201

src/layers/normalise.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
_isactive(m, x) = isnothing(m.active) ? NNlib.within_gradient(x) : m.active
33

44
"""
5-
Dropout(p; dims=:, rng = default_rng())
5+
Dropout(p; dims=:, rng = default_rng_value())
66
77
Layer implementing [dropout](https://arxiv.org/abs/1207.0580) with the given probability.
88
This is used as a regularisation, i.e. to reduce overfitting.
@@ -61,9 +61,9 @@ mutable struct Dropout{F<:Real,D,R<:AbstractRNG}
6161
active::Union{Bool, Nothing}
6262
rng::R
6363
end
64-
Dropout(p::Real, dims, active) = Dropout(p, dims, active, default_rng())
64+
Dropout(p::Real, dims, active) = Dropout(p, dims, active, default_rng_value())
6565

66-
function Dropout(p::Real; dims=:, rng = default_rng())
66+
function Dropout(p::Real; dims=:, rng = default_rng_value())
6767
0 p 1 || throw(ArgumentError("Dropout expexts 0 ≤ p ≤ 1, got p = $p"))
6868
if p isa Integer # Dropout(0)
6969
return p==0 ? identity : zero
@@ -92,7 +92,7 @@ function Base.show(io::IO, d::Dropout)
9292
end
9393

9494
"""
95-
AlphaDropout(p; rng = default_rng())
95+
AlphaDropout(p; rng = default_rng_value())
9696
9797
A dropout layer. Used in
9898
[Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515).
@@ -126,8 +126,8 @@ mutable struct AlphaDropout{F,R<:AbstractRNG}
126126
new{typeof(p), typeof(rng)}(p, active, rng)
127127
end
128128
end
129-
AlphaDropout(p, active) = AlphaDropout(p, active, default_rng())
130-
AlphaDropout(p; rng = default_rng()) = AlphaDropout(p, nothing, rng)
129+
AlphaDropout(p, active) = AlphaDropout(p, active, default_rng_value())
130+
AlphaDropout(p; rng = default_rng_value()) = AlphaDropout(p, nothing, rng)
131131

132132
@functor AlphaDropout
133133
trainable(a::AlphaDropout) = (;)

src/utils.jl

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,32 @@ epseltype(x) = eps(float(eltype(x)))
3636
"""
3737
rng_from_array([x])
3838
39-
Create an instance of the RNG most appropriate for array `x`.
40-
If `x isa CuArray` then this is `CUDA.default_rng()`,
41-
otherwise `Random.default_rng()`.
39+
Create an instance of the RNG most appropriate for `x`.
40+
The current defaults are:
41+
- `x isa CuArray`: `CUDA.default_rng()`, else:
42+
- `x isa AbstractArray`, or no `x` provided:
43+
- Julia version is < 1.7: `Random.GLOBAL_RNG`
44+
- Julia version is >= 1.7: `Random.default_rng()`
45+
"""
46+
rng_from_array(::AbstractArray) = default_rng_value()
47+
rng_from_array(::CuArray) = CUDA.default_rng()
48+
49+
@non_differentiable rng_from_array(::Any)
50+
51+
if VERSION >= v"1.7"
52+
default_rng_value() = Random.default_rng()
53+
else
54+
default_rng_value() = Random.GLOBAL_RNG
55+
end
56+
4257
"""
43-
rng_from_array(x::AbstractArray) = NNlib._rng_from_array(x)
58+
default_rng_value()
4459
60+
Create an instance of the default RNG depending on Julia's version.
61+
- Julia version is < 1.7: `Random.GLOBAL_RNG`
62+
- Julia version is >= 1.7: `Random.default_rng()`
63+
"""
64+
default_rng_value
4565

4666
"""
4767
glorot_uniform([rng = default_rng_value()], size...; gain = 1) -> Array

test/layers/normalisation.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,10 @@ evalwgrad(f, x...) = pullback(f, x...)[1]
5656
y = m(x)
5757
@test count(a->a == 0, y) > 50
5858

59-
y = Flux.dropout(values(rng_kwargs)..., x, 0.9) # , active=true)
59+
y = Flux.dropout(values(rng_kwargs)..., x, 0.9, active=true)
6060
@test count(a->a == 0, y) > 50
6161

62-
y = Flux.dropout(values(rng_kwargs)..., x, 0.9 * 0) #, active=false)
62+
y = Flux.dropout(values(rng_kwargs)..., x, 0.9, active=false)
6363
@test count(a->a == 0, y) == 0
6464

6565
# CPU RNGs map onto CPU ok

0 commit comments

Comments
 (0)