Skip to content

Commit 7a8e647

Browse files
feat: add set_state! and set_parameter!, update docs
1 parent 0166940 commit 7a8e647

File tree

6 files changed

+141
-31
lines changed

6 files changed

+141
-31
lines changed

docs/src/api.md

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# Interface Functions
22

3+
## Mandatory methods
4+
35
```@docs
46
symbolic_container
57
is_variable
@@ -11,26 +13,45 @@ parameter_symbols
1113
is_independent_variable
1214
independent_variable_symbols
1315
is_observed
14-
observed
1516
is_time_dependent
1617
constant_structure
1718
all_variable_symbols
1819
all_symbols
1920
solvedvariables
2021
allvariables
22+
```
23+
24+
## Optional Methods
25+
26+
### Observed equation handling
27+
28+
```@docs
29+
observed
30+
```
31+
32+
### Parameter indexing
33+
34+
```@docs
35+
parameter_values
36+
set_parameter!
37+
getp
38+
setp
39+
```
40+
41+
### State indexing
42+
43+
```@docs
2144
Timeseries
2245
NotTimeseries
2346
is_timeseries
2447
state_values
25-
parameter_values
48+
set_state!
2649
current_time
27-
getp
28-
setp
2950
getu
3051
setu
3152
```
3253

33-
# Traits
54+
# Symbolic Trait
3455

3556
```@docs
3657
ScalarSymbolic

docs/src/complete_sii.md

Lines changed: 71 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,9 @@ function SymbolicIndexingInterface.observed(sys::ExampleSystem, sym::Expr)
123123
end
124124
```
125125

126+
In case a type does not support such observed quantities, `is_observed` must be
127+
defined to always return `false`, and `observed` does not need to be implemented.
128+
126129
### Note about constant structure
127130

128131
Note that the method definitions are all assuming `constant_structure(p) == true`.
@@ -174,35 +177,91 @@ mutable struct ExampleIntegrator
174177
u::Vector{Float64}
175178
p::Vector{Float64}
176179
t::Float64
177-
state_index::Dict{Symbol,Int}
178-
parameter_index::Dict{Symbol,Int}
179-
independent_variable::Symbol
180+
sys::ExampleSystem
180181
end
181-
```
182182

183-
Assume that it implements the mandatory part of the interface as described above, and
184-
the following methods below:
185-
186-
```julia
183+
# define a fallback for the interface methods
184+
SymbolicIndexingInterface.symbolic_container(integ::ExampleIntegrator) = integ.sys
187185
SymbolicIndexingInterface.state_values(sys::ExampleIntegrator) = sys.u
188186
SymbolicIndexingInterface.parameter_values(sys::ExampleIntegrator) = sys.p
189187
SymbolicIndexingInterface.current_time(sys::ExampleIntegrator) = sys.t
190188
```
191189

192190
Then the following example would work:
193191
```julia
194-
integrator = ExampleIntegrator([1.0, 2.0, 3.0], [4.0, 5.0], 6.0, Dict(:x => 1, :y => 2, :z => 3), Dict(:a => 1, :b => 2), :t)
195-
getx = getu(integrator, :x)
192+
sys = ExampleSystem(Dict(:x => 1, :y => 2, :z => 3), Dict(:a => 1, :b => 2), :t, Dict())
193+
integrator = ExampleIntegrator([1.0, 2.0, 3.0], [4.0, 5.0], 6.0, sys)
194+
getx = getu(sys, :x)
196195
getx(integrator) # 1.0
197196

198-
get_expr = getu(integrator, :(x + y + t))
197+
get_expr = getu(sys, :(x + y + t))
199198
get_expr(integrator) # 13.0
200199

