Skip to content

Commit d1ef6e8

Browse files
committed
simplify default_rng etc
1 parent 4ab93b3 commit d1ef6e8

File tree

4 files changed

+13
-43
lines changed

4 files changed

+13
-43
lines changed

src/Flux.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module Flux
22

33
using Base: tail
44
using LinearAlgebra, Statistics, Random # standard lib
5+
using Random: default_rng
56
using MacroTools, Reexport, ProgressLogging, SpecialFunctions
67
using MacroTools: @forward
78

src/deprecations.jl

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,6 @@ Base.@deprecate_binding ADADelta AdaDelta
8484
# Remove sub-module Data, while making sure Flux.Data.DataLoader keeps working
8585
Base.@deprecate_binding Data Flux false "Sub-module Flux.Data has been removed. The only thing it contained may be accessed as Flux.DataLoader"
8686

87-
@deprecate rng_from_array() default_rng_value()
88-
8987
function istraining()
9088
Base.depwarn("Flux.istraining() is deprecated, use NNlib.within_gradient(x) instead", :istraining)
9189
false
@@ -185,17 +183,8 @@ function update!(opt::Optimise.AbstractOptimiser, ::Params, grads::Union{Tuple,
185183
""")
186184
end
187185

188-
189-
function dropout(rng, x, p; dims=:, active::Bool=true)
190-
if active
191-
NNlib.dropout(rng, x, p; dims)
192-
else
193-
Base.depwarn("Flux.dropout(...; active=false) is deprecated. Please branch outside the function, or call dropout(x, 0) if you must.", :dropout)
194-
return x
195-
end
196-
end
197-
dropout(x, p; kwargs...) = dropout(NNlib._rng_from_array(x), x, p; kwargs...)
198-
186+
@deprecate rng_from_array() default_rng_value()
187+
@deprecate default_rng_value() Random.default_rng()
199188

200189
# v0.14 deprecations
201190

src/layers/normalise.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
_isactive(m, x) = isnothing(m.active) ? NNlib.within_gradient(x) : m.active
33

44
"""
5-
Dropout(p; dims=:, rng = default_rng_value())
5+
Dropout(p; dims=:, rng = default_rng())
66
77
Layer implementing [dropout](https://arxiv.org/abs/1207.0580) with the given probability.
88
This is used as a regularisation, i.e. to reduce overfitting.
@@ -61,9 +61,9 @@ mutable struct Dropout{F<:Real,D,R<:AbstractRNG}
6161
active::Union{Bool, Nothing}
6262
rng::R
6363
end
64-
Dropout(p::Real, dims, active) = Dropout(p, dims, active, default_rng_value())
64+
Dropout(p::Real, dims, active) = Dropout(p, dims, active, default_rng())
6565

66-
function Dropout(p::Real; dims=:, rng = default_rng_value())
66+
function Dropout(p::Real; dims=:, rng = default_rng())
6767
0 p 1 || throw(ArgumentError("Dropout expexts 0 ≤ p ≤ 1, got p = $p"))
6868
if p isa Integer # Dropout(0)
6969
return p==0 ? identity : zero
@@ -92,7 +92,7 @@ function Base.show(io::IO, d::Dropout)
9292
end
9393

9494
"""
95-
AlphaDropout(p; rng = default_rng_value())
95+
AlphaDropout(p; rng = default_rng())
9696
9797
A dropout layer. Used in
9898
[Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515).
@@ -126,8 +126,8 @@ mutable struct AlphaDropout{F,R<:AbstractRNG}
126126
new{typeof(p), typeof(rng)}(p, active, rng)
127127
end
128128
end
129-
AlphaDropout(p, active) = AlphaDropout(p, active, default_rng_value())
130-
AlphaDropout(p; rng = default_rng_value()) = AlphaDropout(p, nothing, rng)
129+
AlphaDropout(p, active) = AlphaDropout(p, active, default_rng())
130+
AlphaDropout(p; rng = default_rng()) = AlphaDropout(p, nothing, rng)
131131

132132
@functor AlphaDropout
133133
trainable(a::AlphaDropout) = (;)

src/utils.jl

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -36,32 +36,12 @@ epseltype(x) = eps(float(eltype(x)))
3636
"""
3737
rng_from_array([x])
3838
39-
Create an instance of the RNG most appropriate for `x`.
40-
The current defaults are:
41-
- `x isa CuArray`: `CUDA.default_rng()`, else:
42-
- `x isa AbstractArray`, or no `x` provided:
43-
- Julia version is < 1.7: `Random.GLOBAL_RNG`
44-
- Julia version is >= 1.7: `Random.default_rng()`
45-
"""
46-
rng_from_array(::AbstractArray) = default_rng_value()
47-
rng_from_array(::CuArray) = CUDA.default_rng()
48-
49-
@non_differentiable rng_from_array(::Any)
50-
51-
if VERSION >= v"1.7"
52-
default_rng_value() = Random.default_rng()
53-
else
54-
default_rng_value() = Random.GLOBAL_RNG
55-
end
56-
39+
Create an instance of the RNG most appropriate for array `x`.
40+
If `x isa CuArray` then this is `CUDA.default_rng()`,
41+
otherwise `Random.default_rng()`.
5742
"""
58-
default_rng_value()
43+
rng_from_array(x::AbstractArray = 1:0) = NNlib._rng_from_array(x)
5944

60-
Create an instance of the default RNG depending on Julia's version.
61-
- Julia version is < 1.7: `Random.GLOBAL_RNG`
62-
- Julia version is >= 1.7: `Random.default_rng()`
63-
"""
64-
default_rng_value
6545

6646
"""
6747
glorot_uniform([rng = default_rng_value()], size...; gain = 1) -> Array

0 commit comments

Comments
 (0)