Skip to content

Commit 01da0dc

Browse files
feat: add support for parameter and indepvar symbols in getu
1 parent 94b5a85 commit 01da0dc

File tree

2 files changed

+95
-13
lines changed

2 files changed

+95
-13
lines changed

src/state_indexing.jl

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,10 @@ Return a function that takes an integrator, problem or solution of `sys`, and re
9494
the value of the symbolic `sym`. If `sym` is not an observed quantity, the returned
9595
function can also directly be called with an array of values representing the state
9696
vector. `sym` can be a direct index into the state vector, a symbolic state, a symbolic
97-
expression involving symbolic quantities in the system `sys`, or an array/tuple of the
98-
aforementioned. If the returned function is called with a timeseries object, it can also
99-
be given a second argument representing the index at which to find the value of `sym`.
97+
expression involving symbolic quantities in the system `sys`, a parameter symbol, or the
98+
independent variable symbol, or an array/tuple of the aforementioned. If the returned
99+
function is called with a timeseries object, it can also be given a second argument
100+
representing the index at which to find the value of `sym`.
100101
101102
At minimum, this requires that the integrator, problem or solution implement
102103
[`state_values`](@ref). To support symbolic expressions, the integrator or problem
@@ -131,6 +132,19 @@ function _getu(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, sym)
131132
if is_variable(sys, sym)
132133
idx = variable_index(sys, sym)
133134
return getu(sys, idx)
135+
elseif is_parameter(sys, sym)
136+
return let fn = getp(sys, sym)
137+
getter(prob, args...) = fn(prob)
138+
getter
139+
end
140+
elseif is_independent_variable(sys, sym)
141+
_getter(::IsTimeseriesTrait, prob) = current_time(prob)
142+
_getter(::Timeseries, prob, i) = current_time(prob, i)
143+
return let _getter = _getter
144+
getter(prob) = _getter(is_timeseries(prob), prob)
145+
getter(prob, i) = _getter(is_timeseries(prob), prob, i)
146+
getter
147+
end
134148
elseif is_observed(sys, sym)
135149
fn = observed(sys, sym)
136150
if is_time_dependent(sys)
@@ -227,11 +241,15 @@ function _setu(sys, ::NotSymbolic, ::NotSymbolic, sym)
227241
end
228242

229243
function _setu(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, sym)
230-
is_variable(sys, sym) || error("Invalid symbol $sym for `setu`")
231-
idx = variable_index(sys, sym)
232-
return function setter!(prob, val)
233-
set_state!(prob, val, idx)
244+
if is_variable(sys, sym)
245+
idx = variable_index(sys, sym)
246+
return function setter!(prob, val)
247+
set_state!(prob, val, idx)
248+
end
249+
elseif is_parameter(sys, sym)
250+
return setp(sys, sym)
234251
end
252+
error("Invalid symbol $sym for `setu`")
235253
end
236254

237255
for (t1, t2) in [(ScalarSymbolic, Any), (ArraySymbolic, Any), (NotSymbolic, Union{<:Tuple, <:AbstractArray})]

test/state_indexing_test.jl

Lines changed: 70 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,22 @@
11
using SymbolicIndexingInterface
22

3-
struct FakeIntegrator{S, U}
3+
struct FakeIntegrator{S, U, P, T}
44
sys::S
55
u::U
6+
p::P
7+
t::T
68
end
79

810
SymbolicIndexingInterface.symbolic_container(fp::FakeIntegrator) = fp.sys
911
SymbolicIndexingInterface.state_values(fp::FakeIntegrator) = fp.u
12+
SymbolicIndexingInterface.parameter_values(fp::FakeIntegrator) = fp.p
13+
SymbolicIndexingInterface.current_time(fp::FakeIntegrator) = fp.t
1014

