Skip to content

Commit da7aa70

Browse files
Merge pull request #64 from SciML/as/problem-state
feat: add `ProblemState`
2 parents 494548d + c1b60a6 commit da7aa70

File tree

7 files changed

+72
-2
lines changed

7 files changed

+72
-2
lines changed

.github/workflows/Downstream.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ jobs:
2020
- {user: SciML, repo: RecursiveArrayTools.jl, group: SymbolicIndexingInterface}
2121
- {user: JuliaSymbolics, repo: Symbolics.jl, group: SymbolicIndexingInterface}
2222
- {user: SciML, repo: SciMLBase.jl, group: SymbolicIndexingInterface}
23+
- {user: SciML, repo: ModelingToolkit.jl, group: SymbolicIndexingInterface}
2324
steps:
2425
- uses: actions/checkout@v4
2526
- uses: julia-actions/setup-julia@v1

docs/src/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,4 +88,5 @@ symbolic_evaluate
8888

8989
```@docs
9090
SymbolCache
91+
ProblemState
9192
```

docs/src/usage.md

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@ Consider the following example:
2222

2323
```@example Usage
2424
using ModelingToolkit, OrdinaryDiffEq, SymbolicIndexingInterface, Plots
25+
using ModelingToolkit: t_nounits as t, D_nounits as D
2526
2627
@parameters σ ρ β
27-
@variables t x(t) y(t) z(t) w(t)
28-
D = Differential(t)
28+
@variables x(t) y(t) z(t) w(t)
2929
3030
eqs = [D(D(x)) ~ σ * (y - x),
3131
D(y) ~ x * (ρ - z) - y,
@@ -121,6 +121,30 @@ output, the following shorthand is used:
121121
sol[allvariables] # equivalent to sol[all_variable_symbols(sol)]
122122
```
123123

124+
### Evaluating expressions
125+
126+
`getu` also generates functions for expressions if the object passed to it supports
127+
[`observed`](@ref). For example:
128+
129+
```@example Usage
130+
getu(prob, x + y + z)(prob)
131+
```
132+
133+
To evaluate this function using values other than the ones contained in `prob`, we need
134+
an object that supports [`state_values`](@ref), [`parameter_values`](@ref),
135+
[`current_time`](@ref). SymbolicIndexingInterface provides the [`ProblemState`](@ref) type,
136+
which has trivial implementations of the above functions. We can thus do:
137+
138+
```@example Usage
139+
temp_state = ProblemState(; u = [0.1, 0.2, 0.3, 0.4], p = parameter_values(prob))
140+
getu(prob, x + y + z)(temp_state)
141+
```
142+
143+
Note that providing all of the state vector, parameter object and time may not be
144+
necessary if the function generated by `observed` does not access them. ModelingToolkit.jl
145+
generates functions that access the parameters regardless of whether they are used in the
146+
expression, and thus it needs to be provided to the `ProblemState`.
147+
124148
## Parameter Indexing: Getting and Setting Parameter Values
125149

126150
Parameters cannot be obtained using this syntax, and instead require using [`getp`](@ref) and [`setp`](@ref).

src/SymbolicIndexingInterface.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ include("parameter_indexing.jl")
3131
export state_values, set_state!, current_time, getu, setu
3232
include("state_indexing.jl")
3333

34+
export ProblemState
35+
include("problem_state.jl")
36+
3437
export ParameterIndexingProxy
3538
include("parameter_indexing_proxy.jl")
3639

src/problem_state.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
"""
2+
struct ProblemState
3+
function ProblemState(; u = nothing, p = nothing, t = nothing)
4+
5+
A struct which can be used as an argument to the function returned by [`getu`](@ref) or
6+
[`setu`](@ref). It stores the state vector, parameter object and current time, and
7+
forwards calls to [`state_values`](@ref), [`parameter_values`](@ref),
8+
[`current_time`](@ref), [`set_state!`](@ref), [`set_parameter!`](@ref) to the contained
9+
objects.
10+
"""
11+
struct ProblemState{U, P, T}
12+
u::U
13+
p::P
14+
t::T
15+
end
16+
17+
ProblemState(; u = nothing, p = nothing, t = nothing) = ProblemState(u, p, t)
18+
19+
state_values(ps::ProblemState) = ps.u
20+
parameter_values(ps::ProblemState) = ps.p
21+
current_time(ps::ProblemState) = ps.t
22+
set_state!(ps::ProblemState, val, idx) = set_state!(ps.u, val, idx)
23+
set_parameter!(ps::ProblemState, val, idx) = set_parameter!(ps.p, val, idx)

test/problem_state_test.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
using SymbolicIndexingInterface
2+
using Test
3+
4+
sys = SymbolCache([:x, :y, :z], [:a, :b, :c], :t)
5+
prob = ProblemState(; u = [1.0, 2.0, 3.0], p = [0.1, 0.2, 0.3], t = 0.5)
6+
7+
for (i, sym) in enumerate(variable_symbols(sys))
8+
@test getu(sys, sym)(prob) == prob.u[i]
9+
end
10+
for (i, sym) in enumerate(parameter_symbols(sys))
11+
@test getp(sys, sym)(prob) == prob.p[i]
12+
end
13+
@test getu(sys, :t)(prob) == prob.t
14+
15+
@test getu(sys, :(x + a + t))(prob) == 1.6

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,6 @@ end
2626
@safetestset "Remake test" begin
2727
@time include("remake_test.jl")
2828
end
29+
@safetestset "ProblemState test" begin
30+
@time include("problem_state_test.jl")
31+
end

0 commit comments

Comments
 (0)