Skip to content

Commit 8179344

Browse files
Merge pull request #92 from SciML/as/oop-setp
feat: add out-of-place `setp`, refactor `remake_buffer`
2 parents 82f1464 + 5a3ac25 commit 8179344

File tree

9 files changed

+161
-40
lines changed

9 files changed

+161
-40
lines changed

docs/src/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ set_parameter!
7575
finalize_parameters_hook!
7676
getp
7777
setp
78+
setp_oop
7879
ParameterIndexingProxy
7980
```
8081

src/SymbolicIndexingInterface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ include("value_provider_interface.jl")
3232
export ParameterTimeseriesCollection
3333
include("parameter_timeseries_collection.jl")
3434

35-
export getp, setp
35+
export getp, setp, setp_oop
3636
include("parameter_indexing.jl")
3737

3838
export getu, setu

src/parameter_indexing.jl

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -647,7 +647,7 @@ end
647647
"""
648648
setp(indp, sym)
649649
650-
Return a function that takes an index provider and a value, and sets the parameter `sym`
650+
Return a function that takes a value provider and a value, and sets the parameter `sym`
651651
to that value. Note that `sym` can be an index, a symbolic variable, or an array/tuple of
652652
the aforementioned.
653653
@@ -709,3 +709,69 @@ function _setp(sys, ::ArraySymbolic, ::SymbolicTypeTrait, p)
709709
end
710710
return setp(sys, collect(p); run_hook = false)
711711
end
712+
713+
"""
714+
setp_oop(indp, sym)
715+
716+
Return a function which takes a value provider `valp` and a value `val`, and returns
717+
`parameter_values(valp)` with the parameters at `sym` set to `val`. This allows changing
718+
the types of values stored, and leverages [`remake_buffer`](@ref). Note that `sym` can be
719+
an index, a symbolic variable, or an array/tuple of the aforementioned.
720+
721+
Requires that the value provider implement `parameter_values` and `remake_buffer`.
722+
"""
723+
function setp_oop(indp, sym)
724+
symtype = symbolic_type(sym)
725+
elsymtype = symbolic_type(eltype(sym))
726+
return _setp_oop(indp, symtype, elsymtype, sym)
727+
end
728+
729+
struct OOPSetter{I, D}
730+
indp::I
731+
idxs::D
732+
end
733+
734+
function (os::OOPSetter)(valp, val)
735+
return remake_buffer(os.indp, parameter_values(valp), (os.idxs,), (val,))
736+
end
737+
738+
function (os::OOPSetter)(valp, val::Union{Tuple, AbstractArray})
739+
if os.idxs isa Union{Tuple, AbstractArray}
740+
return remake_buffer(os.indp, parameter_values(valp), os.idxs, val)
741+
else
742+
return remake_buffer(os.indp, parameter_values(valp), (os.idxs,), (val,))
743+
end
744+
end
745+
746+
function _root_indp(indp)
747+
if hasmethod(symbolic_container, Tuple{typeof(indp)}) &&
748+
(sc = symbolic_container(indp)) != indp
749+
return _root_indp(sc)
750+
else
751+
return indp
752+
end
753+
end
754+
755+
function _setp_oop(indp, ::NotSymbolic, ::NotSymbolic, sym)
756+
return OOPSetter(_root_indp(indp), sym)
757+
end
758+
759+
function _setp_oop(indp, ::ScalarSymbolic, ::SymbolicTypeTrait, sym)
760+
return OOPSetter(_root_indp(indp), parameter_index(indp, sym))
761+
end
762+
763+
for (t1, t2) in [
764+
(ScalarSymbolic, Any),
765+
(NotSymbolic, Union{<:Tuple, <:AbstractArray})
766+
]
767+
@eval function _setp_oop(indp, ::NotSymbolic, ::$t1, sym::$t2)
768+
return OOPSetter(_root_indp(indp), parameter_index.((indp,), sym))
769+
end
770+
end
771+
772+
function _setp_oop(indp, ::ArraySymbolic, ::SymbolicTypeTrait, sym)
773+
if is_parameter(indp, sym)
774+
return OOPSetter(_root_indp(indp), parameter_index(indp, sym))
775+
end
776+
error("$sym is not a valid parameter")
777+
end

src/remake.jl

Lines changed: 45 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,34 @@
11
"""
2-
remake_buffer(indp, oldbuffer, vals::Dict)
3-
4-
Return a copy of the buffer `oldbuffer` with values from `vals`. The keys of `vals`
5-
are symbolic variables whose index in the buffer is determined using `indp`. The types of
6-
values in `vals` may not match the types of values stored at the corresponding indexes in
7-
the buffer, in which case the type of the buffer should be promoted accordingly. In
8-
general, this method should attempt to preserve the types of values stored in `vals` as
9-
much as possible. Types can be promoted for type-stability, to maintain performance. The
10-
returned buffer should be of the same type (ignoring type-parameters) as `oldbuffer`.
11-
12-
This method is already implemented for
13-
`remake_buffer(indp, oldbuffer::AbstractArray, vals::Dict)` and supports static arrays
14-
as well. It is also implemented for `oldbuffer::Tuple`.
2+
remake_buffer(indp, oldbuffer, idxs, vals)
3+
4+
Return a copy of the buffer `oldbuffer` with at (optionally symbolic) indexes `idxs`
5+
replaced by corresponding values from `vals`. Both `idxs` and `vals` must be iterables of
6+
the same length. `idxs` may contain symbolic variables whose index in the buffer is
7+
determined using `indp`. The types of values in `vals` may not match the types of values
8+
stored at the corresponding indexes in the buffer, in which case the type of the buffer
9+
should be promoted accordingly. In general, this method should attempt to preserve the
10+
types of values stored in `vals` as much as possible. Types can be promoted for
11+
type-stability, to maintain performance. The returned buffer should be of the same type
12+
(ignoring type-parameters) as `oldbuffer`.
13+
14+
This method is already implemented for `oldbuffer::AbstractArray` and `oldbuffer::Tuple`,
15+
and supports static arrays as well.
16+
17+
The deprecated version of this method which takes a `Dict` mapping symbols to values
18+
instead of `idxs` and `vals` will dispatch to the new method. In addition if
19+
no `remake_buffer` method exists with the new signature, it will call
20+
`remake_buffer(sys, oldbuffer, Dict(idxs .=> vals))`.
21+
22+
Note that the new method signature allows `idxs` to be indexes, instead of requiring
23+
that they be symbolic variables. Thus, any type which implements the new method must
24+
also support indexes in `idxs`.
1525
"""
16-
function remake_buffer(sys, oldbuffer::AbstractArray, vals::Dict)
26+
function remake_buffer(sys, oldbuffer::AbstractArray, idxs, vals)
1727
# similar when used with an `MArray` and nonconcrete eltype returns a
1828
# SizedArray. `similar_type` still returns an `MArray`
1929
if ArrayInterface.ismutable(oldbuffer) && !isa(oldbuffer, MArray)
2030
elT = Union{}
21-
for val in values(vals)
31+
for val in vals
2232
if val isa AbstractArray
2333
valT = eltype(val)
2434
else
@@ -29,7 +39,8 @@ function remake_buffer(sys, oldbuffer::AbstractArray, vals::Dict)
2939

3040
newbuffer = similar(oldbuffer, elT)
3141
copyto!(newbuffer, oldbuffer)
32-
for (k, v) in vals
42+
for (k, v) in zip(idxs, vals)
43+
is_variable(sys, k) || is_parameter(sys, k) || continue
3344
if v isa AbstractArray
3445
v = elT.(v)
3546
else
@@ -38,12 +49,16 @@ function remake_buffer(sys, oldbuffer::AbstractArray, vals::Dict)
3849
setu(sys, k)(newbuffer, v)
3950
end
4051
else
41-
mutbuffer = remake_buffer(sys, collect(oldbuffer), vals)
52+
mutbuffer = remake_buffer(sys, collect(oldbuffer), idxs, vals)
4253
newbuffer = similar_type(oldbuffer, eltype(mutbuffer))(mutbuffer)
4354
end
4455
return newbuffer
4556
end
4657

58+
function remake_buffer(sys, oldbuffer, idxs, vals)
59+
remake_buffer(sys, oldbuffer, Dict(idxs .=> vals))
60+
end
61+
4762
mutable struct TupleRemakeWrapper
4863
t::Tuple
4964
end
@@ -54,8 +69,19 @@ function set_parameter!(sys::TupleRemakeWrapper, val, idx)
5469
sys.t = tp
5570
end
5671

57-
function remake_buffer(sys, oldbuffer::Tuple, vals::Dict)
72+
function set_state!(sys::TupleRemakeWrapper, val, idx)
73+
tp = sys.t
74+
@reset tp[idx] = val
75+
sys.t = tp
76+
end
77+
78+
function remake_buffer(sys, oldbuffer::Tuple, idxs, vals)
5879
wrap = TupleRemakeWrapper(oldbuffer)
59-
setu(sys, collect(keys(vals)))(wrap, values(vals))
80+
setu(sys, idxs)(wrap, vals)
6081
return wrap.t
6182
end
83+
84+
@deprecate remake_buffer(sys, oldbuffer, vals::Dict) remake_buffer(
85+
sys, oldbuffer, keys(vals), values(vals))
86+
@deprecate remake_buffer(sys, oldbuffer::Tuple, vals::Dict) remake_buffer(
87+
sys, oldbuffer, collect(keys(vals)), collect(values(vals)))

test/downstream/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@ SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
55
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
66

77
[compat]
8-
SymbolicUtils = "<1.6"
8+
SymbolicUtils = "3.2"

test/downstream/remake_arrayvars.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,5 @@ using SymbolicIndexingInterface
77
sys = complete(sys)
88

99
u0 = [1.0, 2.0, 3.0]
10-
newu0 = remake_buffer(sys, u0, Dict(x => [5.0, 6.0], y => 7.0))
10+
newu0 = remake_buffer(sys, u0, [x, y], ([5.0, 6.0], 7.0))
1111
@test newu0 == [5.0, 6.0, 7.0]

test/parameter_indexing_test.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,18 @@ for sys in [
168168
@test getter(fi) == []
169169
getter = getp(sys, ())
170170
@test getter(fi) == ()
171+
172+
for (sym, val) in [
173+
(:a, 1.0f1),
174+
(1, 1.0f1),
175+
([:a, :b], [1.0f1, 2.0f1]),
176+
((:b, :c), (2.0f1, 3.0f1))
177+
]
178+
setter = setp_oop(fi, sym)
179+
newp = setter(fi, val)
180+
getter = getp(sys, sym)
181+
@test getter(newp) == val
182+
end
171183
end
172184
end
173185

test/remake_test.jl

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,47 +3,60 @@ using StaticArrays
33

44
sys = SymbolCache([:x, :y, :z], [:a, :b, :c], :t)
55

6-
for (buf, newbuf, newvals) in [
6+
for (buf, newbuf, idxs, vals) in [
77
# standard operation
8-
([1.0, 2.0, 3.0], [2.0, 3.0, 4.0], Dict(:x => 2.0, :y => 3.0, :z => 4.0)),
8+
([1.0, 2.0, 3.0], [2.0, 3.0, 4.0], [:x, :y, :z], [2.0, 3.0, 4.0]),
99
# buffer type "demotion"
10-
([1.0, 2.0, 3.0], [2, 2, 3], Dict(:x => 2)),
10+
([1.0, 2.0, 3.0], [2, 2, 3], [:x], [2]),
1111
# buffer type promotion
12-
([1, 2, 3], [2.0, 2.0, 3.0], Dict(:x => 2.0)),
12+
([1, 2, 3], [2.0, 2.0, 3.0], [:x], [2.0]),
1313
# value type promotion
14-
([1, 2, 3], [2.0, 3.0, 4.0], Dict(:x => 2, :y => 3.0, :z => 4.0)),
14+
([1, 2, 3], [2.0, 3.0, 4.0], [:x, :y, :z], Real[2, 3.0, 4.0]),
1515
# standard operation
16-
([1.0, 2.0, 3.0], [2.0, 3.0, 4.0], Dict(:a => 2.0, :b => 3.0, :c => 4.0)),
16+
([1.0, 2.0, 3.0], [2.0, 3.0, 4.0], [:a, :b, :c], [2.0, 3.0, 4.0]),
1717
# buffer type "demotion"
18-
([1.0, 2.0, 3.0], [2, 2, 3], Dict(:a => 2)),
18+
([1.0, 2.0, 3.0], [2, 2, 3], [:a], [2]),
1919
# buffer type promotion
20-
([1, 2, 3], [2.0, 2.0, 3.0], Dict(:a => 2.0)),
20+
([1, 2, 3], [2.0, 2.0, 3.0], [:a], [2.0]),
2121
# value type promotion
22-
([1, 2, 3], [2, 3.0, 4.0], Dict(:a => 2, :b => 3.0, :c => 4.0))
22+
([1, 2, 3], [2, 3.0, 4.0], [:a, :b, :c], Real[2, 3.0, 4.0]),
23+
# skip non-parameters
24+
([1, 2, 3], [2.0, 3.0, 3.0], [:a, :b, :(a + b)], [2.0, 3.0, 5.0])
2325
]
2426
for arrType in [Vector, SVector{3}, MVector{3}, SizedVector{3}]
2527
buf = arrType(buf)
2628
newbuf = arrType(newbuf)
27-
_newbuf = remake_buffer(sys, buf, newvals)
29+
_newbuf = remake_buffer(sys, buf, idxs, vals)
2830

2931
@test _newbuf != buf # should not alias
3032
@test newbuf == _newbuf # test values
3133
@test typeof(newbuf) == typeof(_newbuf) # ensure appropriate type
34+
@test_deprecated remake_buffer(sys, buf, Dict(idxs .=> vals))
3235
end
3336
end
3437

35-
# Tuples not allowed for state
36-
for (buf, newbuf, newvals) in [
38+
for (buf, newbuf, idxs, vals) in [
3739
# standard operation
38-
((1.0, 2.0, 3.0), (2.0, 3.0, 4.0), Dict(:a => 2.0, :b => 3.0, :c => 4.0)),
40+
((1.0, 2.0, 3.0), (2.0, 3.0, 4.0), [:a, :b, :c], [2.0, 3.0, 4.0]),
3941
# buffer type "demotion"
40-
((1.0, 2.0, 3.0), (2, 3, 4), Dict(:a => 2, :b => 3, :c => 4)),
42+
((1.0, 2.0, 3.0), (2, 3, 4), [:a, :b, :c], [2, 3, 4]),
4143
# buffer type promotion
42-
((1, 2, 3), (2.0, 3.0, 4.0), Dict(:a => 2.0, :b => 3.0, :c => 4.0)),
44+
((1, 2, 3), (2.0, 3.0, 4.0), [:a, :b, :c], [2.0, 3.0, 4.0]),
4345
# value type promotion
44-
((1, 2, 3), (2, 3.0, 4.0), Dict(:a => 2, :b => 3.0, :c => 4.0))
46+
((1, 2, 3), (2, 3.0, 4.0), [:a, :b, :c], Real[2, 3.0, 4.0]),
47+
# standard operation
48+
((1.0, 2.0, 3.0), (2.0, 3.0, 4.0), [:x, :y, :z], [2.0, 3.0, 4.0]),
49+
# buffer type "demotion"
50+
((1.0, 2.0, 3.0), (2, 3, 4), [:x, :y, :z], [2, 3, 4]),
51+
# buffer type promotion
52+
((1, 2, 3), (2.0, 3.0, 4.0), [:x, :y, :z], [2.0, 3.0, 4.0]),
53+
# value type promotion
54+
((1, 2, 3), (2, 3.0, 4.0), [:x, :y, :z], Real[2, 3.0, 4.0]),
55+
# skip non-variables
56+
([1, 2, 3], [2.0, 3.0, 3.0], [:x, :y, :(x + y)], [2.0, 3.0, 5.0])
4557
]
46-
_newbuf = remake_buffer(sys, buf, newvals)
58+
_newbuf = remake_buffer(sys, buf, idxs, vals)
4759
@test newbuf == _newbuf # test values
4860
@test typeof(newbuf) == typeof(_newbuf) # ensure appropriate type
61+
@test_deprecated remake_buffer(sys, buf, Dict(idxs .=> vals))
4962
end

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,4 +55,7 @@ if GROUP == "All" || GROUP == "Downstream"
5555
@safetestset "BatchedInterface with array symbolics test" begin
5656
@time include("downstream/batchedinterface_arrayvars.jl")
5757
end
58+
@safetestset "remake_buffer with array symbolics test" begin
59+
@time include("downstream/remake_arrayvars.jl")
60+
end
5861
end

0 commit comments

Comments
 (0)