Skip to content

Commit 4da53e6

Browse files
Merge pull request #25 from SciML/as/getu-sol
feat: add `IsTimeseriesTrait`, support timeseries objects in `getu`
2 parents 9c5e3b5 + 29d692b commit 4da53e6

File tree

10 files changed

+272
-51
lines changed

10 files changed

+272
-51
lines changed

.github/workflows/Downstream.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ jobs:
1919
package:
2020
- {user: SciML, repo: RecursiveArrayTools.jl, group: All}
2121
- {user: JuliaSymbolics, repo: Symbolics.jl, group: SymbolicIndexingInterface}
22+
- {user: SciML, repo: SciMLBase.jl, group: SymbolicIndexingInterface}
2223
steps:
2324
- uses: actions/checkout@v4
2425
- uses: julia-actions/setup-julia@v1

docs/src/api.md

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# Interface Functions
22

3+
## Mandatory methods
4+
35
```@docs
46
symbolic_container
57
is_variable
@@ -11,23 +13,45 @@ parameter_symbols
1113
is_independent_variable
1214
independent_variable_symbols
1315
is_observed
14-
observed
1516
is_time_dependent
1617
constant_structure
1718
all_variable_symbols
1819
all_symbols
1920
solvedvariables
2021
allvariables
21-
state_values
22+
```
23+
24+
## Optional Methods
25+
26+
### Observed equation handling
27+
28+
```@docs
29+
observed
30+
```
31+
32+
### Parameter indexing
33+
34+
```@docs
2235
parameter_values
23-
current_time
36+
set_parameter!
2437
getp
2538
setp
39+
```
40+
41+
### State indexing
42+
43+
```@docs
44+
Timeseries
45+
NotTimeseries
46+
is_timeseries
47+
state_values
48+
set_state!
49+
current_time
2650
getu
2751
setu
2852
```
2953

30-
# Traits
54+
# Symbolic Trait
3155

3256
```@docs
3357
ScalarSymbolic

docs/src/complete_sii.md

Lines changed: 71 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,9 @@ function SymbolicIndexingInterface.observed(sys::ExampleSystem, sym::Expr)
123123
end
124124
```
125125

126+
In case a type does not support such observed quantities, `is_observed` must be
127+
defined to always return `false`, and `observed` does not need to be implemented.
128+
126129
### Note about constant structure
127130

128131
Note that the method definitions are all assuming `constant_structure(p) == true`.
@@ -174,35 +177,91 @@ mutable struct ExampleIntegrator
174177
u::Vector{Float64}
175178
p::Vector{Float64}
176179
t::Float64
177-
state_index::Dict{Symbol,Int}
178-
parameter_index::Dict{Symbol,Int}
179-
independent_variable::Symbol
180+
sys::ExampleSystem
180181
end
181-
```
182182

183-
Assume that it implements the mandatory part of the interface as described above, and
184-
the following methods below:
185-
186-
```julia
183+
# define a fallback for the interface methods
184+
SymbolicIndexingInterface.symbolic_container(integ::ExampleIntegrator) = integ.sys
187185
SymbolicIndexingInterface.state_values(sys::ExampleIntegrator) = sys.u
188186
SymbolicIndexingInterface.parameter_values(sys::ExampleIntegrator) = sys.p
189187
SymbolicIndexingInterface.current_time(sys::ExampleIntegrator) = sys.t
190188
```
191189

192190
Then the following example would work:
193191
```julia
194-
integrator = ExampleIntegrator([1.0, 2.0, 3.0], [4.0, 5.0], 6.0, Dict(:x => 1, :y => 2, :z => 3), Dict(:a => 1, :b => 2), :t)
195-
getx = getu(integrator, :x)
192+
sys = ExampleSystem(Dict(:x => 1, :y => 2, :z => 3), Dict(:a => 1, :b => 2), :t, Dict())
193+
integrator = ExampleIntegrator([1.0, 2.0, 3.0], [4.0, 5.0], 6.0, sys)
194+
getx = getu(sys, :x)
196195
getx(integrator) # 1.0
197196

