Skip to content

Commit e8934c4

Browse files
feat: add setsym_oop
1 parent 3344f9f commit e8934c4

File tree

4 files changed

+164
-30
lines changed

4 files changed

+164
-30
lines changed

src/parameter_indexing.jl

Lines changed: 4 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -726,52 +726,26 @@ function setp_oop(indp, sym)
726726
return _setp_oop(indp, symtype, elsymtype, sym)
727727
end
728728

729-
struct OOPSetter{I, D}
730-
indp::I
731-
idxs::D
732-
end
733-
734-
function (os::OOPSetter)(valp, val)
735-
return remake_buffer(os.indp, parameter_values(valp), (os.idxs,), (val,))
736-
end
737-
738-
function (os::OOPSetter)(valp, val::Union{Tuple, AbstractArray})
739-
if os.idxs isa Union{Tuple, AbstractArray}
740-
return remake_buffer(os.indp, parameter_values(valp), os.idxs, val)
741-
else
742-
return remake_buffer(os.indp, parameter_values(valp), (os.idxs,), (val,))
743-
end
744-
end
745-
746-
function _root_indp(indp)
747-
if hasmethod(symbolic_container, Tuple{typeof(indp)}) &&
748-
(sc = symbolic_container(indp)) != indp
749-
return _root_indp(sc)
750-
else
751-
return indp
752-
end
753-
end
754-
755729
function _setp_oop(indp, ::NotSymbolic, ::NotSymbolic, sym)
756-
return OOPSetter(_root_indp(indp), sym)
730+
return OOPSetter(_root_indp(indp), sym, false)
757731
end
758732

759733
function _setp_oop(indp, ::ScalarSymbolic, ::SymbolicTypeTrait, sym)
760-
return OOPSetter(_root_indp(indp), parameter_index(indp, sym))
734+
return OOPSetter(_root_indp(indp), parameter_index(indp, sym), false)
761735
end
762736

763737
for (t1, t2) in [
764738
(ScalarSymbolic, Any),
765739
(NotSymbolic, Union{<:Tuple, <:AbstractArray})
766740
]
767741
@eval function _setp_oop(indp, ::NotSymbolic, ::$t1, sym::$t2)
768-
return OOPSetter(_root_indp(indp), parameter_index.((indp,), sym))
742+
return OOPSetter(_root_indp(indp), parameter_index.((indp,), sym), false)
769743
end
770744
end
771745

772746
function _setp_oop(indp, ::ArraySymbolic, ::SymbolicTypeTrait, sym)
773747
if is_parameter(indp, sym)
774-
return OOPSetter(_root_indp(indp), parameter_index(indp, sym))
748+
return OOPSetter(_root_indp(indp), parameter_index(indp, sym), false)
775749
end
776750
error("$sym is not a valid parameter")
777751
end

src/state_indexing.jl

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,3 +377,97 @@ end
377377

378378
const getu = getsym
379379
const setu = setsym
380+
381+
"""
382+
setsym_oop(indp, sym)
383+
384+
Return a function which takes a value provider `valp` and a value `val`, and returns
385+
`state_values(valp), parameter_values(valp)` with the states/parameters in `sym` set to the
386+
corresponding values in `val`. This allows changing the types of values stored, and leverages
387+
[`remake_buffer`](@ref). Note that `sym` can be an index, a symbolic variable, or an
388+
array/tuple of the aforementioned. All entries `s` in `sym` must satisfy `is_variable(indp, s)`
389+
or `is_parameter(indp, s)`.
390+
391+
Requires that the value provider implement `state_values`, `parameter_values` and `remake_buffer`.
392+
"""
393+
function setsym_oop(indp, sym)
394+
symtype = symbolic_type(sym)
395+
elsymtype = symbolic_type(eltype(sym))
396+
return _setsym_oop(indp, symtype, elsymtype, sym)
397+
end
398+
399+
struct FullSetter{S, P, I, J}
400+
state_setter::S
401+
param_setter::P
402+
state_split::I
403+
param_split::J
404+
end
405+
406+
FullSetter(ssetter, psetter) = FullSetter(ssetter, psetter, nothing, nothing)
407+
408+
function (fs::FullSetter)(valp, val)
409+
return fs.state_setter(valp, val[fs.state_split]),
410+
fs.param_setter(valp, val[fs.param_split])
411+
end
412+
413+
function (fs::FullSetter{Nothing})(valp, val)
414+
return state_values(valp), fs.param_setter(valp, val)
415+
end
416+
417+
function (fs::(FullSetter{S, Nothing} where {S}))(valp, val)
418+
return fs.state_setter(valp, val), parameter_values(valp)
419+
end
420+
421+
function (fs::(FullSetter{Nothing, Nothing}))(valp, val)
422+
return state_values(valp), parameter_values(valp)
423+
end
424+
425+
function _setsym_oop(indp, ::NotSymbolic, ::NotSymbolic, sym)
426+
return FullSetter(OOPSetter(_root_indp(indp), sym, true), nothing)
427+
end
428+
429+
function _setsym_oop(indp, ::ScalarSymbolic, ::SymbolicTypeTrait, sym)
430+
if (idx = variable_index(indp, sym)) !== nothing
431+
return FullSetter(OOPSetter(_root_indp(indp), idx, true), nothing)
432+
elseif (idx = parameter_index(indp, sym)) !== nothing
433+
return FullSetter(nothing, OOPSetter(_root_indp(indp), idx, false))
434+
end
435+
throw(NotVariableOrParameter("setsym_oop", sym))
436+
end
437+
438+
for (t1, t2) in [
439+
(ScalarSymbolic, Any),
440+
(NotSymbolic, Union{<:Tuple, <:AbstractArray})
441+
]
442+
@eval function _setsym_oop(indp, ::NotSymbolic, ::$t1, sym::$t2)
443+
vars = []
444+
state_split = eltype(eachindex(sym))[]
445+
pars = []
446+
param_split = eltype(eachindex(sym))[]
447+
for (i, s) in enumerate(sym)
448+
if (idx = variable_index(indp, s)) !== nothing
449+
push!(vars, idx)
450+
push!(state_split, i)
451+
elseif (idx = parameter_index(indp, s)) !== nothing
452+
push!(pars, idx)
453+
push!(param_split, i)
454+
else
455+
throw(NotVariableOrParameter("setsym_oop", s))
456+
end
457+
end
458+
indp = _root_indp(indp)
459+
return FullSetter(isempty(vars) ? nothing : OOPSetter(indp, identity.(vars), true),
460+
isempty(pars) ? nothing : OOPSetter(indp, identity.(pars), false),
461+
state_split, param_split)
462+
end
463+
end
464+
465+
function _setsym_oop(indp, ::ArraySymbolic, ::SymbolicTypeTrait, sym)
466+
if (idx = variable_index(indp, sym)) !== nothing
467+
return setsym_oop(indp, idx)
468+
elseif (idx = parameter_index(indp, sym)) !== nothing
469+
return FullSetter(
470+
nothing, OOPSetter(indp, idx isa AbstractArray ? idx : (idx,), false))
471+
end
472+
return setsym_oop(indp, collect(sym))
473+
end

