Skip to content

Commit fc54236

Browse files
Merge pull request #93 from SciML/as/param-obs-no-t
feat: allow calling parameter observed functions with parameter object
2 parents 58e0df3 + 680276e commit fc54236

File tree

2 files changed

+61
-4
lines changed

2 files changed

+61
-4
lines changed

src/parameter_indexing.jl

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -208,11 +208,24 @@ for argType in [Union{Int, CartesianIndex}, Colon, AbstractArray{Bool}, Any]
208208
end
209209

210210
function (gpo::GetParameterObserved{<:Vector})(::NotTimeseries, prob)
211-
gpo.obsfn(parameter_values(prob), current_time(prob))
211+
# if the method doesn't exist or is an identity function, then `prob` itself
212+
# is the parameter object, so use that and pass `nothing` for the time expecting
213+
# it to not be used
214+
if hasmethod(parameter_values, Tuple{typeof(prob)}) &&
215+
(ps = parameter_values(prob)) != prob
216+
gpo.obsfn(ps, current_time(prob))
217+
else
218+
gpo.obsfn(prob, nothing)
219+
end
212220
end
213221
function (gpo::GetParameterObserved{<:Vector, true})(
214222
buffer::AbstractArray, ::NotTimeseries, prob)
215-
gpo.obsfn(buffer, parameter_values(prob), current_time(prob))
223+
if hasmethod(parameter_values, Tuple{typeof(prob)}) &&
224+
(ps = parameter_values(prob)) != prob
225+
gpo.obsfn(buffer, ps, current_time(prob))
226+
else
227+
gpo.obsfn(buffer, prob, nothing)
228+
end
216229
end
217230
function (gpo::GetParameterObserved{<:Vector})(::Timeseries, prob)
218231
throw(MixedParameterTimeseriesIndexError(prob, indexer_timeseries_index(gpo)))
@@ -224,10 +237,20 @@ function (gpo::GetParameterObserved{<:Vector, false})(::AbstractArray, ::Timeser
224237
throw(MixedParameterTimeseriesIndexError(prob, indexer_timeseries_index(gpo)))
225238
end
226239
function (gpo::GetParameterObserved)(::NotTimeseries, prob)
227-
gpo.obsfn(parameter_values(prob), current_time(prob))
240+
if hasmethod(parameter_values, Tuple{typeof(prob)}) &&
241+
(ps = parameter_values(prob)) != prob
242+
gpo.obsfn(ps, current_time(prob))
243+
else
244+
gpo.obsfn(prob, nothing)
245+
end
228246
end
229247
function (gpo::GetParameterObserved)(buffer::AbstractArray, ::NotTimeseries, prob)
230-
gpo.obsfn(buffer, parameter_values(prob), current_time(prob))
248+
if hasmethod(parameter_values, Tuple{typeof(prob)}) &&
249+
(ps = parameter_values(prob)) != prob
250+
gpo.obsfn(buffer, ps, current_time(prob))
251+
else
252+
gpo.obsfn(buffer, prob, nothing)
253+
end
231254
return buffer
232255
end
233256
function (gpo::GetParameterObserved)(::Timeseries, prob)

test/parameter_indexing_test.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,40 @@ for sys in [
130130
end
131131
end
132132

133+
for (sym, val, check_inference) in [
134+
(:(a + b), p[1] + p[2], true),
135+
([:(a + b), :(a * b)], [p[1] + p[2], p[1] * p[2]], true),
136+
((:(a + b), :(a * b)), (p[1] + p[2], p[1] * p[2]), true),
137+
([:(a + c), :(a + b)], [p[1] + p[3], p[1] + p[2]], true)
138+
]
139+
get = getp(sys, sym)
140+
if check_inference
141+
@inferred get(parameter_values(fi))
142+
end
143+
@test get(parameter_values(fi)) == val
144+
if sym isa Union{Array, Tuple}
145+
buffer = zeros(length(sym))
146+
if check_inference
147+
@inferred get(buffer, parameter_values(fi))
148+
else
149+
get(buffer, parameter_values(fi))
150+
end
151+
@test buffer == collect(val)
152+
end
153+
end
154+
155+
for sym in [
156+
:(a + t),
157+
[:(a + t), :(a * b)],
158+
(:(a + t), :(a * b))
159+
]
160+
get = getp(sys, sym)
161+
@test_throws MethodError get(parameter_values(fi))
162+
if sym isa Union{Array, Tuple}
163+
@test_throws MethodError get(zeros(length(sym)), parameter_values(fi))
164+
end
165+
end
166+
133167
getter = getp(sys, [])
134168
@test getter(fi) == []
135169
getter = getp(sys, ())

0 commit comments

Comments
 (0)