Skip to content

Commit 2c86c8c

Browse files
committed
Enable symbol based indexing of interpolated solutions by adding extra fields to DiffEqArray type
1 parent 532745d commit 2c86c8c

File tree

6 files changed

+121
-6
lines changed

6 files changed

+121
-6
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,12 @@ julia = "1.3"
2626
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
2727
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
2828
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
29+
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
2930
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
3031
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
3132
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3233
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
3334
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3435

3536
[targets]
36-
test = ["ForwardDiff", "NLsolve", "OrdinaryDiffEq", "Test", "Unitful", "Random", "StructArrays", "Zygote"]
37+
test = ["ForwardDiff", "NLsolve", "OrdinaryDiffEq", "Pkg", "Test", "Unitful", "Random", "StructArrays", "Zygote"]

src/RecursiveArrayTools.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,13 @@ using DocStringExtensions
1818
include("zygote.jl")
1919

2020
export VectorOfArray, DiffEqArray, AbstractVectorOfArray, AbstractDiffEqArray,
21-
vecarr_to_arr, vecarr_to_vectors, tuples
21+
AllObserved, vecarr_to_arr, vecarr_to_vectors, tuples
2222

2323
export recursivecopy, recursivecopy!, vecvecapply, copyat_or_push!,
2424
vecvec_to_mat, recursive_one, recursive_mean, recursive_bottom_eltype,
25-
recursive_unitless_bottom_eltype, recursive_unitless_eltype
25+
recursive_unitless_bottom_eltype, recursive_unitless_eltype, parameterless_type,
26+
issymbollike
27+
2628

2729
export ArrayPartition
2830

src/vector_of_array.jl

Lines changed: 74 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,13 @@ mutable struct VectorOfArray{T, N, A} <: AbstractVectorOfArray{T, N, A}
33
u::A # A <: AbstractVector{<: AbstractArray{T, N - 1}}
44
end
55
# VectorOfArray with an added series for time
6-
mutable struct DiffEqArray{T, N, A, B} <: AbstractDiffEqArray{T, N, A}
6+
mutable struct DiffEqArray{T, N, A, B, C, D, E, F} <: AbstractDiffEqArray{T, N, A}
77
u::A # A <: AbstractVector{<: AbstractArray{T, N - 1}}
88
t::B
9+
syms::C
10+
indepsym::D
11+
observed::E
12+
p::F
913
end
1014

1115
Base.Array(VA::AbstractVectorOfArray{T,N,A}) where {T,N,A <: AbstractVector{<:AbstractVector}} = reduce(hcat,VA.u)
@@ -20,10 +24,11 @@ VectorOfArray(vec::AbstractVector{T}, ::NTuple{N}) where {T, N} = VectorOfArray{
2024
VectorOfArray(vec::AbstractVector) = VectorOfArray(vec, (size(vec[1])..., length(vec)))
2125
VectorOfArray(vec::AbstractVector{VT}) where {T, N, VT<:AbstractArray{T, N}} = VectorOfArray{T, N+1, typeof(vec)}(vec)
2226

23-
DiffEqArray(vec::AbstractVector{T}, ts, ::NTuple{N}) where {T, N} = DiffEqArray{eltype(T), N, typeof(vec), typeof(ts)}(vec, ts)
27+
DiffEqArray(vec::AbstractVector{T}, ts, ::NTuple{N}) where {T, N} = DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), Nothing, Nothing, Nothing, Nothing}(vec, ts, nothing, nothing, nothing, nothing)
2428
# Assume that the first element is representative of all other elements
2529
DiffEqArray(vec::AbstractVector,ts::AbstractVector) = DiffEqArray(vec, ts, (size(vec[1])..., length(vec)))
26-
DiffEqArray(vec::AbstractVector{VT},ts::AbstractVector) where {T, N, VT<:AbstractArray{T, N}} = DiffEqArray{T, N+1, typeof(vec), typeof(ts)}(vec, ts)
30+
DiffEqArray(vec::AbstractVector{VT},ts::AbstractVector) where {T, N, VT<:AbstractArray{T, N}} = DiffEqArray{T, N+1, typeof(vec), typeof(ts), Nothing, Nothing, Nothing, Nothing}(vec, ts, nothing, nothing, nothing, nothing)
31+
DiffEqArray(vec::AbstractVector{VT},ts::AbstractVector, syms::Vector{Symbol}, indepsym::Symbol, observed::Function, p) where {T, N, VT<:AbstractArray{T, N}} = DiffEqArray{T, N+1, typeof(vec), typeof(ts), typeof(syms), typeof(indepsym), typeof(observed), typeof(p)}(vec, ts, syms, indepsym, observed, p)
2732