src/value_provider_interface.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,35 @@ function (fn::Fix1Multiple)(args...)
231231
fn.f(fn.arg, args...)
232232
end
233233

234+
struct OOPSetter{I, D}
235+
indp::I
236+
idxs::D
237+
is_state::Bool
238+
end
239+
240+
function (os::OOPSetter)(valp, val)
241+
buffer = os.is_state ? state_values(valp) : parameter_values(valp)
242+
return remake_buffer(os.indp, buffer, (os.idxs,), (val,))
243+
end
244+
245+
function (os::OOPSetter)(valp, val::Union{Tuple, AbstractArray})
246+
buffer = os.is_state ? state_values(valp) : parameter_values(valp)
247+
if os.idxs isa Union{Tuple, AbstractArray}
248+
return remake_buffer(os.indp, buffer, os.idxs, val)
249+
else
250+
return remake_buffer(os.indp, buffer, (os.idxs,), (val,))
251+
end
252+
end
253+
254+
function _root_indp(indp)
255+
if hasmethod(symbolic_container, Tuple{typeof(indp)}) &&
256+
(sc = symbolic_container(indp)) != indp
257+
return _root_indp(sc)
258+
else
259+
return indp
260+
end
261+
end
262+
234263
###########
235264
# Errors
236265
###########
@@ -296,3 +325,16 @@ function Base.showerror(io::IO, err::MixedParameterTimeseriesIndexError)
296325
indexes $(err.ts_idxs).
297326
""")
298327
end
328+
329+
struct NotVariableOrParameter <: Exception
330+
fn::Any
331+
sym::Any
332+
end
333+
334+
function Base.showerror(io::IO, err::NotVariableOrParameter)
335+
print(
336+
io, """
337+
`$(err.fn)` requires that the symbolic variable(s) passed to it satisfy `is_variable`
338+
or `is_parameter`. Got `$(err.sym)` which is neither.
339+
""")
340+
end

test/state_indexing_test.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using SymbolicIndexingInterface
2+
using SymbolicIndexingInterface: NotVariableOrParameter
23

34
struct FakeIntegrator{S, U, P, T}
45
sys::S
@@ -62,6 +63,9 @@ for (sym, val, newval, check_inference) in [(:x, u[1], 4.0, true)
6263
set!(fi, newval)
6364
end
6465
@test get(fi) == newval
66+
67+
new_states = copy(state_values(fi))
68+
6569
set!(fi, val)
6670
@test get(fi) == val
6771

@@ -77,6 +81,15 @@ for (sym, val, newval, check_inference) in [(:x, u[1], 4.0, true)
7781
@test get(u) == newval
7882
set!(u, val)
7983
@test get(u) == val
84+
85+
if sym isa Union{Vector, Tuple} && any(x -> x isa Union{AbstractArray, Tuple}, sym)
86+
continue
87+
end
88+
89+
setter = setsym_oop(sys, sym)
90+
svals, pvals = setter(fi, newval)
91+
@test svals new_states
92+
@test pvals == parameter_values(fi)
8093
end
8194

8295
for (sym, val, check_inference) in [
@@ -123,8 +136,17 @@ for (sym, oldval, newval, check_inference) in [(:a, p[1], 4.0, true)
123136
set!(fi, newval)
124137
end
125138
@test get(fi) == newval
139+
140+
newu = copy(state_values(fi))
141+
newp = copy(parameter_values(fi))
142+
126143
set!(fi, oldval)
127144
@test get(fi) == oldval
145+
146+
oop_setter = setsym_oop(sys, sym)
147+
uvals, pvals = oop_setter(fi, newval)
148+
@test uvals newu
149+
@test pvals newp
128150
end
129151

130152
for (sym, val, check_inference) in [
@@ -137,6 +159,8 @@ for (sym, val, check_inference) in [
137159
@inferred get(fi)
138160
end
139161
@test get(fi) == val
162+
163+
@test_throws NotVariableOrParameter setsym_oop(fi, sym)
140164
end
141165

142166
struct FakeSolution{S, U, P, T}

0 commit comments

Comments
 (0)