@@ -121,14 +121,31 @@ adapt_storage(to::FluxCPUAdaptor, x::AbstractSparseArray) = x
121
121
adapt_storage (to:: FluxCPUAdaptor , x:: CUDA.RNG ) = Random. default_rng ()
122
122
adapt_storage (to:: FluxCPUAdaptor , x:: AbstractRNG ) = x
123
123
124
+ # PIRACY, should be defined in CUDA.jl
124
125
function ChainRulesCore. rrule (:: Type{Array} , x:: CUDA.CuArray )
125
- Array (x), dx -> (NoTangent (), CUDA. cu (unthunk (dx)), )
126
+ Array (x), dx -> (NoTangent (), CUDA. cu (unthunk (dx)))
126
127
end
127
128
128
129
function ChainRulesCore. rrule (:: typeof (Adapt. adapt_storage), to:: FluxCPUAdaptor , x:: CUDA.AbstractGPUArray )
129
- adapt_storage (to, x), dx -> (NoTangent (), NoTangent (), adapt_storage (FluxCUDAAdaptor (), unthunk (dx)), )
130
+ adapt_storage (to, x), dx -> (NoTangent (), NoTangent (), adapt_storage (FluxCUDAAdaptor (), unthunk (dx)))
130
131
end
131
132
133
+ # The following rrules for adapt are here to avoid double wrapping issues
134
+ # as seen in https://github.com/FluxML/Flux.jl/pull/2117#discussion_r1027321801
135
+
136
+ ChainRulesCore. rrule (:: typeof (adapt), a:: FluxCPUAdaptor , x:: AnyCuArray ) =
137
+ adapt (a, x), Δ -> (NoTangent (), NoTangent (), adapt (FluxCUDAAdaptor (), unthunk (Δ)))
138
+
139
+ ChainRulesCore. rrule (:: typeof (adapt), a:: FluxCPUAdaptor , x:: AbstractArray ) =
140
+ adapt (a, x), Δ -> (NoTangent (), NoTangent (), Δ)
141
+
142
+ ChainRulesCore. rrule (:: typeof (adapt), a:: FluxCUDAAdaptor , x:: AnyCuArray ) =
143
+ adapt (a, x), Δ -> (NoTangent (), NoTangent (), Δ)
144
+
145
+ ChainRulesCore. rrule (:: typeof (adapt), a:: FluxCUDAAdaptor , x:: AbstractArray ) =
146
+ adapt (a, x), Δ -> (NoTangent (), NoTangent (), adapt (FluxCPUAdaptor (), unthunk (Δ)))
147
+
148
+
132
149
# CPU/GPU movement conveniences
133
150
134
151
"""
@@ -154,7 +171,7 @@ julia> typeof(m_cpu.W)
154
171
Matrix{Float32}
155
172
```
156
173
"""
157
- cpu (x) = fmap (x -> adapt (FluxCPUAdaptor (), x), x)
174
+ cpu (x) = fmap (x -> adapt (FluxCPUAdaptor (), x), x, exclude = _isleaf )
158
175
159
176
_isbitsarray (:: AbstractArray{<:Number} ) = true
160
177
_isbitsarray (:: AbstractArray{T} ) where T = isbitstype (T)
0 commit comments