Skip to content

Commit ccbfdc5

Browse files
Merge pull request #42 from SciML/as/getp-generic
refactor: make getp generic of the parameter container
2 parents b6bbb66 + 599575d commit ccbfdc5

File tree

7 files changed

+106
-77
lines changed

7 files changed

+106
-77
lines changed

docs/pages.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ pages = [
55
"Tutorials" => [
66
"Using the SciML Symbolic Indexing Interface" => "usage.md",
77
"Simple Demonstration of a Symbolic System Structure" => "simple_sii_sys.md",
8-
"Implementing the Complete Symbolic Indexing Interface" => "complete_sii.md",
8+
"Implementing the Complete Symbolic Indexing Interface" => "complete_sii.md"
99
],
1010
"Defining Solution Wrapper Fallbacks" => "solution_wrappers.md",
11-
"API" => "api.md",
11+
"API" => "api.md"
1212
]

src/SymbolicIndexingInterface.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@ export ScalarSymbolic, ArraySymbolic, NotSymbolic, symbolic_type, hasname, getna
44
include("trait.jl")
55

66
export is_variable, variable_index, variable_symbols, is_parameter, parameter_index,
7-
parameter_symbols, is_independent_variable, independent_variable_symbols, is_observed,
8-
observed, is_time_dependent, constant_structure, symbolic_container,
9-
all_variable_symbols,
10-
all_symbols, solvedvariables, allvariables
7+
parameter_symbols, is_independent_variable, independent_variable_symbols,
8+
is_observed,
9+
observed, is_time_dependent, constant_structure, symbolic_container,
10+
all_variable_symbols,
11+
all_symbols, solvedvariables, allvariables
1112
include("interface.jl")
1213

1314
export SymbolCache
@@ -17,7 +18,7 @@ export parameter_values, set_parameter!, getp, setp
1718
include("parameter_indexing.jl")
1819

1920
export Timeseries,
20-
NotTimeseries, is_timeseries, state_values, set_state!, current_time, getu, setu
21+
NotTimeseries, is_timeseries, state_values, set_state!, current_time, getu, setu
2122
include("state_indexing.jl")
2223

2324
export ParameterIndexingProxy

src/parameter_indexing.jl

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,19 @@
11
"""
22
parameter_values(p)
3+
parameter_values(p, i)
34
4-
Return an indexable collection containing the value of each parameter in `p`.
5+
Return an indexable collection containing the value of each parameter in `p`. The two-
6+
argument version of this function returns the parameter value at index `i`. The
7+
two-argument version of this function will default to returning
8+
`parameter_values(p)[i]`.
59
610
If this function is called with an `AbstractArray`, it will return the same array.
711
"""
812
function parameter_values end
913

1014
parameter_values(arr::AbstractArray) = arr
15+
parameter_values(arr::AbstractArray, i) = arr[i]
16+
parameter_values(prob, i) = parameter_values(parameter_values(prob), i)
1117

1218
"""
1319
set_parameter!(sys, val, idx)
@@ -19,16 +25,19 @@ defined to enable the proper functioning of [`setp`](@ref).
1925
2026
See: [`parameter_values`](@ref)
2127
"""
22-
function set_parameter!(sys, val, idx)
23-
parameter_values(sys)[idx] = val
28+
function set_parameter! end
29+
30+
function set_parameter!(sys::AbstractArray, val, idx)
31+
sys[idx] = val
2432
end
33+
set_parameter!(sys, val, idx) = set_parameter!(parameter_values(sys), val, idx)
2534

