diff --git a/src/state_indexing.jl b/src/state_indexing.jl index 6e1d3d0..fef8f50 100644 --- a/src/state_indexing.jl +++ b/src/state_indexing.jl @@ -407,19 +407,27 @@ end FullSetter(ssetter, psetter) = FullSetter(ssetter, psetter, nothing, nothing) function (fs::FullSetter)(valp, val) + check_both_state_and_parameter_provider(valp) + return fs.state_setter(valp, val[fs.state_split]), fs.param_setter(valp, val[fs.param_split]) end function (fs::FullSetter{Nothing})(valp, val) + check_both_state_and_parameter_provider(valp) + return state_values(valp), fs.param_setter(valp, val) end function (fs::(FullSetter{S, Nothing} where {S}))(valp, val) + check_both_state_and_parameter_provider(valp) + return fs.state_setter(valp, val), parameter_values(valp) end function (fs::(FullSetter{Nothing, Nothing}))(valp, val) + check_both_state_and_parameter_provider(valp) + return state_values(valp), parameter_values(valp) end diff --git a/src/value_provider_interface.jl b/src/value_provider_interface.jl index cd52ae1..e082d54 100644 --- a/src/value_provider_interface.jl +++ b/src/value_provider_interface.jl @@ -238,12 +238,20 @@ struct OOPSetter{I, D} end function (os::OOPSetter)(valp, val) - buffer = os.is_state ? state_values(valp) : parameter_values(valp) + buffer = if os.is_state + hasmethod(state_values, Tuple{typeof(valp)}) ? state_values(valp) : valp + else + hasmethod(parameter_values, Tuple{typeof(valp)}) ? parameter_values(valp) : valp + end return remake_buffer(os.indp, buffer, (os.idxs,), (val,)) end function (os::OOPSetter)(valp, val::Union{Tuple, AbstractArray}) - buffer = os.is_state ? state_values(valp) : parameter_values(valp) + buffer = if os.is_state + hasmethod(state_values, Tuple{typeof(valp)}) ? state_values(valp) : valp + else + hasmethod(parameter_values, Tuple{typeof(valp)}) ? parameter_values(valp) : valp + end if os.idxs isa Union{Tuple, AbstractArray} return remake_buffer(os.indp, buffer, os.idxs, val) else @@ -338,3 +346,20 @@ function Base.showerror(io::IO, err::NotVariableOrParameter) or `is_parameter`. Got `$(err.sym)` which is neither. """) end + +function MustBeBothStateAndParameterProviderError(missing_state::Bool) + ArgumentError(""" + A setter returned from `setsym_oop` must be called with a value provider that \ + contains both states and parameters. The given value provided does not \ + implement `$(missing_state ? "state_values" : "parameter_values")`. + """) +end + +function check_both_state_and_parameter_provider(valp) + if !hasmethod(state_values, Tuple{typeof(valp)}) || state_values(valp) === valp + throw(MustBeBothStateAndParameterProviderError(true)) + end + if !hasmethod(parameter_values, Tuple{typeof(valp)}) || parameter_values(valp) === valp + throw(MustBeBothStateAndParameterProviderError(false)) + end +end diff --git a/test/parameter_indexing_test.jl b/test/parameter_indexing_test.jl index 26eaaf5..7017b74 100644 --- a/test/parameter_indexing_test.jl +++ b/test/parameter_indexing_test.jl @@ -179,6 +179,8 @@ for sys in [ newp = setter(fi, val) getter = getp(sys, sym) @test getter(newp) == val + newp = setter(parameter_values(fi), val) + @test getter(newp) == val end end end diff --git a/test/state_indexing_test.jl b/test/state_indexing_test.jl index 4537276..e52276f 100644 --- a/test/state_indexing_test.jl +++ b/test/state_indexing_test.jl @@ -90,6 +90,8 @@ for (sym, val, newval, check_inference) in [(:x, u[1], 4.0, true) svals, pvals = setter(fi, newval) @test svals ≈ new_states @test pvals == parameter_values(fi) + @test_throws ArgumentError setter(state_values(fi), newval) + @test_throws ArgumentError setter(parameter_values(fi), newval) end for (sym, val, check_inference) in [ @@ -147,6 +149,8 @@ for (sym, oldval, newval, check_inference) in [(:a, p[1], 4.0, true) uvals, pvals = oop_setter(fi, newval) @test uvals ≈ newu @test pvals ≈ newp + @test_throws ArgumentError oop_setter(state_values(fi), newval) + @test_throws ArgumentError oop_setter(parameter_values(fi), newval) end for (sym, val, check_inference) in [