Skip to content

Commit 55eca79

Browse files
feat: improve getu/setu/getp/setp handling of nested variables
- also addresses type-stability of the closures returned from the above functions
1 parent a7e0efb commit 55eca79

File tree

4 files changed

+165
-84
lines changed

4 files changed

+165
-84
lines changed

src/parameter_indexing.jl

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -23,42 +23,43 @@ end
2323
getp(sys, p)
2424
2525
Return a function that takes an integrator or solution of `sys`, and returns the value of
26-
the parameter `p`. Note that `p` can be a direct numerical index or a symbolic value.
26+
the parameter `p`. Note that `p` can be a direct numerical index or a symbolic value, or
27+
an array/tuple of the aforementioned.
28+
2729
Requires that the integrator or solution implement [`parameter_values`](@ref). This function
2830
typically does not need to be implemented, and has a default implementation relying on
2931
[`parameter_values`](@ref).
3032
"""
3133
function getp(sys, p)
3234
symtype = symbolic_type(p)
3335
elsymtype = symbolic_type(eltype(p))
34-
if symtype != NotSymbolic()
35-
return _getp(sys, symtype, p)
36-
else
37-
return _getp(sys, elsymtype, p)
38-
end
36+
_getp(sys, symtype, elsymtype, p)
3937
end
4038

41-
function _getp(sys, ::NotSymbolic, p)
39+
function _getp(sys, ::NotSymbolic, ::NotSymbolic, p)
4240
return function getter(sol)
4341
return parameter_values(sol)[p]
4442
end
4543
end
4644

47-
function _getp(sys, ::ScalarSymbolic, p)
45+
function _getp(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, p)
4846
idx = parameter_index(sys, p)
4947
return function getter(sol)
5048
return parameter_values(sol)[idx]
5149
end
5250
end
5351

54-
function _getp(sys, ::ScalarSymbolic, p::Union{Tuple, AbstractArray})
55-
idxs = parameter_index.((sys,), p)
56-
return function getter(sol)
57-
return getindex.((parameter_values(sol),), idxs)
52+
for (t1, t2) in [(ArraySymbolic, Any), (ScalarSymbolic, Any), (NotSymbolic, Union{<:Tuple, <:AbstractArray})]
53+
@eval function _getp(sys, ::NotSymbolic, ::$t1, p::$t2)
54+
getters = getp.((sys,), p)
55+
56+
return function getter(sol)
57+
map(g -> g(sol), getters)
58+
end
5859
end
5960
end
6061

61-
function _getp(sys, ::ArraySymbolic, p)
62+
function _getp(sys, ::ArraySymbolic, ::NotSymbolic, p)
6263
return getp(sys, collect(p))
6364
end
6465

@@ -76,33 +77,31 @@ implemented.
7677
function setp(sys, p)
7778
symtype = symbolic_type(p)
7879
elsymtype = symbolic_type(eltype(p))
79-
if symtype != NotSymbolic()
80-
return _setp(sys, symtype, p)
81-
else
82-
return _setp(sys, elsymtype, p)
83-
end
80+
_setp(sys, symtype, elsymtype, p)
8481
end
8582

86-
function _setp(sys, ::NotSymbolic, p)
83+
function _setp(sys, ::NotSymbolic, ::NotSymbolic, p)
8784
return function setter!(sol, val)
8885
set_parameter!(sol, val, p)
8986
end
9087
end
9188

92-
function _setp(sys, ::ScalarSymbolic, p)
89+
function _setp(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, p)
9390
idx = parameter_index(sys, p)
9491
return function setter!(sol, val)
9592
set_parameter!(sol, val, idx)
9693
end
9794
end
9895

99-
function _setp(sys, ::ScalarSymbolic, p::Union{Tuple, AbstractArray})
100-
idxs = parameter_index.((sys,), p)
101-
return function setter!(sol, val)
102-
set_parameter!.((sol,), val, idxs)
96+
for (t1, t2) in [(ArraySymbolic, Any), (ScalarSymbolic, Any), (NotSymbolic, Union{<:Tuple, <:AbstractArray})]
97+
@eval function _setp(sys, ::NotSymbolic, ::$t1, p::$t2)
98+
setters = setp.((sys,), p)
99+
return function setter!(sol, val)
100+
map((s!, v) -> s!(sol, v), setters, val)
101+
end
103102
end
104103
end
105104

106-
function _setp(sys, ::ArraySymbolic, p)
105+
function _setp(sys, ::ArraySymbolic, ::NotSymbolic, p)
107106
return setp(sys, collect(p))
108107
end

src/state_indexing.jl

Lines changed: 50 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -93,23 +93,20 @@ relying on the above functions.
9393
function getu(sys, sym)
9494
symtype = symbolic_type(sym)
9595
elsymtype = symbolic_type(eltype(sym))
96-
97-
if symtype != NotSymbolic()
98-
_getu(sys, symtype, sym)
99-
else
100-
_getu(sys, elsymtype, sym)
101-
end
96+
_getu(sys, symtype, elsymtype, sym)
10297
end
10398

104-
function _getu(sys, ::NotSymbolic, sym)
99+
function _getu(sys, ::NotSymbolic, ::NotSymbolic, sym)
105100
_getter(::Timeseries, prob) = getindex.(state_values(prob), (sym,))
106101
_getter(::NotTimeseries, prob) = state_values(prob)[sym]
107-
return function getter(prob)
108-
return _getter(is_timeseries(prob), prob)
102+
return let _getter = _getter
103+
function getter(prob)
104+
return _getter(is_timeseries(prob), prob)
105+
end
109106
end
110107
end
111108

112-
function _getu(sys, ::ScalarSymbolic, sym)
109+
function _getu(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, sym)
113110
if is_variable(sys, sym)
114111
idx = variable_index(sys, sym)
115112
return getu(sys, idx)
@@ -125,8 +122,10 @@ function _getu(sys, ::ScalarSymbolic, sym)
125122
return fn(state_values(prob), parameter_values(prob), current_time(prob))
126123
end
127124

128-
return function getter2(prob)
129-
return _getter2(is_timeseries(prob), prob)
125+
return let _getter2 = _getter2
126+
function getter2(prob)
127+
return _getter2(is_timeseries(prob), prob)
128+
end
130129
end
131130
else
132131
function _getter3(::Timeseries, prob)
@@ -136,8 +135,10 @@ function _getu(sys, ::ScalarSymbolic, sym)
136135
return fn(state_values(prob), parameter_values(prob))
137136
end
138137

139-
return function getter3(prob)
140-
return _getter3(is_timeseries(prob), prob)
138+
return let _getter3 = _getter3
139+
function getter3(prob)
140+
return _getter3(is_timeseries(prob), prob)
141+
end
141142
end
142143
end
143144
end
@@ -153,24 +154,38 @@ state_values(t::TimeseriesIndexWrapper) = state_values(t.timeseries)[t.idx]
153154
parameter_values(t::TimeseriesIndexWrapper) = parameter_values(t.timeseries)
154155
current_time(t::TimeseriesIndexWrapper) = current_time(t.timeseries)[t.idx]
155156

156-
function _getu(sys, ::ScalarSymbolic, sym::Union{<:Tuple, <:AbstractArray})
157-
getters = getu.((sys,), sym)
158-
_call(getter, prob) = getter(prob)
157+
for (t1, t2) in [(ScalarSymbolic, Any), (ArraySymbolic, Any), (NotSymbolic, Union{<:Tuple, <:AbstractArray})]
158+
@eval function _getu(sys, ::NotSymbolic, ::$t1, sym::$t2)
159+
getters = getu.((sys,), sym)
160+
_call(getter, prob) = getter(prob)
159161

160-
function _getter(::Timeseries, prob)
161-
tiws = TimeseriesIndexWrapper.((prob,), eachindex(state_values(prob)))
162-
return [_getter(NotTimeseries(), tiw) for tiw in tiws]
163-
end
164-
_getter(::NotTimeseries, prob) = _call.(getters, (prob,))
165-
return function getter(prob)
166-
return _getter(is_timeseries(prob), prob)
162+
return let getters = getters, _call = _call
163+
_getter(::NotTimeseries, prob) = map(g -> g(prob), getters)
164+
function _getter(::Timeseries, prob)
165+
tiws = TimeseriesIndexWrapper.((prob,), eachindex(state_values(prob)))
166+
# Ideally this should recursively call `_getter` but that leads to type-instability
167+
# since the reference to itself is boxed
168+
# Turning this broadcasted `_call` into a map also makes this type-unstable
169+
170+
return map(tiw -> _call.(getters, (tiw,)), tiws)
171+
end
172+
173+
# Need another scope for this to not box `_getter`
174+
let _getter = _getter
175+
function getter(prob)
176+
return _getter(is_timeseries(prob), prob)
177+
end
178+
end
179+
end
167180
end
168181
end
169182

170-
function _getu(sys, ::ArraySymbolic, sym)
183+
function _getu(sys, ::ArraySymbolic, ::NotSymbolic, sym)
171184
return getu(sys, collect(sym))
172185
end
173186

187+
# setu doesn't need the same `let` blocks to be inferred for some reason
188+
174189
"""
175190
setu(sys, sym)
176191
@@ -186,36 +201,32 @@ This function does not work on types for which [`is_timeseries`](@ref) is
186201
function setu(sys, sym)
187202
symtype = symbolic_type(sym)
188203
elsymtype = symbolic_type(eltype(sym))
189-
190-
if symtype != NotSymbolic()
191-
_setu(sys, symtype, sym)
192-
else
193-
_setu(sys, elsymtype, sym)
194-
end
204+
_setu(sys, symtype, elsymtype, sym)
195205
end
196206

