Skip to content

Commit b04514c

Browse files
committed
TaskLocalRNG is 1.7+
1 parent b8d155b commit b04514c

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

src/functor.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,11 @@ end
9696
struct FluxCUDAAdaptor end
9797
adapt_storage(to::FluxCUDAAdaptor, x) = CUDA.cu(x)
9898
adapt_storage(to::FluxCUDAAdaptor, x::Zygote.FillArrays.AbstractFill) = CUDA.cu(collect(x))
99-
adapt_storage(to::FluxCUDAAdaptor, x::Random.TaskLocalRNG) = CUDA.default_rng()
99+
if VERSION >= v"1.7"
100+
adapt_storage(to::FluxCUDAAdaptor, x::Random.TaskLocalRNG) = CUDA.default_rng()
101+
else
102+
adapt_storage(to::FluxCUDAAdaptor, x::Random._GLOBAL_RNG) = CUDA.default_rng()
103+
end
100104
adapt_storage(to::FluxCUDAAdaptor, x::CUDA.RNG) = x
101105
adapt_storage(to::FluxCUDAAdaptor, x::AbstractRNG) =
102106
error("Cannot map RNG of type $(typeof(x)) to GPU. GPU execution only supports Random.default_rng().")

0 commit comments

Comments
 (0)