Skip to content

Commit 0b925c0

Browse files
refactor: add new *_symbols methods, corresponding singletons, update tests
1 parent 2107adc commit 0b925c0

File tree

9 files changed

+185
-15
lines changed

9 files changed

+185
-15
lines changed

docs/Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
[deps]
2-
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
32
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
3+
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
4+
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
5+
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
46

57
[compat]
68
Documenter = "0.27"

docs/pages.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
pages = [
44
"Home" => "index.md",
5-
"API" => "api.md",
65
"Tutorial" => "tutorial.md",
6+
"Usage" => "usage.md",
7+
"API" => "api.md",
78
]

docs/src/api.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@ is_observed
1414
observed
1515
is_time_dependent
1616
constant_structure
17-
all_solvable_symbols
17+
all_variable_symbols
1818
all_symbols
19+
solvedvariables
20+
allvariables
1921
parameter_values
2022
getp
2123
setp

docs/usage.md

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# Using the SymbolicIndexingInterface
2+
3+
This tutorial will cover ways to use the interface for types that implement it.
4+
Consider the following example:
5+
6+
```@example Usage
7+
using ModelingToolkit, OrdinaryDiffEq, SymbolicIndexingInterface
8+
9+
@parameters σ ρ β
10+
@variables t x(t) y(t) z(t) w(t)
11+
D = Differential(t)
12+
13+
eqs = [D(D(x)) ~ σ * (y - x),
14+
D(y) ~ x * (ρ - z) - y,
15+
D(z) ~ x * y - β * z,
16+
w ~ x + y + z]
17+
18+
@named sys = ODESystem(eqs)
19+
sys = structural_simplify(sys)
20+
```
21+
22+
The system has 4 state variables, 3 parameters and one observed variable:
23+
```@example Usage
24+
observed(sys)
25+
```
26+
27+
Solving the system,
28+
```@example Usage
29+
u0 = [D(x) => 2.0,
30+
x => 1.0,
31+
y => 0.0,
32+
z => 0.0]
33+
34+
p = [σ => 28.0,
35+
ρ => 10.0,
36+
β => 8 / 3]
37+
38+
tspan = (0.0, 100.0)
39+
prob = ODEProblem(sys, u0, tspan, p, jac = true)
40+
sol = solve(prob, Tsit5())
41+
```
42+
43+
We can obtain the timeseries of any time-dependent variable using `getindex`
44+
```@example Usage
45+
sol[x]
46+
```
47+
48+
This also works for arrays or tuples of variables, including observed quantities and
49+
independent variables:
50+
```@example Usage
51+
sol[[x, y]]
52+
```
53+
54+
```@example Usage
55+
sol[(t, w)]
56+
```
57+
58+
If necessary, `Symbol`s can be used to refer to variables. This is only valid for
59+
symbolic variables for which [`hasname`](@ref) returns `true`. The `Symbol` used must
60+
match the one returned by [`getname`](@ref) for the variable.
61+
```@example Usage
62+
hasname(x)
63+
```
64+
65+
```@example Usage
66+
getname(x)
67+
```
68+
69+
```@example Usage
70+
sol[(:x, :w)]
71+
```
72+
73+
Note how when indexing with an array, the returned type is a `Vector{Array{Float64}}`,
74+
and when using a `Tuple`, the returned type is `Vector{Tuple{Float64, Float64}}`.
75+
To obtain the value of all state variables, we can use the shorthand:
76+
```@example Usage
77+
sol[solvedvariables] # equivalent to sol[variable_symbols(sol)]
78+
```
79+
80+
This does not include the observed variable `w`. To include observed variables in the
81+
output, the following shorthand is used:
82+
```@example Usage
83+
sol[allvariables] # equivalent to sol[all_variable_symbols(sol)]
84+
```
85+
86+
Parameters cannot be obtained using this syntax, and instead require using [`getp`](@ref) and [`setp`](@ref).
87+
88+
```@example Usage
89+
σ_getter = getp(sys, σ)
90+
σ_getter(sol) # can also pass `prob`
91+
```
92+
93+
Note that this also supports arrays/tuples of parameter symbols:
94+
95+
```@example Usage
96+
σ_ρ_getter = getp(sys, (σ, ρ))
97+
σ_ρ_getter(sol)
98+
```
99+
100+
Now suppose the system has to be solved with a different value of the parameter `β`.
101+
102+
```@example Usage
103+
β_setter = setp(sys, β)
104+
β_setter(prob, 3)
105+
```
106+
107+
The updated parameter values can be checked using [`parameter_values`](@ref).
108+
109+
```@example Usage
110+
parameter_values(prob)
111+
```
112+
113+
Solving the new system, note that the parameter getter functions still work on the new
114+
solution object.
115+
116+
```@example Usage
117+
sol2 = solve(prob, Tsit5())
118+
σ_getter(sol)
119+
```
120+
121+
```@example Usage
122+
σ_ρ_getter(sol)
123+
```
124+
125+
To set the entire parameter vector at once, [`parameter_values`](@ref) can be used
126+
(note the usage of broadcasted assignment).
127+
128+
```@example Usage
129+
parameter_values(prob) .= [29.0, 11.0, 2.5]
130+
parameter_values(prob)
131+
```

