Skip to content

Commit 6ec82b5

Browse files
Merge pull request #247 from SciML/weakdep
Add ReverseDiff weak dep
2 parents 2e3e3d4 + 2746e50 commit 6ec82b5

File tree

7 files changed

+66
-19
lines changed

7 files changed

+66
-19
lines changed

Project.toml

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1414
IteratorInterfaceExtensions = "82899510-4779-5014-852e-03e436cf321d"
1515
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1616
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
17+
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1718
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
1819
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1920
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
@@ -30,12 +31,16 @@ FillArrays = "0.11, 0.12, 0.13"
3031
GPUArraysCore = "0.1"
3132
IteratorInterfaceExtensions = "1"
3233
RecipesBase = "0.7, 0.8, 1.0"
34+
Requires = "1.0"
3335
StaticArraysCore = "1.1"
3436
SymbolicIndexingInterface = "0.1, 0.2"
3537
Tables = "1"
3638
ZygoteRules = "0.2"
3739
julia = "1.6"
3840

41+
[extensions]
42+
RecursiveArrayToolsTrackerExt = "Tracker"
43+
3944
[extras]
4045
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
4146
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
@@ -44,11 +49,16 @@ NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
4449
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
4550
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
4651
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
52+
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
4753
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
4854
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
4955
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
56+
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
5057
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
5158
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
5259

5360
[targets]
54-
test = ["Aqua", "ForwardDiff", "LabelledArrays", "NLsolve", "OrdinaryDiffEq", "Pkg", "Test", "Unitful", "Random", "StaticArrays", "StructArrays", "Zygote"]
61+
test = ["SafeTestsets", "Aqua", "ForwardDiff", "LabelledArrays", "NLsolve", "OrdinaryDiffEq", "Pkg", "Test", "Unitful", "Random", "StaticArrays", "StructArrays", "Zygote"]
62+
63+
[weakdeps]
64+
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
module RecursiveArrayToolsTrackerExt
2+
3+
import RecursiveArrayTools
4+
isdefined(Base, :get_extension) ? (import Tracker) : (import ..Tracker)
5+
6+
function RecursiveArrayTools.recursivecopy!(b::AbstractArray{T, N},
7+
a::AbstractArray{T2, N}) where {
8+
T <:
9+
Tracker.TrackedArray,
10+
T2 <:
11+
Tracker.TrackedArray,
12+
N}
13+
@inbounds for i in eachindex(a)
14+
b[i] = copy(a[i])
15+
end
16+
end
17+
18+
end

src/RecursiveArrayTools.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,13 @@ function ChainRulesCore.rrule(T::Type{<:GPUArraysCore.AbstractGPUArray},
4242
T(xs), ȳ -> (NoTangent(), ȳ)
4343
end
4444

45+
import Requires
46+
@static if !isdefined(Base, :get_extension)
47+
function __init__()
48+
Requires.@require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin include("../ext/RecursiveArrayToolsTrackerExt.jl") end
49+
end
50+
end
51+
4552
export VectorOfArray, DiffEqArray, AbstractVectorOfArray, AbstractDiffEqArray,
4653
AllObserved, vecarr_to_vectors, tuples
4754

src/tabletraits.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@ function Tables.rows(A::AbstractDiffEqArray)
77
N = length(A.u[1])
88
names = [
99
:timestamp,
10-
(A.sc !== nothing && A.sc.syms !== nothing ? (A.sc.syms[i] for i in 1:N) :
10+
(!(A.sc isa SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing}) ? (states(A.sc)[i] for i in 1:N) :
1111
(Symbol("value", i) for i in 1:N))...,
1212
]
1313
types = Type[eltype(A.t), (eltype(A.u[1]) for _ in 1:N)...]
1414
else
15-
names = [:timestamp, A.sc !== nothing && A.sc.syms !== nothing ? A.sc.syms[1] : :value]
15+
names = [:timestamp, !(A.sc isa SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing}) ? states(A.sc)[1] : :value]
1616
types = Type[eltype(A.t), VT]
1717
end
1818
return AbstractDiffEqArrayRows(names, types, A.t, A.u)
@@ -31,8 +31,8 @@ struct AbstractDiffEqArrayRows{T, U}
3131
u::U
3232
end
3333
function AbstractDiffEqArrayRows(names, types, t, u)
34-
AbstractDiffEqArrayRows(names, types,
35-
Dict(nm => i for (i, nm) in enumerate(names)), t, u)
34+
AbstractDiffEqArrayRows(Symbol.(names), types,
35+
Dict(Symbol(nm) => i for (i, nm) in enumerate(names)), t, u)
3636
end
3737