197-
function _setu(sys, ::NotSymbolic, sym)
207+
function _setu(sys, ::NotSymbolic, ::NotSymbolic, sym)
198208
return function setter!(prob, val)
199209
set_state!(prob, val, sym)
200210
end
201211
end
202212

203-
function _setu(sys, ::ScalarSymbolic, sym)
213+
function _setu(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, sym)
204214
is_variable(sys, sym) || error("Invalid symbol $sym for `setu`")
205215
idx = variable_index(sys, sym)
206216
return function setter!(prob, val)
207217
set_state!(prob, val, idx)
208218
end
209219
end
210220

211-
function _setu(sys, ::ScalarSymbolic, sym::Union{<:Tuple, <:AbstractArray})
212-
setters = setu.((sys,), sym)
213-
_call!(setter!, prob, val) = setter!(prob, val)
214-
return function setter!(prob, val)
215-
_call!.(setters, (prob,), val)
221+
for (t1, t2) in [(ScalarSymbolic, Any), (ArraySymbolic, Any), (NotSymbolic, Union{<:Tuple, <:AbstractArray})]
222+
@eval function _setu(sys, ::NotSymbolic, ::$t1, sym::$t2)
223+
setters = setu.((sys,), sym)
224+
return function setter!(prob, val)
225+
map((s!, v) -> s!(prob, v), setters, val)
226+
end
216227
end
217228
end
218229