2635
"""
2736
getp(sys, p)
2837
2938
Return a function that takes an array representing the parameter vector or an integrator
3039
or solution of `sys`, and returns the value of the parameter `p`. Note that `p` can be a
31-
direct numerical index or a symbolic value, or an array/tuple of the aforementioned.
40+
direct index or a symbolic value, or an array/tuple of the aforementioned.
3241
3342
Requires that the integrator or solution implement [`parameter_values`](@ref). This function
3443
typically does not need to be implemented, and has a default implementation relying on
@@ -42,21 +51,21 @@ end
4251

4352
function _getp(sys, ::NotSymbolic, ::NotSymbolic, p)
4453
return function getter(sol)
45-
return parameter_values(sol)[p]
54+
return parameter_values(sol, p)
4655
end
4756
end
4857

4958
function _getp(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, p)
5059
idx = parameter_index(sys, p)
5160
return function getter(sol)
52-
return parameter_values(sol)[idx]
61+
return parameter_values(sol, idx)
5362
end
5463
end
5564

5665
for (t1, t2) in [
5766
(ArraySymbolic, Any),
5867
(ScalarSymbolic, Any),
59-
(NotSymbolic, Union{<:Tuple, <:AbstractArray}),
68+
(NotSymbolic, Union{<:Tuple, <:AbstractArray})
6069
]
6170
@eval function _getp(sys, ::NotSymbolic, ::$t1, p::$t2)
6271
getters = getp.((sys,), p)
@@ -76,7 +85,7 @@ end
7685
7786
Return a function that takes an array representing the parameter vector or an integrator
7887
or problem of `sys`, and a value, and sets the parameter `p` to that value. Note that `p`
79-
can be a direct numerical index or a symbolic value.
88+
can be a direct index or a symbolic value.
8089
8190
Requires that the integrator implement [`parameter_values`](@ref) and the returned
8291
collection be a mutable reference to the parameter vector in the integrator. In
@@ -106,7 +115,7 @@ end
106115
for (t1, t2) in [
107116
(ArraySymbolic, Any),
108117
(ScalarSymbolic, Any),
109-
(NotSymbolic, Union{<:Tuple, <:AbstractArray}),
118+
(NotSymbolic, Union{<:Tuple, <:AbstractArray})
110119
]
111120
@eval function _setp(sys, ::NotSymbolic, ::$t1, p::$t2)
112121
setters = setp.((sys,), p)

src/state_indexing.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ end
116116
function _getu(sys, ::NotSymbolic, ::NotSymbolic, sym)
117117
_getter(::Timeseries, prob) = getindex.(state_values(prob), (sym,))
118118
_getter(::Timeseries, prob, i) = getindex(state_values(prob, i), sym)
119-
_getter(::NotTimeseries, prob) = state_values(prob)[sym]
119+
_getter(::NotTimeseries, prob) = state_values(prob, sym)
120120
return let _getter = _getter
121121
function getter(prob)
122122
return _getter(is_timeseries(prob), prob)
@@ -186,7 +186,7 @@ end
186186
for (t1, t2) in [
187187
(ScalarSymbolic, Any),
188188
(ArraySymbolic, Any),
189-
(NotSymbolic, Union{<:Tuple, <:AbstractArray}),
189+
(NotSymbolic, Union{<:Tuple, <:AbstractArray})
190190
]
191191
@eval function _getu(sys, ::NotSymbolic, ::$t1, sym::$t2)
192192
num_observed = count(x -> is_observed(sys, x), sym)
@@ -266,7 +266,7 @@ end
266266
267267
Return a function that takes an array representing the state vector or an integrator or
268268
problem of `sys`, and a value, and sets the the state `sym` to that value. Note that `sym`
269-
can be a direct numerical index, a symbolic state, or an array/tuple of the aforementioned.
269+
can be a direct index, a symbolic state, or an array/tuple of the aforementioned.
270270
271271
Requires that the integrator implement [`state_values`](@ref) and the
272272
returned collection be a mutable reference to the state vector in the integrator/problem. Alternatively, if this is not possible or additional actions need to
@@ -301,7 +301,7 @@ end
301301
for (t1, t2) in [
302302
(ScalarSymbolic, Any),
303303
(ArraySymbolic, Any),
304-
(NotSymbolic, Union{<:Tuple, <:AbstractArray}),
304+
(NotSymbolic, Union{<:Tuple, <:AbstractArray})
305305
]
306306
@eval function _setu(sys, ::NotSymbolic, ::$t1, sym::$t2)
307307
setters = setu.((sys,), sym)

src/symbol_cache.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ array containing a single variable if the system has only one independent variab
1414
struct SymbolCache{
1515
V <: Union{Nothing, AbstractVector},
1616
P <: Union{Nothing, AbstractVector},
17-
I,
17+
I
1818
}
1919
variables::V
2020
parameters::P

test/parameter_indexing_test.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ for (sym, oldval, newval, check_inference) in [
2626
([1, [:b, :c]], [p[1], p[2:3]], [new_p[1], new_p[2:3]], false),
2727
([1, (:b, :c)], [p[1], (p[2], p[3])], [new_p[1], (new_p[2], new_p[3])], false),
2828
((1, [:b, :c]), (p[1], p[2:3]), (new_p[1], new_p[2:3]), true),
29-
((1, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true),
29+
((1, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true)
3030
]
3131
get = getp(sys, sym)
3232
set! = setp(sys, sym)

test/state_indexing_test.jl

Lines changed: 74 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -19,23 +19,33 @@ t = 0.5
1919
fi = FakeIntegrator(sys, copy(u), copy(p), t)
2020
# checking inference for non-concretely typed arrays will always fail
2121
for (sym, val, newval, check_inference) in [(:x, u[1], 4.0, true)
22-
(:y, u[2], 4.0, true)
23-
(:z, u[3], 4.0, true)
24-
(1, u[1], 4.0, true)
25-
([:x, :y], u[1:2], 4ones(2), true)
26-
([1, 2], u[1:2], 4ones(2), true)
27-
((:z, :y), (u[3], u[2]), (4.0, 5.0), true)
28-
((3, 2), (u[3], u[2]), (4.0, 5.0), true)
29-
([:x, [:y, :z]], [u[1], u[2:3]], [4.0, [5.0, 6.0]], false)
30-
([:x, 2:3], [u[1], u[2:3]], [4.0, [5.0, 6.0]], false)
31-
([:x, (:y, :z)], [u[1], (u[2], u[3])], [4.0, (5.0, 6.0)], false)
32-
([:x, Tuple(2:3)], [u[1], (u[2], u[3])], [4.0, (5.0, 6.0)], false)
33-
([:x, [:y], (:z,)], [u[1], [u[2]], (u[3],)], [4.0, [5.0], (6.0,)], false)
34-
([:x, [:y], (3,)], [u[1], [u[2]], (u[3],)], [4.0, [5.0], (6.0,)], false)
35-
((:x, [:y, :z]), (u[1], u[2:3]), (4.0, [5.0, 6.0]), true)
36-
((:x, (:y, :z)), (u[1], (u[2], u[3])), (4.0, (5.0, 6.0)), true)
37-
((1, (:y, :z)), (u[1], (u[2], u[3])), (4.0, (5.0, 6.0)), true)
38-
((:x, [:y], (:z,)), (u[1], [u[2]], (u[3],)), (4.0, [5.0], (6.0,)), true)]
22+
(:y, u[2], 4.0, true)
23+
(:z, u[3], 4.0, true)
24+
(1, u[1], 4.0, true)
25+
([:x, :y], u[1:2], 4ones(2), true)
26+
([1, 2], u[1:2], 4ones(2), true)
27+
((:z, :y), (u[3], u[2]), (4.0, 5.0), true)
28+
((3, 2), (u[3], u[2]), (4.0, 5.0), true)
29+
([:x, [:y, :z]], [u[1], u[2:3]],
30+
[4.0, [5.0, 6.0]], false)
31+
([:x, 2:3], [u[1], u[2:3]],
32+
[4.0, [5.0, 6.0]], false)
33+
([:x, (:y, :z)], [u[1], (u[2], u[3])],
34+
[4.0, (5.0, 6.0)], false)
35+
([:x, Tuple(2:3)], [u[1], (u[2], u[3])],
36+
[4.0, (5.0, 6.0)], false)
37+
([:x, [:y], (:z,)], [u[1], [u[2]], (u[3],)],
38+
[4.0, [5.0], (6.0,)], false)
39+
([:x, [:y], (3,)], [u[1], [u[2]], (u[3],)],
40+
[4.0, [5.0], (6.0,)], false)
41+
((:x, [:y, :z]), (u[1], u[2:3]),
42+
(4.0, [5.0, 6.0]), true)
43+
((:x, (:y, :z)), (u[1], (u[2], u[3])),
44+
(4.0, (5.0, 6.0)), true)
45+
((1, (:y, :z)), (u[1], (u[2], u[3])),
46+
(4.0, (5.0, 6.0)), true)
47+
((:x, [:y], (:z,)), (u[1], [u[2]], (u[3],)),
48+
(4.0, [5.0], (6.0,)), true)]
3949
get = getu(sys, sym)
4050
set! = setu(sys, sym)
4151
if check_inference
@@ -66,12 +76,12 @@ for (sym, val, newval, check_inference) in [(:x, u[1], 4.0, true)
6676
end
6777

6878
for (sym, oldval, newval, check_inference) in [(:a, p[1], 4.0, true)
69-
(:b, p[2], 5.0, true)
70-
(:c, p[3], 6.0, true)
71-
([:a, :b], p[1:2], [4.0, 5.0], true)
72-
((:c, :b), (p[3], p[2]), (6.0, 5.0), true)
73-
([:x, :a], [u[1], p[1]], [4.0, 5.0], false)
74-
((:y, :b), (u[2], p[2]), (5.0, 6.0), true)]
79+
(:b, p[2], 5.0, true)
80+
(:c, p[3], 6.0, true)
81+
([:a, :b], p[1:2], [4.0, 5.0], true)
82+
((:c, :b), (p[3], p[2]), (6.0, 5.0), true)
83+
([:x, :a], [u[1], p[1]], [4.0, 5.0], false)
84+
((:y, :b), (u[2], p[2]), (5.0, 6.0), true)]
7585
get = getu(fi, sym)
7686
set! = setu(fi, sym)
7787
if check_inference
@@ -91,7 +101,7 @@ end
91101
for (sym, val, check_inference) in [
92102
(:t, t, true),
93103
([:x, :a, :t], [u[1], p[1], t], false),
94-
((:x, :a, :t), (u[1], p[1], t), true),
104+
((:x, :a, :t), (u[1], p[1], t), true)
95105
]
96106
get = getu(fi, sym)
97107
if check_inference
@@ -123,33 +133,42 @@ yvals = getindex.(sol.u, 2)
123133
zvals = getindex.(sol.u, 3)
124134

125135
for (sym, ans, check_inference) in [(:x, xvals, true)
126-
(:y, yvals, true)
127-
(:z, zvals, true)
128-
(1, xvals, true)
129-
([:x, :y], vcat.(xvals, yvals), true)
130-
(1:2, vcat.(xvals, yvals), true)
131-
([:x, 2], vcat.(xvals, yvals), false)
132-
((:z, :y), tuple.(zvals, yvals), true)
133-
((3, 2), tuple.(zvals, yvals), true)
134-
([:x, [:y, :z]], vcat.(xvals, [[x] for x in vcat.(yvals, zvals)]), false)
135-
([:x, (:y, :z)], vcat.(xvals, tuple.(yvals, zvals)), false)
136-
([1, (:y, :z)], vcat.(xvals, tuple.(yvals, zvals)), false)
137-
([:x, [:y, :z], (:x, :z)],
138-
vcat.(xvals, [[x] for x in vcat.(yvals, zvals)], tuple.(xvals, zvals)),
139-
false)
140-
([:x, [:y, 3], (1, :z)],
141-
vcat.(xvals, [[x] for x in vcat.(yvals, zvals)], tuple.(xvals, zvals)),
142-
false)
143-
((:x, [:y, :z]), tuple.(xvals, vcat.(yvals, zvals)), true)
144-
((:x, (:y, :z)), tuple.(xvals, tuple.(yvals, zvals)), true)
145-
((:x, [:y, :z], (:z, :y)),
146-
tuple.(xvals, vcat.(yvals, zvals), tuple.(zvals, yvals)),
147-
true)
148-
([:x, :a], vcat.(xvals, p[1]), false)
149-
((:y, :b), tuple.(yvals, p[2]), true)
150-
(:t, t, true)
151-
([:x, :a, :t], vcat.(xvals, p[1], t), false)
152-
((:x, :a, :t), tuple.(xvals, p[1], t), true)]
136+
(:y, yvals, true)
137+
(:z, zvals, true)
138+
(1, xvals, true)
139+
([:x, :y], vcat.(xvals, yvals), true)
140+
(1:2, vcat.(xvals, yvals), true)
141+
([:x, 2], vcat.(xvals, yvals), false)
142+
((:z, :y), tuple.(zvals, yvals), true)
143+
((3, 2), tuple.(zvals, yvals), true)
144+
([:x, [:y, :z]],
145+
vcat.(xvals, [[x] for x in vcat.(yvals, zvals)]),
146+
false)
147+
([:x, (:y, :z)],
148+
vcat.(xvals, tuple.(yvals, zvals)), false)
149+
([1, (:y, :z)],
150+
vcat.(xvals, tuple.(yvals, zvals)), false)
151+
([:x, [:y, :z], (:x, :z)],
152+
vcat.(xvals, [[x] for x in vcat.(yvals, zvals)],
153+
tuple.(xvals, zvals)),
154+
false)
155+
([:x, [:y, 3], (1, :z)],
156+
vcat.(xvals, [[x] for x in vcat.(yvals, zvals)],
157+
tuple.(xvals, zvals)),
158+
false)
159+
((:x, [:y, :z]),
160+
tuple.(xvals, vcat.(yvals, zvals)), true)
161+
((:x, (:y, :z)),
162+
tuple.(xvals, tuple.(yvals, zvals)), true)
163+
((:x, [:y, :z], (:z, :y)),
164+
tuple.(xvals, vcat.(yvals, zvals),
165+
tuple.(zvals, yvals)),
166+
true)
167+
([:x, :a], vcat.(xvals, p[1]), false)
168+
((:y, :b), tuple.(yvals, p[2]), true)
169+
(:t, t, true)
170+
([:x, :a, :t], vcat.(xvals, p[1], t), false)
171+
((:x, :a, :t), tuple.(xvals, p[1], t), true)]
153172
get = getu(sys, sym)
154173
if check_inference
155174
@inferred get(sol)
@@ -164,10 +183,10 @@ for (sym, ans, check_inference) in [(:x, xvals, true)
164183
end
165184

166185
for (sym, val) in [(:a, p[1])
167-
(:b, p[2])
168-
(:c, p[3])
169-
([:a, :b], p[1:2])
170-
((:c, :b), (p[3], p[2]))]
186+
(:b, p[2])
187+
(:c, p[3])
188+
([:a, :b], p[1:2])
189+
((:c, :b), (p[3], p[2]))]
171190
get = getu(fi, sym)
172191
@inferred get(fi)
173192
@test get(fi) == val

0 commit comments

Comments
 (0)