Skip to content

Commit 726580b

Browse files
Merge pull request #202 from SciML/gpu
add GPUArrays and GPU differentiation test
2 parents a4688c2 + 61d8be0 commit 726580b

File tree

3 files changed

+20
-5
lines changed

3 files changed

+20
-5
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
99
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1010
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1111
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
12+
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
1213
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1314
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
1415
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
@@ -21,6 +22,7 @@ ArrayInterface = "2.7, 3.0, 4, 5"
2122
ChainRulesCore = "0.10.7, 1"
2223
DocStringExtensions = "0.8"
2324
FillArrays = "0.11, 0.12, 0.13"
25+
GPUArrays = "8"
2426
RecipesBase = "0.7, 0.8, 1.0"
2527
StaticArrays = "0.12, 1.0"
2628
ZygoteRules = "0.2"

src/RecursiveArrayTools.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ include("zygote.jl")
2424

2525
Base.show(io::IO, x::Union{ArrayPartition,AbstractVectorOfArray}) = invoke(show, Tuple{typeof(io), Any}, io, x)
2626

27+
import GPUArrays
28+
Base.convert(T::Type{<:GPUArrays.AbstractGPUArray}, VA::AbstractVectorOfArray) = T(VA)
29+
ChainRulesCore.rrule(T::Type{<:GPUArrays.AbstractGPUArray}, xs::AbstractVectorOfArray) = T(xs), ȳ -> (NoTangent(),ȳ)
30+
2731
export VectorOfArray, DiffEqArray, AbstractVectorOfArray, AbstractDiffEqArray,
2832
AllObserved, vecarr_to_arr, vecarr_to_vectors, tuples
2933

test/gpu/vectorofarray_gpu.jl

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
1-
using RecursiveArrayTools, CUDA, Test
1+
using RecursiveArrayTools, CUDA, Test, Zygote
22
CUDA.allowscalar(false)
33

44
# Test indexing with colon
55
x = zeros(5)
6-
y = VectorOfArray([x,x,x])
7-
y[:,:]
6+
y = VectorOfArray([x, x, x])
7+
y[:, :]
88

99
x = CUDA.zeros(5)
10-
y = VectorOfArray([x,x,x])
11-
y[:,:]
10+
y = VectorOfArray([x, x, x])
11+
y[:, :]
1212

1313
# Test indexing with boolean masks and colon
1414
nx, ny, nt = 3, 4, 5
@@ -22,3 +22,12 @@ va = VectorOfArray([slice for slice in eachslice(x, dims=3)])
2222
xc = Array(x)
2323
mc = Array(m)
2424
@test xc[mc, :] Array(va[m, :])
25+
26+
# Check differentiation with GPUs
27+
28+
p = cu([1.0, 2.0])
29+
function f(p)
30+
x = VectorOfArray([p, p])
31+
sum(CuArray(x))
32+
end
33+
Zygote.gradient(f, p)

0 commit comments

Comments
 (0)