219-
function _setu(sys, ::ArraySymbolic, sym)
230+
function _setu(sys, ::ArraySymbolic, ::NotSymbolic, sym)
220231
return setu(sys, collect(sym))
221232
end

test/parameter_indexing_test.jl

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,37 @@ end
99
SymbolicIndexingInterface.symbolic_container(fp::FakeIntegrator) = fp.sys
1010
SymbolicIndexingInterface.parameter_values(fp::FakeIntegrator) = fp.p
1111

12-
sys = SymbolCache([:x, :y, :z], [:a, :b], [:t])
13-
p = [1.0, 2.0]
12+
sys = SymbolCache([:x, :y, :z], [:a, :b, :c], [:t])
13+
p = [1.0, 2.0, 3.0]
1414
fi = FakeIntegrator(sys, copy(p))
15-
for (i, sym) in [(1, :a), (2, :b), ([1, 2], [:a, :b]), ((1, 2), (:a, :b))]
15+
new_p = [4.0, 5.0, 6.0]
16+
for (sym, oldval, newval, check_inference) in [
17+
(:a, p[1], new_p[1], true),
18+
(1, p[1], new_p[1], true),
19+
([:a, :b], p[1:2], new_p[1:2], true),
20+
(1:2, p[1:2], new_p[1:2], true),
21+
((1, 2), Tuple(p[1:2]), Tuple(new_p[1:2]), true),
22+
([:a, [:b, :c]], [p[1], p[2:3]], [new_p[1], new_p[2:3]], false),
23+
([:a, (:b, :c)], [p[1], (p[2], p[3])], [new_p[1], (new_p[2], new_p[3])], false),
24+
((:a, [:b, :c]), (p[1], p[2:3]), (new_p[1], new_p[2:3]), true),
25+
((:a, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true),
26+
([1, [:b, :c]], [p[1], p[2:3]], [new_p[1], new_p[2:3]], false),
27+
([1, (:b, :c)], [p[1], (p[2], p[3])], [new_p[1], (new_p[2], new_p[3])], false),
28+
((1, [:b, :c]), (p[1], p[2:3]), (new_p[1], new_p[2:3]), true),
29+
((1, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true),
30+
]
1631
get = getp(sys, sym)
1732
set! = setp(sys, sym)
18-
true_value = i isa Tuple ? getindex.((p,), i) : p[i]
19-
@test get(fi) == ParameterIndexingProxy(fi)[sym] == true_value
20-
set!(fi, 0.5 .* i)
21-
@test get(fi) == ParameterIndexingProxy(fi)[sym] == 0.5 .* i
22-
set!(fi, true_value)
33+
if check_inference
34+
@inferred get(fi)
35+
end
36+
@test get(fi) == oldval
37+
if check_inference
38+
@inferred set!(fi, newval)
39+
else
40+
set!(fi, newval)
41+
end
42+
@test get(fi) == newval
43+
set!(fi, oldval)
44+
@test get(fi) == oldval
2345
end

test/state_indexing_test.jl

Lines changed: 61 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,44 @@ SymbolicIndexingInterface.state_values(fp::FakeIntegrator) = fp.u
1111
sys = SymbolCache([:x, :y, :z], [:a, :b], [:t])
1212
u = [1.0, 2.0, 3.0]
1313
fi = FakeIntegrator(sys, copy(u))
14-
for (i, sym) in [(1, :x), (2, :y), (3, :z), ([1, 2], [:x, :y]), ((3, 2), (:z, :y))]
14+
# checking inference for non-concretely typed arrays will always fail
15+
for (sym, val, newval, check_inference) in [
16+
(:x, u[1], 4.0, true)
17+
(:y, u[2], 4.0, true)
18+
(:z, u[3], 4.0, true)
19+
(1, u[1], 4.0, true)
20+
([:x, :y], u[1:2], 4ones(2), true)
21+
([1, 2], u[1:2], 4ones(2), true)
22+
((:z, :y), (u[3], u[2]), (4.0, 5.0), true)
23+
((3, 2), (u[3], u[2]), (4.0, 5.0), true)
24+
([:x, [:y, :z]], [u[1], u[2:3]], [4.0, [5.0, 6.0]], false)
25+
([:x, 2:3], [u[1], u[2:3]], [4.0, [5.0, 6.0]], false)
26+
([:x, (:y, :z)], [u[1], (u[2], u[3])], [4.0, (5.0, 6.0)], false)
27+
([:x, Tuple(2:3)], [u[1], (u[2], u[3])], [4.0, (5.0, 6.0)], false)
28+
([:x, [:y], (:z,)], [u[1], [u[2]], (u[3],)], [4.0, [5.0], (6.0,)], false)
29+
([:x, [:y], (3,)], [u[1], [u[2]], (u[3],)], [4.0, [5.0], (6.0,)], false)
30+
((:x, [:y, :z]), (u[1], u[2:3]), (4.0, [5.0, 6.0]), true)
31+
((:x, (:y, :z)), (u[1], (u[2], u[3])), (4.0, (5.0, 6.0)), true)
32+
((1, (:y, :z)), (u[1], (u[2], u[3])), (4.0, (5.0, 6.0)), true)
33+
((:x, [:y], (:z,)), (u[1], [u[2]], (u[3],)), (4.0, [5.0], (6.0,)), true)
34+
]
1535
get = getu(sys, sym)
1636
set! = setu(sys, sym)
17-
true_value = i isa Tuple ? getindex.((u,), i) : u[i]
18-
@test get(fi) == true_value
19-
set!(fi, 0.5 .* i)
20-
@test get(fi) == 0.5 .* i
21-
set!(fi, true_value)
37+
if check_inference
38+
@inferred get(fi)
39+
end
40+
@test get(fi) == val
41+
if check_inference
42+
@inferred set!(fi, newval)
43+
else
44+
set!(fi, newval)
45+
end
46+
@test get(fi) == newval
47+
set!(fi, val)
48+
@test get(fi) == val
2249
end
2350

51+
2452
struct FakeSolution{S, U}
2553
sys::S
2654
u::U
@@ -33,12 +61,33 @@ SymbolicIndexingInterface.state_values(fp::FakeSolution) = fp.u
3361
sys = SymbolCache([:x, :y, :z], [:a, :b], [:t])
3462
u = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
3563
sol = FakeSolution(sys, u)
36-
for (i, sym) in [(1, :x), (2, :y), (3, :z), ([1, 2], [:x, :y]), ((3, 2), (:z, :y))]
64+
65+
xvals = getindex.(sol.u, 1)
66+
yvals = getindex.(sol.u, 2)
67+
zvals = getindex.(sol.u, 3)
68+
69+
for (sym, ans, check_inference) in [
70+
(:x, xvals, true)
71+
(:y, yvals, true)
72+
(:z, zvals, true)
73+
(1, xvals, true)
74+
([:x, :y], vcat.(xvals, yvals), true)
75+
(1:2, vcat.(xvals, yvals), true)
76+
([:x, 2], vcat.(xvals, yvals), false)
77+
((:z, :y), tuple.(zvals, yvals), true)
78+
((3, 2), tuple.(zvals, yvals), true)
79+
([:x, [:y, :z]], vcat.(xvals, [[x] for x in vcat.(yvals, zvals)]), false)
80+
([:x, (:y, :z)], vcat.(xvals, tuple.(yvals, zvals)), false)
81+
([1, (:y, :z)], vcat.(xvals, tuple.(yvals, zvals)), false)
82+
([:x, [:y, :z], (:x, :z)], vcat.(xvals, [[x] for x in vcat.(yvals, zvals)], tuple.(xvals, zvals)), false)
83+
([:x, [:y, 3], (1, :z)], vcat.(xvals, [[x] for x in vcat.(yvals, zvals)], tuple.(xvals, zvals)), false)
84+
((:x, [:y, :z]), tuple.(xvals, vcat.(yvals, zvals)), true)
85+
((:x, (:y, :z)), tuple.(xvals, tuple.(yvals, zvals)), true)
86+
((:x, [:y, :z], (:z, :y)), tuple.(xvals, vcat.(yvals, zvals), tuple.(zvals, yvals)), true)
87+
]
3788
get = getu(sys, sym)
38-
true_value = if i isa Tuple
39-
[getindex.((v,), i) for v in u]
40-
else
41-
getindex.(u, (i,))
89+
if check_inference
90+
@inferred get(sol)
4291
end
43-
@test get(sol) == true_value
92+
@test get(sol) == ans
4493
end

0 commit comments

Comments
 (0)