Skip to content

Commit a190a4c

Browse files
Merge pull request #116 from SciML/as/setp-oop-pobj
fix: handle cases for calling oop setter directly on state/parameter object
2 parents ec17278 + 28c6771 commit a190a4c

File tree

4 files changed

+41
-2
lines changed

4 files changed

+41
-2
lines changed

src/state_indexing.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,19 +407,27 @@ end
407407
FullSetter(ssetter, psetter) = FullSetter(ssetter, psetter, nothing, nothing)
408408

409409
function (fs::FullSetter)(valp, val)
410+
check_both_state_and_parameter_provider(valp)
411+
410412
return fs.state_setter(valp, val[fs.state_split]),
411413
fs.param_setter(valp, val[fs.param_split])
412414
end
413415

414416
function (fs::FullSetter{Nothing})(valp, val)
417+
check_both_state_and_parameter_provider(valp)
418+
415419
return state_values(valp), fs.param_setter(valp, val)
416420
end
417421

418422
function (fs::(FullSetter{S, Nothing} where {S}))(valp, val)
423+
check_both_state_and_parameter_provider(valp)
424+
419425
return fs.state_setter(valp, val), parameter_values(valp)
420426
end
421427

422428
function (fs::(FullSetter{Nothing, Nothing}))(valp, val)
429+
check_both_state_and_parameter_provider(valp)
430+
423431
return state_values(valp), parameter_values(valp)
424432
end
425433

src/value_provider_interface.jl

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,12 +238,20 @@ struct OOPSetter{I, D}
238238
end
239239

240240
function (os::OOPSetter)(valp, val)
241-
buffer = os.is_state ? state_values(valp) : parameter_values(valp)
241+
buffer = if os.is_state
242+
hasmethod(state_values, Tuple{typeof(valp)}) ? state_values(valp) : valp
243+
else
244+
hasmethod(parameter_values, Tuple{typeof(valp)}) ? parameter_values(valp) : valp
245+
end
242246
return remake_buffer(os.indp, buffer, (os.idxs,), (val,))
243247
end
244248

245249
function (os::OOPSetter)(valp, val::Union{Tuple, AbstractArray})
246-
buffer = os.is_state ? state_values(valp) : parameter_values(valp)
250+
buffer = if os.is_state
251+
hasmethod(state_values, Tuple{typeof(valp)}) ? state_values(valp) : valp
252+
else
253+
hasmethod(parameter_values, Tuple{typeof(valp)}) ? parameter_values(valp) : valp
254+
end
247255
if os.idxs isa Union{Tuple, AbstractArray}
248256
return remake_buffer(os.indp, buffer, os.idxs, val)
249257
else
@@ -338,3 +346,20 @@ function Base.showerror(io::IO, err::NotVariableOrParameter)
338346
or `is_parameter`. Got `$(err.sym)` which is neither.
339347
""")
340348
end
349+
350+
function MustBeBothStateAndParameterProviderError(missing_state::Bool)
351+
ArgumentError("""
352+
A setter returned from `setsym_oop` must be called with a value provider that \
353+
contains both states and parameters. The given value provided does not \
354+
implement `$(missing_state ? "state_values" : "parameter_values")`.
355+
""")
356+
end
357+
358+
function check_both_state_and_parameter_provider(valp)
359+
if !hasmethod(state_values, Tuple{typeof(valp)}) || state_values(valp) === valp
360+
throw(MustBeBothStateAndParameterProviderError(true))
361+
end
362+
if !hasmethod(parameter_values, Tuple{typeof(valp)}) || parameter_values(valp) === valp
363+
throw(MustBeBothStateAndParameterProviderError(false))
364+
end
365+
end

test/parameter_indexing_test.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,8 @@ for sys in [
179179
newp = setter(fi, val)
180180
getter = getp(sys, sym)
181181
@test getter(newp) == val
182+
newp = setter(parameter_values(fi), val)
183+
@test getter(newp) == val
182184
end
183185
end
184186
end

test/state_indexing_test.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ for (sym, val, newval, check_inference) in [(:x, u[1], 4.0, true)
9090
svals, pvals = setter(fi, newval)
9191
@test svals new_states
9292
@test pvals == parameter_values(fi)
93+
@test_throws ArgumentError setter(state_values(fi), newval)
94+
@test_throws ArgumentError setter(parameter_values(fi), newval)
9395
end
9496

9597
for (sym, val, check_inference) in [
@@ -147,6 +149,8 @@ for (sym, oldval, newval, check_inference) in [(:a, p[1], 4.0, true)
147149
uvals, pvals = oop_setter(fi, newval)
148150
@test uvals newu
149151
@test pvals newp
152+
@test_throws ArgumentError oop_setter(state_values(fi), newval)
153+
@test_throws ArgumentError oop_setter(parameter_values(fi), newval)
150154
end
151155

152156
for (sym, val, check_inference) in [

0 commit comments

Comments
 (0)