11-
sys = SymbolCache([:x, :y, :z], [:a, :b], [:t])
15+
sys = SymbolCache([:x, :y, :z], [:a, :b, :c], [:t])
1216
u = [1.0, 2.0, 3.0]
13-
fi = FakeIntegrator(sys, copy(u))
17+
p = [11.0, 12.0, 13.0]
18+
t = 0.5
19+
fi = FakeIntegrator(sys, copy(u), copy(p), t)
1420
# checking inference for non-concretely typed arrays will always fail
1521
for (sym, val, newval, check_inference) in [
1622
(:x, u[1], 4.0, true)
@@ -61,19 +67,60 @@ for (sym, val, newval, check_inference) in [
6167
@test get(u) == val
6268
end
6369

70+
for (sym, oldval, newval, check_inference) in [
71+
(:a, p[1], 4.0, true)
72+
(:b, p[2], 5.0, true)
73+
(:c, p[3], 6.0, true)
74+
([:a, :b], p[1:2], [4.0, 5.0], true)
75+
((:c, :b), (p[3], p[2]), (6.0, 5.0), true)
76+
([:x, :a], [u[1], p[1]], [4.0, 5.0], false)
77+
((:y, :b), (u[2], p[2]), (5.0, 6.0), true)
78+
]
79+
get = getu(fi, sym)
80+
set! = setu(fi, sym)
81+
if check_inference
82+
@inferred get(fi)
83+
end
84+
@test get(fi) == oldval
85+
if check_inference
86+
@inferred set!(fi, newval)
87+
else
88+
set!(fi, newval)
89+
end
90+
@test get(fi) == newval
91+
set!(fi, oldval)
92+
@test get(fi) == oldval
93+
end
6494

65-
struct FakeSolution{S, U}
95+
for (sym, val, check_inference) in [
96+
(:t, t, true),
97+
([:x, :a, :t], [u[1], p[1], t], false),
98+
((:x, :a, :t), (u[1], p[1], t), true),
99+
]
100+
get = getu(fi, sym)
101+
if check_inference
102+
@inferred get(fi)
103+
end
104+
@test get(fi) == val
105+
end
106+
107+
struct FakeSolution{S, U, P, T}
66108
sys::S
67109
u::U
110+
p::P
111+
t::T
68112
end
69113

70114
SymbolicIndexingInterface.is_timeseries(::Type{<:FakeSolution}) = Timeseries()
71115
SymbolicIndexingInterface.symbolic_container(fp::FakeSolution) = fp.sys
72116
SymbolicIndexingInterface.state_values(fp::FakeSolution) = fp.u
117+
SymbolicIndexingInterface.parameter_values(fp::FakeSolution) = fp.p
118+
SymbolicIndexingInterface.current_time(fp::FakeSolution) = fp.t
73119

74-
sys = SymbolCache([:x, :y, :z], [:a, :b], [:t])
120+
sys = SymbolCache([:x, :y, :z], [:a, :b, :c], [:t])
75121
u = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
76-
sol = FakeSolution(sys, u)
122+
t = [1.5, 2.0]
123+
sol = FakeSolution(sys, u, p, t)
77124

78125
xvals = getindex.(sol.u, 1)
79126
yvals = getindex.(sol.u, 2)
@@ -97,6 +144,11 @@ for (sym, ans, check_inference) in [
97144
((:x, [:y, :z]), tuple.(xvals, vcat.(yvals, zvals)), true)
98145
((:x, (:y, :z)), tuple.(xvals, tuple.(yvals, zvals)), true)
99146
((:x, [:y, :z], (:z, :y)), tuple.(xvals, vcat.(yvals, zvals), tuple.(zvals, yvals)), true)
147+
([:x, :a], vcat.(xvals, p[1]), false)
148+
((:y, :b), tuple.(yvals, p[2]), true)
149+
(:t, t, true)
150+
([:x, :a, :t], vcat.(xvals, p[1], t), false)
151+
((:x, :a, :t), tuple.(xvals, p[1], t), true)
100152
]
101153
get = getu(sys, sym)
102154
if check_inference
@@ -110,3 +162,15 @@ for (sym, ans, check_inference) in [
110162
@test get(sol, i) == ans[i]
111163
end
112164
end
165+
166+
for (sym, val) in [
167+
(:a, p[1])
168+
(:b, p[2])
169+
(:c, p[3])
170+
([:a, :b], p[1:2])
171+
((:c, :b), (p[3], p[2]))
172+
]
173+
get = getu(fi, sym)
174+
@inferred get(fi)
175+
@test get(fi) == val
176+
end

0 commit comments

Comments
 (0)