Skip to content

Commit dfa6732

Browse files
Merge pull request #430 from JuliaArrays/aos_to_soa_ReverseDiff
Specialize Array of Structs to Struct of Array for ReverseDiff
2 parents 7c5d183 + 80e220d commit dfa6732

File tree

4 files changed

+56
-8
lines changed

4 files changed

+56
-8
lines changed

Project.toml

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ArrayInterface"
22
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
3-
version = "7.7.1"
3+
version = "7.8.0"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -14,6 +14,7 @@ BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
1414
BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0"
1515
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
1616
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
17+
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
1718
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
1819
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
1920

@@ -22,16 +23,17 @@ ArrayInterfaceBandedMatricesExt = "BandedMatrices"
2223
ArrayInterfaceBlockBandedMatricesExt = "BlockBandedMatrices"
2324
ArrayInterfaceCUDAExt = "CUDA"
2425
ArrayInterfaceGPUArraysCoreExt = "GPUArraysCore"
26+
ArrayInterfaceReverseDiffExt = "ReverseDiff"
2527
ArrayInterfaceStaticArraysCoreExt = "StaticArraysCore"
2628
ArrayInterfaceTrackerExt = "Tracker"
2729

2830
[compat]
29-
Adapt = "3, 4"
30-
LinearAlgebra = "1.9"
31+
Adapt = "4"
32+
LinearAlgebra = "1.10"
3133
Requires = "1"
32-
SparseArrays = "1.9"
33-
SuiteSparse = "1.9"
34-
julia = "1.9"
34+
SparseArrays = "1.10"
35+
SuiteSparse = "1.10"
36+
julia = "1.10"
3537

3638
[extras]
3739
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
@@ -41,6 +43,7 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
4143
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
4244
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
4345
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
46+
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
4447
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
4548
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
4649
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
@@ -50,4 +53,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5053
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
5154

5255
[targets]
53-
test = ["SafeTestsets", "Pkg", "Test", "Aqua", "Random", "SparseArrays", "SuiteSparse", "BandedMatrices", "BlockBandedMatrices", "GPUArraysCore", "StaticArrays", "StaticArraysCore", "Tracker"]
56+
test = ["SafeTestsets", "Pkg", "Test", "Aqua", "Random", "SparseArrays", "SuiteSparse", "BandedMatrices", "BlockBandedMatrices", "GPUArraysCore", "StaticArrays", "StaticArraysCore", "Tracker", "ReverseDiff"]

ext/ArrayInterfaceReverseDiffExt.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
module ArrayInterfaceReverseDiffExt
2+
3+
if isdefined(Base, :get_extension)
4+
using ArrayInterface
5+
import ReverseDiff
6+
else
7+
using ..ArrayInterface
8+
import ..ReverseDiff
9+
end
10+
11+
ArrayInterface.ismutable(::Type{<:ReverseDiff.TrackedArray}) = false
12+
ArrayInterface.ismutable(T::Type{<:ReverseDiff.TrackedReal}) = false
13+
ArrayInterface.can_setindex(::Type{<:ReverseDiff.TrackedArray}) = false
14+
ArrayInterface.fast_scalar_indexing(::Type{<:ReverseDiff.TrackedArray}) = false
15+
function ArrayInterface.aos_to_soa(x::AbstractArray{<:ReverseDiff.TrackedReal,N}) where {N}
16+
if length(x) > 1
17+
reduce(vcat,x)
18+
else
19+
@show "here?"
20+
reduce(vcat,[x[1],x[1]])[1:1]
21+
end
22+
end
23+
24+
end # module

test/ad.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
using ArrayInterface, ReverseDiff, Tracker, Test
2+
x = ReverseDiff.track([4.0])
3+
@test ArrayInterface.aos_to_soa(x) isa ReverseDiff.TrackedArray
4+
x = reduce(vcat, ReverseDiff.track([4.0,4.0]))
5+
@test ArrayInterface.aos_to_soa(x) isa ReverseDiff.TrackedArray
6+
x = [ReverseDiff.track([4.0])[1]]
7+
@test ArrayInterface.aos_to_soa(x) isa ReverseDiff.TrackedArray
8+
x = reduce(vcat, ReverseDiff.track([4.0,4.0]))
9+
x = [x[1],x[2]]
10+
@test ArrayInterface.aos_to_soa(x) isa ReverseDiff.TrackedArray
11+
12+
x = Tracker.TrackedArray([4.0])
13+
@test ArrayInterface.aos_to_soa(x) isa Tracker.TrackedArray
14+
x = [Tracker.TrackedArray([4.0])[1]]
15+
@test ArrayInterface.aos_to_soa(x) isa Tracker.TrackedArray
16+
x = Tracker.TrackedArray([4.0,4.0])
17+
@test ArrayInterface.aos_to_soa(x) isa Tracker.TrackedArray
18+
x = reduce(vcat, Tracker.TrackedArray([4.0,4.0]))
19+
x = [x[1],x[2]]
20+
@test ArrayInterface.aos_to_soa(x) isa Tracker.TrackedArray

test/runtests.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@ end
1313
@time @safetestset "BandedMatrices" begin include("bandedmatrices.jl") end
1414
@time @safetestset "BlockBandedMatrices" begin include("blockbandedmatrices.jl") end
1515
@time @safetestset "Core" begin include("core.jl") end
16+
@time @safetestset "AD Integration" begin include("ad.jl") end
1617
@time @safetestset "StaticArraysCore" begin include("staticarrayscore.jl") end
1718
end
1819

1920
if GROUP == "GPU"
2021
activate_gpu_env()
2122
@time @safetestset "CUDA" begin include("gpu/cuda.jl") end
2223
end
23-
end
24+
end

0 commit comments

Comments
 (0)