201-
setx! = setu(integrator, :y)
200+
setx! = setu(sys, :y)
202201
setx!(integrator, 0.0)
203202
getx(integrator) # 0.0
204203
```
205204

205+
In case a type stores timeseries data (such as solutions), then it must also implement
206+
the [`Timeseries`](@ref) trait. The type would then return a timeseries from
207+
[`state_values`](@ref) and [`current_time`](@ref) and the function returned from
208+
[`getu`](@ref) would then return a timeseries as well. For example, consider the
209+
`ExampleSolution` below:
210+
211+
```julia
212+
struct ExampleSolution
213+
u::Vector{Vector{Float64}}
214+
t::Vector{Float64}
215+
p::Vector{Float64}
216+
sys::ExampleSystem
217+
end
218+
219+
# define a fallback for the interface methods
220+
SymbolicIndexingInterface.symbolic_container(sol::ExampleSolution) = sol.sys
221+
SymbolicIndexingInterface.parameter_values(sol::ExampleSolution) = sol.p
222+
# define the trait
223+
SymbolicIndexingInterface.is_timeseries(::Type{ExampleSolution}) = Timeseries()
224+
# both state_values and current_time return a timeseries, which must be
225+
# the same length
226+
SymbolicIndexingInterface.state_values(sol::ExampleSolution) = sol.u
227+
SymbolicIndexingInterface.current_time(sol::ExampleSolution) = sol.t
228+
```
229+
230+
Then the following example would work:
231+
```julia
232+
# using the same system that the ExampleIntegrator used
233+
sol = ExampleSolution([[1.0, 2.0, 3.0], [1.5, 2.5, 3.5]], [4.0, 5.0], [6.0, 7.0], sys)
234+
getx = getu(sys, :x)
235+
getx(sol) # [1.0, 1.5]
236+
237+
get_expr = getu(sys, :(x + y + t))
238+
get_expr(sol) # [9.0, 11.0]
239+
240+
get_arr = getu(sys, [:y, :(x + a)])
241+
get_arr(sol) # [[2.0, 5.0], [2.5, 5.5]]
242+
243+
get_tuple = getu(sys, (:z, :(z * t)))
244+
get_tuple(sol) # [(3.0, 18.0), (3.5, 24.5)]
245+
```
246+
247+
Note that `setu` is not designed to work for `Timeseries` objects.
248+
249+
If a type needs to perform some additional actions when updating the state/parameters
250+
or if it is not possible to return a mutable reference to the state/parameter vector
251+
which can directly be modified, the functions [`set_state!`](@ref) and/or
252+
[`set_parameter!`](@ref) can be used. For example, suppose our `ExampleIntegrator`
253+
had an additional field `u_modified::Bool` to allow it to keep track of when a
254+
discontinuity occurs and handle it appropriately. This flag needs to be set to `true`
255+
whenever the state is modified. The `set_state!` function can then be implemented as
256+
follows:
257+
258+
```julia
259+
function SymbolicIndexingInterface.set_state!(integrator::ExampleIntegrator, val, idx)
260+
integrator.u[idx] = val
261+
integrator.u_modified = true
262+
end
263+
```
264+
206265
# Implementing the `SymbolicTypeTrait` for a type
207266

208267
The `SymbolicTypeTrait` is used to identify values that can act as symbolic variables. It

src/SymbolicIndexingInterface.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,11 @@ include("interface.jl")
1313
export SymbolCache
1414
include("symbol_cache.jl")
1515

16-
export parameter_values, getp, setp
16+
export parameter_values, set_parameter!, getp, setp
1717
include("parameter_indexing.jl")
1818

19-
export Timeseries, NotTimeseries, is_timeseries, state_values, current_time, getu, setu
19+
export Timeseries,
20+
NotTimeseries, is_timeseries, state_values, set_state!, current_time, getu, setu
2021
include("state_indexing.jl")
2122

2223
end

src/interface.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,9 @@ have the signature `(u, p) -> [values...]` where `u` and `p` is the current stat
9191
parameter vector, respectively. If `istimedependent(sys) == true`, the function should accept
9292
the current time `t` as its third parameter. If `constant_structure(sys) == false`,
9393
accept a third parameter, which can either be a vector of symbols indicating the order
94-
of states or a time index, which identifies the order of states.
94+
of states or a time index, which identifies the order of states. This function
95+
does not need to be defined if [`is_observed`](@ref) always returns `false`. Thus,
96+
it is mandatory to always check `is_observed` before using this function.
9597
9698
See also: [`is_time_dependent`](@ref), [`constant_structure`](@ref)
9799
"""

src/parameter_indexing.jl

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,20 @@ Return an indexable collection containing the value of each parameter in `p`.
55
"""
66
function parameter_values end
77

8+
"""
9+
set_parameter!(sys, val, idx)
10+
11+
Set the parameter at index `idx` to `val` for system `sys`. This defaults to modifying
12+
`parameter_values(sys)`. If any additional bookkeeping needs to be performed or the
13+
default implementation does not work for a particular type, this method needs to be
14+
defined to enable the proper functioning of [`setp`](@ref).
15+
16+
See: [`parameter_values`](@ref)
17+
"""
18+
function set_parameter!(sys, val, idx)
19+
parameter_values(sys)[idx] = val
20+
end
21+
822
"""
923
getp(sys, p)
1024
@@ -55,8 +69,9 @@ Return a function that takes an integrator of `sys` and a value, and sets
5569
the parameter `p` to that value. Note that `p` can be a direct numerical index or a
5670
symbolic value. Requires that the integrator implement [`parameter_values`](@ref) and the
5771
returned collection be a mutable reference to the parameter vector in the integrator. In
58-
case `parameter_values` cannot return such a mutable reference, `setp` needs to be
59-
implemented manually.
72+
case `parameter_values` cannot return such a mutable reference, or additional actions
73+
need to be performed when updating parameters, [`set_parameter!`](@ref) must be
74+
implemented.
6075
"""
6176
function setp(sys, p)
6277
symtype = symbolic_type(p)
@@ -70,21 +85,21 @@ end
7085

7186
function _setp(sys, ::NotSymbolic, p)
7287
return function setter!(sol, val)
73-
parameter_values(sol)[p] = val
88+
set_parameter!(sol, val, p)
7489
end
7590
end
7691

7792
function _setp(sys, ::ScalarSymbolic, p)
7893
idx = parameter_index(sys, p)
7994
return function setter!(sol, val)
80-
parameter_values(sol)[idx] = val
95+
set_parameter!(sol, val, idx)
8196
end
8297
end
8398

8499
function _setp(sys, ::ScalarSymbolic, p::Union{Tuple, AbstractArray})
85100
idxs = parameter_index.((sys,), p)
86101
return function setter!(sol, val)
87-
setindex!.((parameter_values(sol),), val, idxs)
102+
set_parameter!.((sol,), val, idxs)
88103
end
89104
end
90105

src/state_indexing.jl

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,20 @@ See: [`is_timeseries`](@ref)
4848
"""
4949
function state_values end
5050

51+
"""
52+
set_state!(sys, val, idx)
53+
54+
Set the state at index `idx` to `val` for system `sys`. This defaults to modifying
55+
`state_values(sys)`. If any additional bookkeeping needs to be performed or the
56+
default implementation does not work for a particular type, this method needs to be
57+
defined to enable the proper functioning of [`setu`](@ref).
58+
59+
See: [`state_values`](@ref)
60+
"""
61+
function set_state!(sys, val, idx)
62+
state_values(sys)[idx] = val
63+
end
64+
5165
"""
5266
current_time(p)
5367
@@ -164,12 +178,10 @@ Return a function that takes an integrator or problem of `sys` and a value, and
164178
the state `sym` to that value. Note that `sym` can be a direct numerical index, a symbolic state, or an array/tuple of the aforementioned.
165179
166180
Requires that the integrator implement [`state_values`](@ref) and the
167-
returned collection be a mutable reference to the state vector in the integrator/problem.
181+
returned collection be a mutable reference to the state vector in the integrator/problem. Alternatively, if this is not possible or additional actions need to
182+
be performed when updating state, [`set_state!`](@ref) can be defined.
168183
This function does not work on types for which [`is_timeseries`](@ref) is
169184
[`Timeseries`](@ref).
170-
171-
In case `state_values` cannot return such a mutable reference, `setu` needs to be
172-
implemented manually.
173185
"""
174186
function setu(sys, sym)
175187
symtype = symbolic_type(sym)
@@ -184,15 +196,15 @@ end
184196

185197
function _setu(sys, ::NotSymbolic, sym)
186198
return function setter!(prob, val)
187-
state_values(prob)[sym] = val
199+
set_state!(prob, val, sym)
188200
end
189201
end
190202

191203
function _setu(sys, ::ScalarSymbolic, sym)
192204
is_variable(sys, sym) || error("Invalid symbol $sym for `setu`")
193205
idx = variable_index(sys, sym)
194206
return function setter!(prob, val)
195-
state_values(prob)[idx] = val
207+
set_state!(prob, val, idx)
196208
end
197209
end
198210

0 commit comments

Comments
 (0)