Skip to content

Commit 612c247

Browse files
Merge pull request #376 from SciML/ap/abstract_array
Device agnostic convert
2 parents 3957332 + 3eb80eb commit 612c247

File tree

3 files changed

+10
-1
lines changed

3 files changed

+10
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "RecursiveArrayTools"
22
uuid = "731186ca-8d62-57ce-b412-fbd966d074cd"
33
authors = ["Chris Rackauckas <[email protected]>"]
4-
version = "3.16.0"
4+
version = "3.16.1"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/vector_of_array.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,9 @@ function Base.Array{U}(VA::AbstractVectorOfArray) where {U}
135135
vecs = vec.(VA.u)
136136
Array(reshape(reduce(hcat, vecs), size(VA.u[1])..., length(VA.u)))
137137
end
138+
139+
Base.convert(::Type{AbstractArray}, VA::AbstractVectorOfArray) = stack(VA.u)
140+
138141
function Adapt.adapt_structure(to, VA::AbstractVectorOfArray)
139142
Adapt.adapt(to, Array(VA))
140143
end

test/gpu/vectorofarray_gpu.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,9 @@ function f(p)
3131
sum(CuArray(x))
3232
end
3333
Zygote.gradient(f, p)
34+
35+
# Check conversion preserves device
36+
va_cu = convert(AbstractArray, va)
37+
38+
@test va_cu isa CuArray
39+
@test size(va_cu) == size(x)

0 commit comments

Comments
 (0)