Skip to content

Commit 8fcceb2

Browse files
feat: add getu/setu, update docs, add tests
1 parent 06f2cb2 commit 8fcceb2

File tree

12 files changed

+256
-20
lines changed

12 files changed

+256
-20
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@ version = "0.3.1"
55

66
[compat]
77
Aqua = "0.8"
8+
SafeTestsets = "0.0.1"
89
Test = "1"
910
julia = "1.10"
1011

1112
[extras]
1213
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
14+
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
1315
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1416

1517
[targets]
16-
test = ["Aqua", "Test"]
18+
test = ["Aqua", "Test", "SafeTestsets"]

docs/src/api.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,13 @@ all_variable_symbols
1818
all_symbols
1919
solvedvariables
2020
allvariables
21+
state_values
2122
parameter_values
23+
current_time
2224
getp
2325
setp
26+
getu
27+
setu
2428
```
2529

2630
# Traits

docs/src/complete_sii.md

Lines changed: 73 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -85,22 +85,41 @@ These are for handling symbolic expressions and generating equations which are n
8585
in the solution vector.
8686

8787
```julia
88+
using RuntimeGeneratedFunctions
89+
RuntimeGeneratedFunctions.init(@__MODULE__)
90+
8891
# this type accepts `Expr` for observed expressions involving state/parameter/observed
8992
# variables
9093
SymbolicIndexingInterface.is_observed(sys::ExampleSystem, sym) = sym isa Expr || sym isa Symbol && haskey(sys.observed, sym)
9194

9295
function SymbolicIndexingInterface.observed(sys::ExampleSystem, sym::Expr)
96+
# generate a function with the appropriate signature
9397
if is_time_dependent(sys)
94-
return function (u, p, t)
95-
# compute value from `sym`, leveraging `variable_index` and
96-
# `parameter_index` to turn symbols into indices
97-
end
98+
fn_expr = :(
99+
function gen(u, p, t)
100+
# assign a variable for each state symbol it's value in u
101+
$([:($var = u[$idx]) for (var, idx) in pairs(sys.state_index)]...)
102+
# assign a variable for each parameter symbol it's value in p
103+
$([:($var = p[$idx]) for (var, idx) in pairs(sys.parameter_index)]...)
104+
# assign a variable for the independent variable
105+
$(sys.independent_variable) = t
106+
# return the value of the expression
107+
return $sym
108+
end
109+
)
98110
else
99-
return function (u, p)
100-
# compute value from `sym`, leveraging `variable_index` and
101-
# `parameter_index` to turn symbols into indices
102-
end
111+
fn_expr = :(
112+
function gen(u, p)
113+
# assign a variable for each state symbol it's value in u
114+
$([:($var = u[$idx]) for (var, idx) in pairs(sys.state_index)]...)
115+
# assign a variable for each parameter symbol it's value in p
116+
$([:($var = p[$idx]) for (var, idx) in pairs(sys.parameter_index)]...)
117+
# return the value of the expression
118+
return $sym
119+
end
120+
)
103121
end
122+
return @RuntimeGeneratedFunction(fn_expr)
104123
end
105124
```
106125

@@ -127,7 +146,7 @@ only responsible for identifying observed values and `observed` will always be c
127146
on a type that wraps this type. An example is `ModelingToolkit.AbstractSystem`, which
128147
can identify whether a value is observed, but cannot implement `observed` itself.
129148

130-
Other optional methods relate to parameter indexing. If a type contains the values of
149+
Other optional methods relate to indexing functions. If a type contains the values of
131150
parameter variables, it must implement [`parameter_values`](@ref). This allows the
132151
default definitions of [`getp`](@ref) and [`setp`](@ref) to work. While `setp` is
133152
not typically useful for solution objects, it may be useful for integrators. Typically,
@@ -140,7 +159,51 @@ function SymbolicIndexingInterface.parameter_values(sys::ExampleSystem)
140159
end
141160
```
142161

