We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent b8d155b commit b04514cCopy full SHA for b04514c
src/functor.jl
@@ -96,7 +96,11 @@ end
96
struct FluxCUDAAdaptor end
97
adapt_storage(to::FluxCUDAAdaptor, x) = CUDA.cu(x)
98
adapt_storage(to::FluxCUDAAdaptor, x::Zygote.FillArrays.AbstractFill) = CUDA.cu(collect(x))
99
-adapt_storage(to::FluxCUDAAdaptor, x::Random.TaskLocalRNG) = CUDA.default_rng()
+if VERSION >= v"1.7"
100
+ adapt_storage(to::FluxCUDAAdaptor, x::Random.TaskLocalRNG) = CUDA.default_rng()
101
+else
102
+ adapt_storage(to::FluxCUDAAdaptor, x::Random._GLOBAL_RNG) = CUDA.default_rng()
103
+end
104
adapt_storage(to::FluxCUDAAdaptor, x::CUDA.RNG) = x
105
adapt_storage(to::FluxCUDAAdaptor, x::AbstractRNG) =
106
error("Cannot map RNG of type $(typeof(x)) to GPU. GPU execution only supports Random.default_rng().")
0 commit comments