3838
Base.length(x::AbstractDiffEqArrayRows) = length(x.u)

test/downstream/Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
[deps]
22
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
33
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
4+
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
45

56
[compat]
67
ModelingToolkit = "8.33"
7-
OrdinaryDiffEq = "6.31"
8+
OrdinaryDiffEq = "6.31"
9+
Tracker = "0.2"

test/downstream/TrackerExt.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
using RecursiveArrayTools, Tracker, Test
2+
3+
x = [5.0]
4+
a = [Tracker.TrackedArray(x)]
5+
b = [Tracker.TrackedArray(copy([5.2]))]
6+
RecursiveArrayTools.recursivecopy!(a,b)
7+
@test a[1][1] == 5.2

test/runtests.jl

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ using Pkg
22
using RecursiveArrayTools
33
using Test
44
using Aqua
5+
using SafeTestsets
6+
57
Aqua.test_all(RecursiveArrayTools, ambiguities = false)
68
@test_broken isempty(Test.detect_ambiguities(RecursiveArrayTools))
79
const GROUP = get(ENV, "GROUP", "All")
@@ -21,26 +23,27 @@ end
2123

2224
@time begin
2325
if GROUP == "Core" || GROUP == "All"
24-
@time @testset "Utils Tests" begin include("utils_test.jl") end
25-
@time @testset "Partitions Tests" begin include("partitions_test.jl") end
26-
@time @testset "VecOfArr Indexing Tests" begin include("basic_indexing.jl") end
27-
@time @testset "SymbolicIndexingInterface API test" begin include("symbolic_indexing_interface_test.jl") end
28-
@time @testset "VecOfArr Interface Tests" begin include("interface_tests.jl") end
29-
@time @testset "Table traits" begin include("tabletraits.jl") end
30-
@time @testset "StaticArrays Tests" begin include("copy_static_array_test.jl") end
31-
@time @testset "Linear Algebra Tests" begin include("linalg.jl") end
32-
@time @testset "Upstream Tests" begin include("upstream.jl") end
33-
@time @testset "Adjoint Tests" begin include("adjoints.jl") end
26+
@time @safetestset "Utils Tests" begin include("utils_test.jl") end
27+
@time @safetestset "Partitions Tests" begin include("partitions_test.jl") end
28+
@time @safetestset "VecOfArr Indexing Tests" begin include("basic_indexing.jl") end
29+
@time @safetestset "SymbolicIndexingInterface API test" begin include("symbolic_indexing_interface_test.jl") end
30+
@time @safetestset "VecOfArr Interface Tests" begin include("interface_tests.jl") end
31+
@time @safetestset "Table traits" begin include("tabletraits.jl") end
32+
@time @safetestset "StaticArrays Tests" begin include("copy_static_array_test.jl") end
33+
@time @safetestset "Linear Algebra Tests" begin include("linalg.jl") end
34+
@time @safetestset "Upstream Tests" begin include("upstream.jl") end
35+
@time @safetestset "Adjoint Tests" begin include("adjoints.jl") end
3436
end
3537

3638
if !is_APPVEYOR && GROUP == "Downstream"
3739
activate_downstream_env()
38-
@time @testset "DiffEqArray Indexing Tests" begin include("downstream/symbol_indexing.jl") end
39-
@time @testset "Event Tests with ArrayPartition" begin include("downstream/downstream_events.jl") end
40+
@time @safetestset "DiffEqArray Indexing Tests" begin include("downstream/symbol_indexing.jl") end
41+
@time @safetestset "Event Tests with ArrayPartition" begin include("downstream/downstream_events.jl") end
42+
@time @safetestset "TrackerExt" begin include("downstream/TrackerExt.jl") end
4043
end
4144

4245
if !is_APPVEYOR && GROUP == "GPU"
4346
activate_gpu_env()
44-
@time @testset "VectorOfArray GPU" begin include("gpu/vectorofarray_gpu.jl") end
47+
@time @safetestset "VectorOfArray GPU" begin include("gpu/vectorofarray_gpu.jl") end
4548
end
4649
end

0 commit comments

Comments
 (0)