Skip to content

Commit 0ad8caa

Browse files
Merge pull request #18 from SciML/as/all-symbols
feat: add `all_solvable_symbols` and `all_symbols`
2 parents 0e67d45 + ae50f9d commit 0ad8caa

File tree

11 files changed

+246
-5
lines changed

11 files changed

+246
-5
lines changed

docs/Project.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
[deps]
2-
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
32
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
3+
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
4+
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
5+
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
6+
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
47

58
[compat]
69
Documenter = "0.27"

docs/pages.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
pages = [
44
"Home" => "index.md",
5-
"API" => "api.md",
65
"Tutorial" => "tutorial.md",
6+
"Usage" => "usage.md",
7+
"API" => "api.md",
78
]

docs/src/api.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ is_observed
1414
observed
1515
is_time_dependent
1616
constant_structure
17+
all_variable_symbols
18+
all_symbols
19+
solvedvariables
20+
allvariables
1721
parameter_values
1822
getp
1923
setp

docs/src/tutorial.md

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ struct ExampleSolution
4141
state_index::Dict{Symbol,Int}
4242
parameter_index::Dict{Symbol,Int}
4343
independent_variable::Union{Symbol,Nothing}
44+
# mapping from observed variable to Expr to calculate its value
45+
observed::Dict{Symbol,Expr}
4446
u::Vector{Vector{Float64}}
4547
p::Vector{Float64}
4648
t::Vector{Float64}
@@ -86,9 +88,9 @@ function SymbolicIndexingInterface.independent_variable_symbols(sys::ExampleSolu
8688
sys.independent_variable === nothing ? [] : [sys.independent_variable]
8789
end
8890

89-
# this types accepts `Expr` for observed expressions involving state/parameter
91+
# this type accepts `Expr` for observed expressions involving state/parameter/observed
9092
# variables
91-
SymbolicIndexingInterface.is_observed(sys::ExampleSolution, sym) = sym isa Expr
93+
SymbolicIndexingInterface.is_observed(sys::ExampleSolution, sym) = sym isa Expr || sym isa Symbol && haskey(sys.observed, sym)
9294

9395
function SymbolicIndexingInterface.observed(sys::ExampleSolution, sym::Expr)
9496
if is_time_dependent(sys)
@@ -109,6 +111,21 @@ function SymbolicIndexingInterface.is_time_dependent(sys::ExampleSolution)
109111
end
110112

111113
SymbolicIndexingInterface.constant_structure(::ExampleSolution) = true
114+
115+
function SymbolicIndexingInterface.all_solvable_symbols(sys::ExampleSolution)
116+
return vcat(
117+
collect(keys(sys.state_index)),
118+
collect(keys(sys.observed)),
119+
)
120+
end
121+
122+
function SymbolicIndexingInterface.all_symbols(sys::ExampleSolution)
123+
return vcat(
124+
all_solvable_symbols(sys),
125+
collect(keys(sys.parameter_index)),
126+
sys.independent_variable === nothing ? Symbol[] : sys.independent_variable
127+
)
128+
end
112129
```
113130

114131
Note that the method definitions are all assuming `constant_structure(p) == true`.

docs/src/usage.md

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
# Using the SymbolicIndexingInterface
2+
3+
This tutorial will cover ways to use the interface for types that implement it.
4+
Consider the following example:
5+
6+
```@example Usage
7+
using ModelingToolkit, OrdinaryDiffEq, SymbolicIndexingInterface, Plots
8+
9+
@parameters σ ρ β
10+
@variables t x(t) y(t) z(t) w(t)
11+
D = Differential(t)
12+
13+
eqs = [D(D(x)) ~ σ * (y - x),
14+
D(y) ~ x * (ρ - z) - y,
15+
D(z) ~ x * y - β * z,
16+
w ~ x + y + z]
17+
18+
@named sys = ODESystem(eqs)
19+
sys = structural_simplify(sys)
20+
```
21+
22+
The system has 4 state variables, 3 parameters and one observed variable:
23+
```@example Usage
24+
observed(sys)
25+
```
26+
27+
Solving the system,
28+
```@example Usage
29+
u0 = [D(x) => 2.0,
30+
x => 1.0,
31+
y => 0.0,
32+
z => 0.0]
33+
34+
p = [σ => 28.0,
35+
ρ => 10.0,
36+
β => 8 / 3]
37+
38+
tspan = (0.0, 100.0)
39+
prob = ODEProblem(sys, u0, tspan, p, jac = true)
40+
sol = solve(prob, Tsit5())
41+
```
42+
43+
We can obtain the timeseries of any time-dependent variable using `getindex`
44+
```@example Usage
45+
sol[x]
46+
```
47+
48+
This also works for arrays or tuples of variables, including observed quantities and
49+
independent variables, for interpolating solutions, and plotting:
50+
```@example Usage
51+
sol[[x, y]]
52+
```
53+
54+
```@example Usage
55+
sol[(t, w)]
56+
```
57+
58+
```@example Usage
59+
sol(1.3, idxs=x)
60+
```
61+
62+
```@example Usage
63+
sol(1.3, idxs=[x, w])
64+
```
65+
66+
```@example Usage
67+
sol(1.3, idxs=[:y, :z])
68+
```
69+
70+
```@example Usage
71+
plot(sol, idxs=x)
72+
```
73+
74+
If necessary, `Symbol`s can be used to refer to variables. This is only valid for
75+
symbolic variables for which [`hasname`](@ref) returns `true`. The `Symbol` used must
76+
match the one returned by [`getname`](@ref) for the variable.
77+
```@example Usage
78+
hasname(x)
79+
```
80+
81+
```@example Usage
82+
getname(x)
83+
```
84+
85+
```@example Usage
86+
sol[(:x, :w)]
87+
```
88+
89+
Note how when indexing with an array, the returned type is a `Vector{Array{Float64}}`,
90+
and when using a `Tuple`, the returned type is `Vector{Tuple{Float64, Float64}}`.
91+
To obtain the value of all state variables, we can use the shorthand:
92+
```@example Usage
93+
sol[solvedvariables] # equivalent to sol[variable_symbols(sol)]
94+
```
95+
96+
This does not include the observed variable `w`. To include observed variables in the
97+
output, the following shorthand is used:
98+
```@example Usage
99+
sol[allvariables] # equivalent to sol[all_variable_symbols(sol)]
100+
```
101+
102+
Parameters cannot be obtained using this syntax, and instead require using [`getp`](@ref) and [`setp`](@ref).
103+
104+
```@example Usage
105+
σ_getter = getp(sys, σ)
106+
σ_getter(sol) # can also pass `prob`
107+
```
108+
109+
Note that this also supports arrays/tuples of parameter symbols:
110+
111+
```@example Usage
112+
σ_ρ_getter = getp(sys, (σ, ρ))
113+
σ_ρ_getter(sol)
114+
```
115+
116+
Now suppose the system has to be solved with a different value of the parameter `β`.
117+
118+
```@example Usage
119+
β_setter = setp(sys, β)
120+
β_setter(prob, 3)
121+
```
122+
123+
The updated parameter values can be checked using [`parameter_values`](@ref).
124+
125+
```@example Usage
126+
parameter_values(prob)
127+
```
128+
129+
Solving the new system, note that the parameter getter functions still work on the new
130+
solution object.
131+
132+
```@example Usage
133+
sol2 = solve(prob, Tsit5())
134+
σ_getter(sol)
135+
```
136+
137+
```@example Usage
138+
σ_ρ_getter(sol)
139+
```
140+
141+
To set the entire parameter vector at once, [`parameter_values`](@ref) can be used
142+
(note the usage of broadcasted assignment).
143+
144+
```@example Usage
145+
parameter_values(prob) .= [29.0, 11.0, 2.5]
146+
parameter_values(prob)
147+
```

src/SymbolicIndexingInterface.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ include("trait.jl")
55

66
export is_variable, variable_index, variable_symbols, is_parameter, parameter_index,
77
parameter_symbols, is_independent_variable, independent_variable_symbols, is_observed,
8-
observed, is_time_dependent, constant_structure, symbolic_container
8+
observed, is_time_dependent, constant_structure, symbolic_container, all_variable_symbols,
9+
all_symbols, solvedvariables, allvariables
910
include("interface.jl")
1011

1112
export SymbolCache

src/interface.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ variable_index(sys, sym, i) = variable_index(symbolic_container(sys), sym, i)
3333
Return a vector of the symbolic variables being solved for in the system `sys`. If
3434
`constant_structure(sys) == false` this accepts an additional parameter indicating
3535
the current time index. The returned vector should not be mutated.
36+
37+
For types that implement `Base.getindex` with symbolic indices using this interface,
38+
The shorthand `sys[solvedvariables]` can be used as shorthand for
39+
`sys[variable_symbols(sys)]`. See: [`solvedvariables`](@ref).
3640
"""
3741
variable_symbols(sys) = variable_symbols(symbolic_container(sys))
3842
variable_symbols(sys, i) = variable_symbols(symbolic_container(sys), i)
@@ -110,3 +114,46 @@ Check if `sys` has a constant structure. Constant structure systems do not chang
110114
number of variables or parameters over time.
111115
"""
112116
constant_structure(sys) = constant_structure(symbolic_container(sys))
117+
118+
"""
119+
all_variable_symbols(sys)
120+
121+
Return a vector of variable symbols in the system, including observed quantities.
122+
123+
For types that implement `Base.getindex` with symbolic indices using this interface,
124+
The shorthand `sys[allvariables]` can be used as shorthand for
125+
`sys[all_variable_symbols(sys)]`. See: [`allvariables`](@ref).
126+
"""
127+
all_variable_symbols(sys) = all_variable_symbols(symbolic_container(sys))
128+
129+
"""
130+
all_symbols(sys)
131+
132+
Return an array of all symbols in the system. This includes parameters and independent
133+
variables.
134+
"""
135+
all_symbols(sys) = all_symbols(symbolic_container(sys))
136+
137+
struct SolvedVariables end
138+
139+
"""
140+
const solvedvariables = SolvedVariables()
141+
142+
This singleton is used as a shortcut to allow indexing all solution variables
143+
(excluding observed quantities). It has a [`symbolic_type`](@ref) of
144+
[`ScalarSymbolic`](@ref). See: [`variable_symbols`](@ref).
145+
"""
146+
const solvedvariables = SolvedVariables()
147+
symbolic_type(::Type{SolvedVariables}) = ScalarSymbolic()
148+
149+
struct AllVariables end
150+
151+
"""
152+
const allvariables = AllVariables()
153+
154+
This singleton is used as a shortcut to allow indexing all solution variables
155+
(including observed quantities). It has a [`symbolic_type`](@ref) of
156+
[`ScalarSymbolic`](@ref). See [`all_variable_symbols`](@ref).
157+
"""
158+
const allvariables = AllVariables()
159+
symbolic_type(::Type{AllVariables}) = ScalarSymbolic()

src/symbol_cache.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ function is_time_dependent(sc::SymbolCache)
7171
end
7272
end
7373
constant_structure(::SymbolCache) = true
74+
all_variable_symbols(sc::SymbolCache) = variable_symbols(sc)
75+
all_symbols(sc::SymbolCache) = vcat(variable_symbols(sc), parameter_symbols(sc), independent_variable_symbols(sc))
7476

7577
function Base.copy(sc::SymbolCache)
7678
return SymbolCache(sc.variables === nothing ? nothing : copy(sc.variables),

test/example_test.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@ function SymbolicIndexingInterface.observed(sys::SystemMockup, sym, states = not
5353
end
5454
SymbolicIndexingInterface.is_time_dependent(sys::SystemMockup) = isequal(sys.indepvar, :t)
5555
SymbolicIndexingInterface.constant_structure(sys::SystemMockup) = sys.static
56+
SymbolicIndexingInterface.all_variable_symbols(sys::SystemMockup) = sys.vars
57+
function SymbolicIndexingInterface.all_symbols(sys::SystemMockup)
58+
vcat(sys.vars, sys.params, independent_variable_symbols(sys))
59+
end
5660

5761
sys = SystemMockup(true, [:x, :y, :z], [:a, :b, :c], :t)
5862

@@ -79,6 +83,8 @@ sys = SystemMockup(true, [:x, :y, :z], [:a, :b, :c], :t)
7983
@test variable_symbols(sys) == [:x, :y, :z]
8084
@test parameter_symbols(sys) == [:a, :b, :c]
8185
@test independent_variable_symbols(sys) == [:t]
86+
@test all_variable_symbols(sys) == [:x, :y, :z]
87+
@test sort(all_symbols(sys)) == [:a, :b, :c, :t, :x, :y, :z]
8288

8389
sys = SystemMockup(true, [:x, :y, :z], [:a, :b, :c], nothing)
8490

@@ -93,6 +99,8 @@ sys = SystemMockup(true, [:x, :y, :z], [:a, :b, :c], nothing)
9399
@test variable_symbols(sys) == [:x, :y, :z]
94100
@test parameter_symbols(sys) == [:a, :b, :c]
95101
@test independent_variable_symbols(sys) == []
102+
@test all_variable_symbols(sys) == [:x, :y, :z]
103+
@test sort(all_symbols(sys)) == [:a, :b, :c, :x, :y, :z]
96104

97105
sys = SystemMockup(false, [:x, :y, :z], [:a, :b, :c], :t)
98106
@test !constant_structure(sys)
@@ -115,3 +123,5 @@ end
115123
@test variable_symbols(sys, 3) == [:x, :y, :z]
116124
@test parameter_symbols(sys) == [:a, :b, :c]
117125
@test independent_variable_symbols(sys) == [:t]
126+
@test all_variable_symbols(sys) == [:x, :y, :z]
127+
@test sort(all_symbols(sys)) == [:a, :b, :c, :t, :x, :y, :z]

test/fallback_test.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,5 @@ all_syms = [:x, :y, :z, :a, :b, :t]
2222
@test variable_symbols(sys) == variable_symbols(sc)
2323
@test parameter_symbols(sys) == parameter_symbols(sc)
2424
@test independent_variable_symbols(sys) == independent_variable_symbols(sc)
25+
@test all_variable_symbols(sys) == variable_symbols(sc)
26+
@test all_symbols(sys) == all_symbols(sc)

0 commit comments

Comments
 (0)