Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/state_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -407,19 +407,27 @@ end
FullSetter(ssetter, psetter) = FullSetter(ssetter, psetter, nothing, nothing)

function (fs::FullSetter)(valp, val)
check_both_state_and_parameter_provider(valp)

return fs.state_setter(valp, val[fs.state_split]),
fs.param_setter(valp, val[fs.param_split])
end

function (fs::FullSetter{Nothing})(valp, val)
check_both_state_and_parameter_provider(valp)

return state_values(valp), fs.param_setter(valp, val)
end

function (fs::(FullSetter{S, Nothing} where {S}))(valp, val)
check_both_state_and_parameter_provider(valp)

return fs.state_setter(valp, val), parameter_values(valp)
end

function (fs::(FullSetter{Nothing, Nothing}))(valp, val)
check_both_state_and_parameter_provider(valp)

return state_values(valp), parameter_values(valp)
end

Expand Down
29 changes: 27 additions & 2 deletions src/value_provider_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -238,12 +238,20 @@ struct OOPSetter{I, D}
end

function (os::OOPSetter)(valp, val)
buffer = os.is_state ? state_values(valp) : parameter_values(valp)
buffer = if os.is_state
hasmethod(state_values, Tuple{typeof(valp)}) ? state_values(valp) : valp
else
hasmethod(parameter_values, Tuple{typeof(valp)}) ? parameter_values(valp) : valp
end
return remake_buffer(os.indp, buffer, (os.idxs,), (val,))
end

function (os::OOPSetter)(valp, val::Union{Tuple, AbstractArray})
buffer = os.is_state ? state_values(valp) : parameter_values(valp)
buffer = if os.is_state
hasmethod(state_values, Tuple{typeof(valp)}) ? state_values(valp) : valp
else
hasmethod(parameter_values, Tuple{typeof(valp)}) ? parameter_values(valp) : valp
end
if os.idxs isa Union{Tuple, AbstractArray}
return remake_buffer(os.indp, buffer, os.idxs, val)
else
Expand Down Expand Up @@ -338,3 +346,20 @@ function Base.showerror(io::IO, err::NotVariableOrParameter)
or `is_parameter`. Got `$(err.sym)` which is neither.
""")
end

function MustBeBothStateAndParameterProviderError(missing_state::Bool)
ArgumentError("""
A setter returned from `setsym_oop` must be called with a value provider that \
contains both states and parameters. The given value provided does not \
implement `$(missing_state ? "state_values" : "parameter_values")`.
""")
end

function check_both_state_and_parameter_provider(valp)
if !hasmethod(state_values, Tuple{typeof(valp)}) || state_values(valp) === valp
throw(MustBeBothStateAndParameterProviderError(true))
end
if !hasmethod(parameter_values, Tuple{typeof(valp)}) || parameter_values(valp) === valp
throw(MustBeBothStateAndParameterProviderError(false))
end
end
2 changes: 2 additions & 0 deletions test/parameter_indexing_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ for sys in [
newp = setter(fi, val)
getter = getp(sys, sym)
@test getter(newp) == val
newp = setter(parameter_values(fi), val)
@test getter(newp) == val
end
end
end
Expand Down
4 changes: 4 additions & 0 deletions test/state_indexing_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ for (sym, val, newval, check_inference) in [(:x, u[1], 4.0, true)
svals, pvals = setter(fi, newval)
@test svals ≈ new_states
@test pvals == parameter_values(fi)
@test_throws ArgumentError setter(state_values(fi), newval)
@test_throws ArgumentError setter(parameter_values(fi), newval)
end

for (sym, val, check_inference) in [
Expand Down Expand Up @@ -147,6 +149,8 @@ for (sym, oldval, newval, check_inference) in [(:a, p[1], 4.0, true)
uvals, pvals = oop_setter(fi, newval)
@test uvals ≈ newu
@test pvals ≈ newp
@test_throws ArgumentError oop_setter(state_values(fi), newval)
@test_throws ArgumentError oop_setter(parameter_values(fi), newval)
end

for (sym, val, check_inference) in [
Expand Down
Loading