Skip to content

Commit 97ec9f1

Browse files
committed
TaskLocalRNG and GLOBAL_RNG are not the same
1 parent 3c8cceb commit 97ec9f1

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

src/functor.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ end
9696
struct FluxCUDAAdaptor end
9797
adapt_storage(to::FluxCUDAAdaptor, x) = CUDA.cu(x)
9898
adapt_storage(to::FluxCUDAAdaptor, x::Zygote.FillArrays.AbstractFill) = CUDA.cu(collect(x))
99-
adapt_storage(to::FluxCUDAAdaptor, x::Random._GLOBAL_RNG) = CUDA.default_rng()
99+
adapt_storage(to::FluxCUDAAdaptor, x::Random.TaskLocalRNG) = CUDA.default_rng()
100100
adapt_storage(to::FluxCUDAAdaptor, x::CUDA.RNG) = x
101101
adapt_storage(to::FluxCUDAAdaptor, x::AbstractRNG) =
102102
error("Cannot map RNG of type $(typeof(x)) to GPU. GPU execution only supports Random.default_rng().")

src/layers/normalise.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ function dropout(rng, x, p; dims=:, active::Bool=true)
3737
return x .* y
3838
end
3939
dropout(x, p; kwargs...) = dropout(Random.default_rng(), x, p; kwargs...)
40-
dropout(x::CuArray, p; kwargs...) = dropout(CUDA.CURAND.default_rng(), x, p; kwargs...)
40+
dropout(x::CuArray, p; kwargs...) = dropout(CUDA.default_rng(), x, p; kwargs...)
4141

4242
@adjoint function dropout(rng, x, p; dims=:, active::Bool=true)
4343
active || return x, Δ -> (Δ, nothing)
@@ -106,11 +106,12 @@ mutable struct AlphaDropout{F,R<:AbstractRNG}
106106
p::F
107107
active::Union{Bool, Nothing}
108108
rng::R
109-
function AlphaDropout(p, active = nothing, rng = Random.default_rng())
109+
function AlphaDropout(p, active, rng)
110110
@assert 0 p 1
111-
new{typeof(p)}(p, active, rng)
111+
new{typeof(p), typeof(rng)}(p, active, rng)
112112
end
113113
end
114+
AlphaDropout(p, active) = AlphaDropout(p, active, Random.default_rng())
114115
AlphaDropout(p; rng = Random.default_rng()) = AlphaDropout(p, nothing, rng)
115116

116117
function (a::AlphaDropout)(x::AbstractArray{T}) where T

0 commit comments

Comments
 (0)