src/interface.jl

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ variable_index(sys, sym, i) = variable_index(symbolic_container(sys), sym, i)
3333
Return a vector of the symbolic variables being solved for in the system `sys`. If
3434
`constant_structure(sys) == false` this accepts an additional parameter indicating
3535
the current time index. The returned vector should not be mutated.
36+
37+
For types that implement `Base.getindex` with symbolic indices using this interface,
38+
The shorthand `sys[solvedvariables]` can be used as shorthand for
39+
`sys[variable_symbols(sys)]`. See: [`solvedvariables`](@ref).
3640
"""
3741
variable_symbols(sys) = variable_symbols(symbolic_container(sys))
3842
variable_symbols(sys, i) = variable_symbols(symbolic_container(sys), i)
@@ -112,12 +116,16 @@ number of variables or parameters over time.
112116
constant_structure(sys) = constant_structure(symbolic_container(sys))
113117

114118
"""
115-
all_solvable_symbols(sys)
119+
all_variables(sys)
120+
121+
Return a vector of pairs, where the first element of each pair is a symbolic variable
122+
and the second is its initial value. This includes observed quantities.
116123
117-
Return an array of all symbols in the system that can be solved for. This includes
118-
observed variables, but not parameters or independent variables.
124+
For types that implement `Base.getindex` with symbolic indices using this interface,
125+
The shorthand `sys[allvariables]` can be used as shorthand for
126+
`sys[all_variable_symbols(sys)]`. See: [`allvariables`](@ref).
119127
"""
120-
all_solvable_symbols(sys) = all_solvable_symbols(symbolic_container(sys))
128+
all_variable_symbols(sys) = all_variable_symbols(symbolic_container(sys))
121129

122130
"""
123131
all_symbols(sys)
@@ -126,3 +134,27 @@ Return an array of all symbols in the system. This includes parameters and indep
126134
variables.
127135
"""
128136
all_symbols(sys) = all_symbols(symbolic_container(sys))
137+
138+
struct SolvedVariables end
139+
140+
"""
141+
const solvedvariables = SolvedVariables()
142+
143+
This singleton is used as a shortcut to allow indexing all solution variables
144+
(excluding observed quantities). It has a [`symbolic_type`](@ref) of
145+
[`ScalarSymbolic`](@ref). See: [`variable_symbols`](@ref).
146+
"""
147+
const solvedvariables = SolvedVariables()
148+
symbolic_type(::Type{AllVariables}) = ScalarSymbolic()
149+
150+
struct AllVariables end
151+
152+
"""
153+
const allvariables = AllVariables()
154+
155+
This singleton is used as a shortcut to allow indexing all solution variables
156+
(including observed quantities). It has a [`symbolic_type`](@ref) of
157+
[`ScalarSymbolic`](@ref). See [`all_variable_symbols`](@ref).
158+
"""
159+
const allvariables = AllVariables()
160+
symbolic_type(::Type{AllVariables}) = ScalarSymbolic()

src/symbol_cache.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ function is_time_dependent(sc::SymbolCache)
7171
end
7272
end
7373
constant_structure(::SymbolCache) = true
74-
all_solvable_symbols(sc::SymbolCache) = variable_symbols(sc)
74+
all_variable_symbols(sc::SymbolCache) = variable_symbols(sc)
7575
all_symbols(sc::SymbolCache) = vcat(variable_symbols(sc), parameter_symbols(sc), independent_variable_symbols(sc))
7676

7777
function Base.copy(sc::SymbolCache)

test/example_test.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ function SymbolicIndexingInterface.observed(sys::SystemMockup, sym, states = not
5353
end
5454
SymbolicIndexingInterface.is_time_dependent(sys::SystemMockup) = isequal(sys.indepvar, :t)
5555
SymbolicIndexingInterface.constant_structure(sys::SystemMockup) = sys.static
56-
SymbolicIndexingInterface.all_solvable_symbols(sys::SystemMockup) = sys.vars
56+
SymbolicIndexingInterface.all_variable_symbols(sys::SystemMockup) = sys.vars
5757
function SymbolicIndexingInterface.all_symbols(sys::SystemMockup)
5858
vcat(sys.vars, sys.params, independent_variable_symbols(sys))
5959
end
@@ -83,7 +83,7 @@ sys = SystemMockup(true, [:x, :y, :z], [:a, :b, :c], :t)
8383
@test variable_symbols(sys) == [:x, :y, :z]
8484
@test parameter_symbols(sys) == [:a, :b, :c]
8585
@test independent_variable_symbols(sys) == [:t]
86-
@test all_solvable_symbols(sys) == [:x, :y, :z]
86+
@test all_variable_symbols(sys) == [:x, :y, :z]
8787
@test sort(all_symbols(sys)) == [:a, :b, :c, :t, :x, :y, :z]
8888

8989
sys = SystemMockup(true, [:x, :y, :z], [:a, :b, :c], nothing)
@@ -99,7 +99,7 @@ sys = SystemMockup(true, [:x, :y, :z], [:a, :b, :c], nothing)
9999
@test variable_symbols(sys) == [:x, :y, :z]
100100
@test parameter_symbols(sys) == [:a, :b, :c]
101101
@test independent_variable_symbols(sys) == []
102-
@test all_solvable_symbols(sys) == [:x, :y, :z]
102+
@test all_variable_symbols(sys) == [:x, :y, :z]
103103
@test sort(all_symbols(sys)) == [:a, :b, :c, :x, :y, :z]
104104

105105
sys = SystemMockup(false, [:x, :y, :z], [:a, :b, :c], :t)
@@ -123,5 +123,5 @@ end
123123
@test variable_symbols(sys, 3) == [:x, :y, :z]
124124
@test parameter_symbols(sys) == [:a, :b, :c]
125125
@test independent_variable_symbols(sys) == [:t]
126-
@test all_solvable_symbols(sys) == [:x, :y, :z]
126+
@test all_variable_symbols(sys) == [:x, :y, :z]
127127
@test sort(all_symbols(sys)) == [:a, :b, :c, :t, :x, :y, :z]

test/fallback_test.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,5 @@ all_syms = [:x, :y, :z, :a, :b, :t]
2222
@test variable_symbols(sys) == variable_symbols(sc)
2323
@test parameter_symbols(sys) == parameter_symbols(sc)
2424
@test independent_variable_symbols(sys) == independent_variable_symbols(sc)
25+
@test all_variable_symbols(sys) == variable_symbols(sc)
26+
@test all_symbols(sys) == all_symbols(sc)

test/symbol_cache_test.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ sc = SymbolCache([:x, :y, :z], [:a, :b], [:t])
1414
@test variable_symbols(sc) == [:x, :y, :z]
1515
@test parameter_symbols(sc) == [:a, :b]
1616
@test independent_variable_symbols(sc) == [:t]
17-
@test all_solvable_symbols(sc) == [:x, :y, :z]
17+
@test all_variable_symbols(sc) == [:x, :y, :z]
1818
@test sort(all_symbols(sc)) == [:a, :b, :t, :x, :y, :z]
1919

2020
sc = SymbolCache([:x, :y], [:a, :b])
@@ -33,15 +33,15 @@ sc = SymbolCache()
3333
@test all(.!is_independent_variable.((sc,), [:x, :y, :a, :b, :t]))
3434
@test independent_variable_symbols(sc) == []
3535
@test !is_time_dependent(sc)
36-
@test all_solvable_symbols(sc) == []
36+
@test all_variable_symbols(sc) == []
3737
@test all_symbols(sc) == []
3838

3939
sc = SymbolCache(nothing, nothing, :t)
4040
@test all(.!is_independent_variable.((sc,), [:x, :y, :a, :b]))
4141
@test is_independent_variable(sc, :t)
4242
@test independent_variable_symbols(sc) == [:t]
4343
@test is_time_dependent(sc)
44-
@test all_solvable_symbols(sc) == []
44+
@test all_variable_symbols(sc) == []
4545
@test all_symbols(sc) == [:t]
4646

4747
sc2 = copy(sc)

0 commit comments

Comments
 (0)