198-
get_expr = getu(integrator, :(x + y + t))
197+
get_expr = getu(sys, :(x + y + t))
199198
get_expr(integrator) # 13.0
200199

201-
setx! = setu(integrator, :y)
200+
setx! = setu(sys, :y)
202201
setx!(integrator, 0.0)
203202
getx(integrator) # 0.0
204203
```
205204

205+
In case a type stores timeseries data (such as solutions), then it must also implement
206+
the [`Timeseries`](@ref) trait. The type would then return a timeseries from
207+
[`state_values`](@ref) and [`current_time`](@ref) and the function returned from
208+
[`getu`](@ref) would then return a timeseries as well. For example, consider the
209+
`ExampleSolution` below:
210+
211+
```julia
212+
struct ExampleSolution
213+
u::Vector{Vector{Float64}}
214+
t::Vector{Float64}
215+
p::Vector{Float64}
216+
sys::ExampleSystem
217+
end
218+
219+
# define a fallback for the interface methods
220+
SymbolicIndexingInterface.symbolic_container(sol::ExampleSolution) = sol.sys
221+
SymbolicIndexingInterface.parameter_values(sol::ExampleSolution) = sol.p
222+
# define the trait
223+
SymbolicIndexingInterface.is_timeseries(::Type{ExampleSolution}) = Timeseries()
224+
# both state_values and current_time return a timeseries, which must be
225+
# the same length
226+
SymbolicIndexingInterface.state_values(sol::ExampleSolution) = sol.u
227+
SymbolicIndexingInterface.current_time(sol::ExampleSolution) = sol.t
228+
```
229+
230+
Then the following example would work:
231+
```julia
232+
# using the same system that the ExampleIntegrator used
233+
sol = ExampleSolution([[1.0, 2.0, 3.0], [1.5, 2.5, 3.5]], [4.0, 5.0], [6.0, 7.0], sys)
234+
getx = getu(sys, :x)
235+
getx(sol) # [1.0, 1.5]
236+
237+
get_expr = getu(sys, :(x + y + t))
238+
get_expr(sol) # [9.0, 11.0]
239+
240+
get_arr = getu(sys, [:y, :(x + a)])
241+
get_arr(sol) # [[2.0, 5.0], [2.5, 5.5]]
242+
243+
get_tuple = getu(sys, (:z, :(z * t)))
244+
get_tuple(sol) # [(3.0, 18.0), (3.5, 24.5)]
245+
```
246+
247+
Note that `setu` is not designed to work for `Timeseries` objects.
248+
249+
If a type needs to perform some additional actions when updating the state/parameters
250+
or if it is not possible to return a mutable reference to the state/parameter vector
251+
which can directly be modified, the functions [`set_state!`](@ref) and/or
252+
[`set_parameter!`](@ref) can be used. For example, suppose our `ExampleIntegrator`
253+
had an additional field `u_modified::Bool` to allow it to keep track of when a
254+
discontinuity occurs and handle it appropriately. This flag needs to be set to `true`
255+
whenever the state is modified. The `set_state!` function can then be implemented as
256+
follows:
257+
258+
```julia
259+
function SymbolicIndexingInterface.set_state!(integrator::ExampleIntegrator, val, idx)
260+
integrator.u[idx] = val
261+
integrator.u_modified = true
262+
end
263+
```
264+
206265
# Implementing the `SymbolicTypeTrait` for a type
207266

208267
The `SymbolicTypeTrait` is used to identify values that can act as symbolic variables. It

src/SymbolicIndexingInterface.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,19 @@ include("trait.jl")
55

66
export is_variable, variable_index, variable_symbols, is_parameter, parameter_index,
77
parameter_symbols, is_independent_variable, independent_variable_symbols, is_observed,
8-
observed, is_time_dependent, constant_structure, symbolic_container, all_variable_symbols,
8+
observed, is_time_dependent, constant_structure, symbolic_container,
9+
all_variable_symbols,
910
all_symbols, solvedvariables, allvariables
1011
include("interface.jl")
1112

1213
export SymbolCache
1314
include("symbol_cache.jl")
1415

15-
export parameter_values, getp, setp
16+
export parameter_values, set_parameter!, getp, setp
1617
include("parameter_indexing.jl")
1718

18-
export state_values, current_time, getu, setu
19+
export Timeseries,
20+
NotTimeseries, is_timeseries, state_values, set_state!, current_time, getu, setu
1921
include("state_indexing.jl")
2022

2123
end

src/interface.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,9 @@ have the signature `(u, p) -> [values...]` where `u` and `p` is the current stat
9191
parameter vector, respectively. If `istimedependent(sys) == true`, the function should accept
9292
the current time `t` as its third parameter. If `constant_structure(sys) == false`,
9393
accept a third parameter, which can either be a vector of symbols indicating the order
94-
of states or a time index, which identifies the order of states.
94+
of states or a time index, which identifies the order of states. This function
95+
does not need to be defined if [`is_observed`](@ref) always returns `false`. Thus,
96+
it is mandatory to always check `is_observed` before using this function.
9597
9698
See also: [`is_time_dependent`](@ref), [`constant_structure`](@ref)
9799
"""

src/parameter_indexing.jl

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,20 @@ Return an indexable collection containing the value of each parameter in `p`.
55
"""
66
function parameter_values end
77

8+
"""
9+
set_parameter!(sys, val, idx)
10+
11+
Set the parameter at index `idx` to `val` for system `sys`. This defaults to modifying
12+
`parameter_values(sys)`. If any additional bookkeeping needs to be performed or the
13+
default implementation does not work for a particular type, this method needs to be
14+
defined to enable the proper functioning of [`setp`](@ref).
15+
16+
See: [`parameter_values`](@ref)
17+
"""
18+
function set_parameter!(sys, val, idx)
19+
parameter_values(sys)[idx] = val
20+
end
21+
822
"""
923
getp(sys, p)
1024
@@ -55,8 +69,9 @@ Return a function that takes an integrator of `sys` and a value, and sets
5569
the parameter `p` to that value. Note that `p` can be a direct numerical index or a
5670
symbolic value. Requires that the integrator implement [`parameter_values`](@ref) and the
5771
returned collection be a mutable reference to the parameter vector in the integrator. In
58-
case `parameter_values` cannot return such a mutable reference, `setp` needs to be
59-
implemented manually.
72+
case `parameter_values` cannot return such a mutable reference, or additional actions
73+
need to be performed when updating parameters, [`set_parameter!`](@ref) must be
74+
implemented.
6075
"""
6176
function setp(sys, p)
6277
symtype = symbolic_type(p)
@@ -70,21 +85,21 @@ end
7085

7186
function _setp(sys, ::NotSymbolic, p)
7287
return function setter!(sol, val)
73-
parameter_values(sol)[p] = val
88+
set_parameter!(sol, val, p)
7489
end
7590
end
7691

7792
function _setp(sys, ::ScalarSymbolic, p)
7893
idx = parameter_index(sys, p)
7994
return function setter!(sol, val)
80-
parameter_values(sol)[idx] = val
95+
set_parameter!(sol, val, idx)
8196
end
8297
end
8398

8499
function _setp(sys, ::ScalarSymbolic, p::Union{Tuple, AbstractArray})
85100
idxs = parameter_index.((sys,), p)
86101
return function setter!(sol, val)
87-
setindex!.((parameter_values(sol),), val, idxs)
102+
set_parameter!.((sol,), val, idxs)
88103
end
89104
end
90105

0 commit comments

Comments
 (0)