Skip to content

Commit d57f502

Browse files
Merge pull request #318 from AayushSabharwal/as/gpu
fix: GPU tests, CuArray conversion, autodiff
2 parents 87ef7d5 + c1249b3 commit d57f502

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

ext/RecursiveArrayToolsZygoteExt.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,17 +95,17 @@ end
9595
VectorOfArray(u),
9696
y -> begin
9797
y isa Ref && (y = VectorOfArray(y[].u))
98-
(VectorOfArray([y[ntuple(x -> Colon(), ndims(y.u) - 1)..., i]
99-
for i in 1:size(y.u)[end]]),)
98+
(VectorOfArray([y[ntuple(x -> Colon(), ndims(y) - 1)..., i]
99+
for i in 1:size(y)[end]]),)
100100
end
101101
end
102102

103103
@adjoint function DiffEqArray(u, t)
104104
DiffEqArray(u, t),
105105
y -> begin
106106
y isa Ref && (y = VectorOfArray(y[].u))
107-
(DiffEqArray([y[ntuple(x -> Colon(), ndims(y.u) - 1)..., i]
108-
for i in 1:size(y.u)[end]],
107+
(DiffEqArray([y[ntuple(x -> Colon(), ndims(y) - 1)..., i]
108+
for i in 1:size(y)[end]],
109109
t), nothing)
110110
end
111111
end

src/RecursiveArrayTools.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ end
2828

2929
import GPUArraysCore
3030
Base.convert(T::Type{<:GPUArraysCore.AnyGPUArray}, VA::AbstractVectorOfArray) = stack(VA.u)
31+
(T::Type{<:GPUArraysCore.AnyGPUArray})(VA::AbstractVectorOfArray) = T(Array(VA))
3132

3233
import Requires
3334
@static if !isdefined(Base, :get_extension)

0 commit comments

Comments
 (0)