2833
# Interface for the linear indexing. This is just a view of the underlying nested structure
2934
@inline Base.firstindex(VA::AbstractVectorOfArray) = firstindex(VA.u)
@@ -38,6 +43,72 @@ Base.@propagate_inbounds Base.getindex(VA::AbstractVectorOfArray{T, N}, I::Int)
3843
Base.@propagate_inbounds Base.getindex(VA::AbstractVectorOfArray{T, N}, I::Colon) where {T, N} = VA.u[I]
3944
Base.@propagate_inbounds Base.getindex(VA::AbstractVectorOfArray{T, N}, I::AbstractArray{Int}) where {T, N} = VectorOfArray(VA.u[I])
4045
Base.@propagate_inbounds Base.getindex(VA::AbstractDiffEqArray{T, N}, I::AbstractArray{Int}) where {T, N} = DiffEqArray(VA.u[I],VA.t[I])
46+
#
47+
Base.@pure __parameterless_type(T) = Base.typename(T).wrapper
48+
parameterless_type(x) = parameterless_type(typeof(x))
49+
parameterless_type(x::Type) = __parameterless_type(x)
50+
51+
### Abstract Interface
52+
struct AllObserved
53+
end
54+
issymbollike(x) = x isa Symbol ||
55+
x isa AllObserved ||
56+
Symbol(parameterless_type(typeof(x))) == :Operation ||
57+
Symbol(parameterless_type(typeof(x))) == :Variable ||
58+
Symbol(parameterless_type(typeof(x))) == :Sym ||
59+
Symbol(parameterless_type(typeof(x))) == :Num ||
60+
Symbol(parameterless_type(typeof(x))) == :Term
61+
62+
Base.@propagate_inbounds function Base.getindex(A::AbstractDiffEqArray{T, N},sym) where {T, N}
63+
if issymbollike(sym)
64+
i = findfirst(isequal(Symbol(sym)),A.syms)
65+
else
66+
i = sym
67+
end
68+
69+
if i === nothing
70+
# TODO: Check if system actually has a indepsym
71+
if issymbollike(i) && Symbol(i) == A.indepsym
72+
A.t
73+
else
74+
observed(A,sym,:)
75+
end
76+
else
77+
A[i,:]
78+
end
79+
end
80+
Base.@propagate_inbounds function Base.getindex(A::AbstractDiffEqArray{T, N},sym,args...) where {T, N}
81+
if issymbollike(sym)
82+
i = findfirst(isequal(Symbol(sym)),A.syms)
83+
else
84+
i = sym
85+
end
86+
87+
if i === nothing
88+
# TODO: Check if system actually has a indepsym
89+
if issymbollike(i) && Symbol(i) == A.indepsym
90+
A.t[args...]
91+
else
92+
observed(A,sym,args...)
93+
end
94+
else
95+
A[i,args...]
96+
end
97+
end
98+
Base.@propagate_inbounds Base.getindex(A::AbstractDiffEqArray{T, N} where {T, N}, i::Int64, ::Colon) = Base.getindex.(A.u, i)
99+
100+
function observed(A::AbstractDiffEqArray{T, N},sym,i::Int) where {T, N}
101+
A.observed(sym,A.u[i],A.p,A.t[i])
102+
end
103+
104+
function observed(A::AbstractDiffEqArray{T, N},sym,i::AbstractArray{Int}) where {T, N}
105+
A.observed.((sym,),A.u[i],(A.p,),A.t[i])
106+
end
107+
108+
function observed(A::AbstractDiffEqArray{T, N},sym,::Colon) where {T, N}
109+
A.observed.((sym,),A.u,(A.p,),A.t)
110+
end
111+
#
41112
Base.@propagate_inbounds Base.getindex(VA::AbstractVectorOfArray{T, N}, i::Int,::Colon) where {T, N} = [VA.u[j][i] for j in 1:length(VA)]
42113
Base.@propagate_inbounds function Base.getindex(VA::AbstractVectorOfArray{T,N}, ii::CartesianIndex) where {T, N}
43114
ti = Tuple(ii)

test/downstream/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[deps]
2+
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
3+
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"

test/downstream/symbol_indexing.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
using RecursiveArrayTools, ModelingToolkit, OrdinaryDiffEq, Test
2+
3+
@variables t x(t) # independent and dependent variables
4+
@parameters τ # parameters
5+
D = Differential(t) # define an operator for the differentiation w.r.t. time
6+
@variables RHS(t)
7+
@named fol_separate = ODESystem([ RHS ~ (1 - x)/τ,
8+
D(x) ~ RHS ])
9+
fol_simplified = structural_simplify(fol_separate)
10+
11+
prob = ODEProblem(fol_simplified, [x => 0.0], (0.0,10.0), [τ => 3.0])
12+
sol = solve(prob, Tsit5())
13+
14+
sol_new = DiffEqArray(
15+
sol.u[1:10],
16+
sol.t[1:10],
17+
sol.prob.f.syms,
18+
sol.prob.f.indepsym,
19+
sol.prob.f.observed,
20+
sol.prob.p
21+
)
22+
23+
@test sol_new[RHS] (1 .- sol_new[x])./3.0

test/runtests.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,16 @@
1+
using Pkg
12
using RecursiveArrayTools
23
using Test
34

5+
const GROUP = get(ENV, "GROUP", "All")
6+
const is_APPVEYOR = ( Sys.iswindows() && haskey(ENV,"APPVEYOR") )
7+
8+
function activate_downstream_env()
9+
Pkg.activate("downstream")
10+
Pkg.develop(PackageSpec(path=dirname(@__DIR__)))
11+
Pkg.instantiate()
12+
end
13+
414
@time begin
515
@time @testset "Utils Tests" begin include("utils_test.jl") end
616
@time @testset "Partitions Tests" begin include("partitions_test.jl") end
@@ -10,4 +20,9 @@ using Test
1020
@time @testset "Linear Algebra Tests" begin include("linalg.jl") end
1121
@time @testset "Upstream Tests" begin include("upstream.jl") end
1222
@time @testset "Adjoint Tests" begin include("adjoints.jl") end
23+
24+
if !is_APPVEYOR && GROUP == "Downstream"
25+
activate_downstream_env()
26+
@time @testset "DiffEqArray Indexing Tests" begin include("downstream/symbol_indexing.jl") end
27+
end
1328
end

0 commit comments

Comments
 (0)