Skip to content

Commit 28c6771

Browse files
fix: validate FullSetter is called on a state and parameter provider
1 parent 9fab503 commit 28c6771

File tree

3 files changed

+29
-0
lines changed

3 files changed

+29
-0
lines changed

src/state_indexing.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,19 +407,27 @@ end
407407
FullSetter(ssetter, psetter) = FullSetter(ssetter, psetter, nothing, nothing)
408408

409409
function (fs::FullSetter)(valp, val)
410+
check_both_state_and_parameter_provider(valp)
411+
410412
return fs.state_setter(valp, val[fs.state_split]),
411413
fs.param_setter(valp, val[fs.param_split])
412414
end
413415

414416
function (fs::FullSetter{Nothing})(valp, val)
417+
check_both_state_and_parameter_provider(valp)
418+
415419
return state_values(valp), fs.param_setter(valp, val)
416420
end
417421

418422
function (fs::(FullSetter{S, Nothing} where {S}))(valp, val)
423+
check_both_state_and_parameter_provider(valp)
424+
419425
return fs.state_setter(valp, val), parameter_values(valp)
420426
end
421427

422428
function (fs::(FullSetter{Nothing, Nothing}))(valp, val)
429+
check_both_state_and_parameter_provider(valp)
430+
423431
return state_values(valp), parameter_values(valp)
424432
end
425433

src/value_provider_interface.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,3 +346,20 @@ function Base.showerror(io::IO, err::NotVariableOrParameter)
346346
or `is_parameter`. Got `$(err.sym)` which is neither.
347347
""")
348348
end
349+
350+
function MustBeBothStateAndParameterProviderError(missing_state::Bool)
351+
ArgumentError("""
352+
A setter returned from `setsym_oop` must be called with a value provider that \
353+
contains both states and parameters. The given value provided does not \
354+
implement `$(missing_state ? "state_values" : "parameter_values")`.
355+
""")
356+
end
357+
358+
function check_both_state_and_parameter_provider(valp)
359+
if !hasmethod(state_values, Tuple{typeof(valp)}) || state_values(valp) === valp
360+
throw(MustBeBothStateAndParameterProviderError(true))
361+
end
362+
if !hasmethod(parameter_values, Tuple{typeof(valp)}) || parameter_values(valp) === valp
363+
throw(MustBeBothStateAndParameterProviderError(false))
364+
end
365+
end

test/state_indexing_test.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ for (sym, val, newval, check_inference) in [(:x, u[1], 4.0, true)
9090
svals, pvals = setter(fi, newval)
9191
@test svals new_states
9292
@test pvals == parameter_values(fi)
93+
@test_throws ArgumentError setter(state_values(fi), newval)
94+
@test_throws ArgumentError setter(parameter_values(fi), newval)
9395
end
9496

9597
for (sym, val, check_inference) in [
@@ -147,6 +149,8 @@ for (sym, oldval, newval, check_inference) in [(:a, p[1], 4.0, true)
147149
uvals, pvals = oop_setter(fi, newval)
148150
@test uvals newu
149151
@test pvals newp
152+
@test_throws ArgumentError oop_setter(state_values(fi), newval)
153+
@test_throws ArgumentError oop_setter(parameter_values(fi), newval)
150154
end
151155

152156
for (sym, val, check_inference) in [

0 commit comments

Comments
 (0)