143-
## Implementing the `SymbolicTypeTrait` for a type
162+
If a type contains the value of state variables, it can define [`state_values`](@ref) to
163+
enable the usage of [`getu`](@ref) and [`setu`](@ref). These methods retturn getter/
164+
setter functions to access or update the value of a state variable (or a collection of
165+
them). If the type also supports generating [`observed`](@ref) functions, `getu` also
166+
enables returning functions to access the value of arbitrary expressions involving
167+
the system's symbols. This also requires that the type implement
168+
[`parameter_values`](@ref) and [`current_time`](@ref) (if the system is time-dependent).
169+
170+
Consider the following `ExampleIntegrator`
171+
172+
```julia
173+
mutable struct ExampleIntegrator
174+
u::Vector{Float64}
175+
p::Vector{Float64}
176+
t::Float64
177+
state_index::Dict{Symbol,Int}
178+
parameter_index::Dict{Symbol,Int}
179+
independent_variable::Symbol
180+
end
181+
```
182+
183+
Assume that it implements the mandatory part of the interface as described above, and
184+
the following methods below:
185+
186+
```julia
187+
SymbolicIndexingInterface.state_values(sys::ExampleIntegrator) = sys.u
188+
SymbolicIndexingInterface.parameter_values(sys::ExampleIntegrator) = sys.p
189+
SymbolicIndexingInterface.current_time(sys::ExampleIntegrator) = sys.t
190+
```
191+
192+
Then the following example would work:
193+
```julia
194+
integrator = ExampleIntegrator([1.0, 2.0, 3.0], [4.0, 5.0], 6.0, Dict(:x => 1, :y => 2, :z => 3), Dict(:a => 1, :b => 2), :t)
195+
getx = getu(integrator, :x)
196+
getx(integrator) # 1.0
197+
198+
get_expr = getu(integrator, :(x + y + t))
199+
get_expr(integrator) # 13.0
200+
201+
setx! = setu(integrator, :y)
202+
setx!(integrator, 0.0)
203+
getx(integrator) # 0.0
204+
```
205+
206+
# Implementing the `SymbolicTypeTrait` for a type
144207

145208
The `SymbolicTypeTrait` is used to identify values that can act as symbolic variables. It
146209
has three variants:

src/SymbolicIndexingInterface.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,7 @@ include("symbol_cache.jl")
1515
export parameter_values, getp, setp
1616
include("parameter_indexing.jl")
1717

18+
export state_values, current_time, getu, setu
19+
include("state_indexing.jl")
20+
1821
end

src/state_indexing.jl

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
"""
2+
state_values(p)
3+
4+
Return an indexable collection containing the values of all states in the integrator or
5+
problem `p`.
6+
"""
7+
function state_values end
8+
9+
"""
10+
current_time(p)
11+
12+
Return the current time in the integrator or problem `p`.
13+
"""
14+
function current_time end
15+
16+
"""
17+
getu(sys, sym)
18+
19+
Return a function that takes an integrator or problem of `sys`, and returns the value of
20+
the symbolic `sym`. `sym` can be a direct index into the state vector, a symbolic state,
21+
a symbolic expression involving symbolic quantities in the system `sys`, or an
22+
array/tuple of the aforementioned.
23+
24+
At minimum, this requires that the integrator or problem implement [`state_values`](@ref).
25+
To support symbolic expressions, the integrator or problem must implement
26+
[`observed`](@ref), [`parameter_values`](@ref) and [`current_time`](@ref).
27+
28+
This function typically does not need to be implemented, and has a default implementation
29+
relying on the above functions.
30+
"""
31+
function getu(sys, sym)
32+
symtype = symbolic_type(sym)
33+
elsymtype = symbolic_type(eltype(sym))
34+
35+
if symtype != NotSymbolic()
36+
_getu(sys, symtype, sym)
37+
else
38+
_getu(sys, elsymtype, sym)
39+
end
40+
end
41+
42+
function _getu(sys, ::NotSymbolic, sym)
43+
return function getter(prob)
44+
return state_values(prob)[sym]
45+
end
46+
end
47+
48+
function _getu(sys, ::ScalarSymbolic, sym)
49+
if is_variable(sys, sym)
50+
idx = variable_index(sys, sym)
51+
return function getter1(prob)
52+
return state_values(prob)[idx]
53+
end
54+
elseif is_observed(sys, sym)
55+
fn = observed(sys, sym)
56+
if is_time_dependent(sys)
57+
function getter2(prob)
58+
return fn(state_values(prob), parameter_values(prob), current_time(prob))
59+
end
60+
else
61+
function getter3(prob)
62+
return fn(state_values(prob), parameter_values(prob))
63+
end
64+
end
65+
end
66+
error("Invalid symbol $sym for `getu`")
67+
end
68+
69+
function _getu(sys, ::ScalarSymbolic, sym::Union{<:Tuple,<:AbstractArray})
70+
getters = getu.((sys,), sym)
71+
_call(getter, prob) = getter(prob)
72+
return function getter(prob)
73+
return _call.(getters, (prob,))
74+
end
75+
end
76+
77+
function _getu(sys, ::ArraySymbolic, sym)
78+
return getu(sys, collect(sym))
79+
end
80+
81+
"""
82+
setu(sys, sym)
83+
84+
Return a function that takes an integrator or problem of `sys` and a value, and sets the
85+
the state `sym` to that value. Note that `sym` can be a direct numerical index, a symbolic state, or an array/tuple of the aforementioned.
86+
87+
Requires that the integrator implement [`state_values`](@ref) and the
88+
returned collection be a mutable reference to the state vector in the integrator/problem.
89+
In case `state_values` cannot return such a mutable reference, `setu` needs to be
90+
implemented manually.
91+
"""
92+
function setu(sys, sym)
93+
symtype = symbolic_type(sym)
94+
elsymtype = symbolic_type(eltype(sym))
95+
96+
if symtype != NotSymbolic()
97+
_setu(sys, symtype, sym)
98+
else
99+
_setu(sys, elsymtype, sym)
100+
end
101+
end
102+
103+
function _setu(sys, ::NotSymbolic, sym)
104+
return function setter!(prob, val)
105+
state_values(prob)[sym] = val
106+
end
107+
end
108+
109+
function _setu(sys, ::ScalarSymbolic, sym)
110+
is_variable(sys, sym) || error("Invalid symbol $sym for `setu`")
111+
idx = variable_index(sys, sym)
112+
return function setter!(prob, val)
113+
state_values(prob)[idx] = val
114+
end
115+
end
116+
117+
function _setu(sys, ::ScalarSymbolic, sym::Union{<:Tuple,<:AbstractArray})
118+
setters = setu.((sys,), sym)
119+
_call!(setter!, prob, val) = setter!(prob, val)
120+
return function setter!(prob, val)
121+
_call!.(setters, (prob,), val)
122+
end
123+
end
124+
125+
function _setu(sys, ::ArraySymbolic, sym)
126+
return setu(sys, collect(sym))
127+
end

