Skip to content

Commit cf2aae7

Browse files
Merge pull request #33 from SciML/as/getu-improved
feat: improve getu/setu/getp/setp handling of nested variables
2 parents a7e0efb + 07d1e73 commit cf2aae7

File tree

4 files changed

+355
-122
lines changed

4 files changed

+355
-122
lines changed

src/parameter_indexing.jl

Lines changed: 43 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,13 @@
22
parameter_values(p)
33
44
Return an indexable collection containing the value of each parameter in `p`.
5+
6+
If this function is called with an `AbstractArray`, it will return the same array.
57
"""
68
function parameter_values end
79

10+
parameter_values(arr::AbstractArray) = arr
11+
812
"""
913
set_parameter!(sys, val, idx)
1014
@@ -22,87 +26,96 @@ end
2226
"""
2327
getp(sys, p)
2428
25-
Return a function that takes an integrator or solution of `sys`, and returns the value of
26-
the parameter `p`. Note that `p` can be a direct numerical index or a symbolic value.
29+
Return a function that takes an array representing the parameter vector or an integrator
30+
or solution of `sys`, and returns the value of the parameter `p`. Note that `p` can be a
31+
direct numerical index or a symbolic value, or an array/tuple of the aforementioned.
32+
2733
Requires that the integrator or solution implement [`parameter_values`](@ref). This function
2834
typically does not need to be implemented, and has a default implementation relying on
2935
[`parameter_values`](@ref).
3036
"""
3137
function getp(sys, p)
3238
symtype = symbolic_type(p)
3339
elsymtype = symbolic_type(eltype(p))
34-
if symtype != NotSymbolic()
35-
return _getp(sys, symtype, p)
36-
else
37-
return _getp(sys, elsymtype, p)
38-
end
40+
_getp(sys, symtype, elsymtype, p)
3941
end
4042

41-
function _getp(sys, ::NotSymbolic, p)
43+
function _getp(sys, ::NotSymbolic, ::NotSymbolic, p)
4244
return function getter(sol)
4345
return parameter_values(sol)[p]
4446
end
4547
end
4648

47-
function _getp(sys, ::ScalarSymbolic, p)
49+
function _getp(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, p)
4850
idx = parameter_index(sys, p)
4951
return function getter(sol)
5052
return parameter_values(sol)[idx]
5153
end
5254
end
5355

54-
function _getp(sys, ::ScalarSymbolic, p::Union{Tuple, AbstractArray})
55-
idxs = parameter_index.((sys,), p)
56-
return function getter(sol)
57-
return getindex.((parameter_values(sol),), idxs)
56+
for (t1, t2) in [
57+
(ArraySymbolic, Any),
58+
(ScalarSymbolic, Any),
59+
(NotSymbolic, Union{<:Tuple, <:AbstractArray}),
60+
]
61+
@eval function _getp(sys, ::NotSymbolic, ::$t1, p::$t2)
62+
getters = getp.((sys,), p)
63+
64+
return function getter(sol)
65+
map(g -> g(sol), getters)
66+
end
5867
end
5968
end
6069

61-
function _getp(sys, ::ArraySymbolic, p)
70+
function _getp(sys, ::ArraySymbolic, ::NotSymbolic, p)
6271
return getp(sys, collect(p))
6372
end
6473

6574
"""
6675
setp(sys, p)
6776
68-
Return a function that takes an integrator of `sys` and a value, and sets
69-
the parameter `p` to that value. Note that `p` can be a direct numerical index or a
70-
symbolic value. Requires that the integrator implement [`parameter_values`](@ref) and the
71-
returned collection be a mutable reference to the parameter vector in the integrator. In
77+
Return a function that takes an array representing the parameter vector or an integrator
78+
or problem of `sys`, and a value, and sets the parameter `p` to that value. Note that `p`
79+
can be a direct numerical index or a symbolic value.
80+
81+
Requires that the integrator implement [`parameter_values`](@ref) and the returned
82+
collection be a mutable reference to the parameter vector in the integrator. In
7283
case `parameter_values` cannot return such a mutable reference, or additional actions
7384
need to be performed when updating parameters, [`set_parameter!`](@ref) must be
7485
implemented.
7586
"""
7687
function setp(sys, p)
7788
symtype = symbolic_type(p)
7889
elsymtype = symbolic_type(eltype(p))
79-
if symtype != NotSymbolic()
80-
return _setp(sys, symtype, p)
81-
else
82-
return _setp(sys, elsymtype, p)
83-
end
90+
_setp(sys, symtype, elsymtype, p)
8491
end
8592

86-
function _setp(sys, ::NotSymbolic, p)
93+
function _setp(sys, ::NotSymbolic, ::NotSymbolic, p)
8794
return function setter!(sol, val)
8895
set_parameter!(sol, val, p)
8996
end
9097
end
9198

92-
function _setp(sys, ::ScalarSymbolic, p)
99+
function _setp(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, p)
93100
idx = parameter_index(sys, p)
94101
return function setter!(sol, val)
95102
set_parameter!(sol, val, idx)
96103
end
97104
end
98105

99-
function _setp(sys, ::ScalarSymbolic, p::Union{Tuple, AbstractArray})
100-
idxs = parameter_index.((sys,), p)
101-
return function setter!(sol, val)
102-
set_parameter!.((sol,), val, idxs)
106+
for (t1, t2) in [
107+
(ArraySymbolic, Any),
108+
(ScalarSymbolic, Any),
109+
(NotSymbolic, Union{<:Tuple, <:AbstractArray}),
110+
]
111+
@eval function _setp(sys, ::NotSymbolic, ::$t1, p::$t2)
112+
setters = setp.((sys,), p)
113+
return function setter!(sol, val)
114+
map((s!, v) -> s!(sol, v), setters, val)
115+
end
103116
end
104117
end
105118

106-
function _setp(sys, ::ArraySymbolic, p)
119+
function _setp(sys, ::ArraySymbolic, ::NotSymbolic, p)
107120
return setp(sys, collect(p))
108121
end

0 commit comments

Comments
 (0)