src/trait.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,14 @@ struct NotSymbolic <: SymbolicTypeTrait end
3535
symbolic_type(::Type)
3636
3737
Get the symbolic type trait of a type. Default to [`NotSymbolic`](@ref) for all types
38-
except `Symbol`.
38+
except `Symbol` and `Expr`, both of which are [`ScalarSymbolic`](@ref).
3939
4040
See also: [`ScalarSymbolic`](@ref), [`ArraySymbolic`](@ref), [`NotSymbolic`](@ref)
4141
"""
4242
symbolic_type(x) = symbolic_type(typeof(x))
4343
symbolic_type(::Type) = NotSymbolic()
4444
symbolic_type(::Type{Symbol}) = ScalarSymbolic()
45+
symbolic_type(::Type{Expr}) = ScalarSymbolic()
4546

4647
"""
4748
hasname(x)

test/example_test.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
using SymbolicIndexingInterface
2+
using Test
3+
14
struct SystemMockup
25
static::Bool
36
vars::Vector{Symbol}

test/fallback_test.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using SymbolicIndexingInterface
2+
using Test
23

34
struct Wrapper{W}
45
wrapped::W

test/parameter_indexing_test.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
using SymbolicIndexingInterface
2+
using Test
23

3-
struct FakeIntegrator{P}
4+
struct FakeIntegrator{S,P}
5+
sys::S
46
p::P
57
end
68

@@ -9,7 +11,7 @@ SymbolicIndexingInterface.parameter_values(fp::FakeIntegrator) = fp.p
911

1012
sys = SymbolCache([:x, :y, :z], [:a, :b], [:t])
1113
p = [1.0, 2.0]
12-
fi = FakeIntegrator(copy(p))
14+
fi = FakeIntegrator(sys, copy(p))
1315
for (i, sym) in [(1, :a), (2, :b), ([1, 2], [:a, :b]), ((1, 2), (:a, :b))]
1416
get = getp(sys, sym)
1517
set! = setp(sys, sym)

test/runtests.jl

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,25 @@
11
using SymbolicIndexingInterface
2+
using SafeTestsets
23
using Test
3-
@testset "Quality Assurance" begin
4+
5+
@safetestset "Quality Assurance" begin
46
@time include("qa.jl")
57
end
6-
@testset "Interface test" begin
8+
@safetestset "Interface test" begin
79
@time include("example_test.jl")
810
end
9-
@testset "Trait test" begin
11+
@safetestset "Trait test" begin
1012
@time include("trait_test.jl")
1113
end
12-
@testset "SymbolCache test" begin
14+
@safetestset "SymbolCache test" begin
1315
@time include("symbol_cache_test.jl")
1416
end
15-
@testset "Fallback test" begin
17+
@safetestset "Fallback test" begin
1618
@time include("fallback_test.jl")
1719
end
18-
@testset "Parameter indexing test" begin
20+
@safetestset "Parameter indexing test" begin
1921
@time include("parameter_indexing_test.jl")
2022
end
23+
@safetestset "State indexing test" begin
24+
@time include("state_indexing_test.jl")
25+
end

0 commit comments

Comments
 (0)