Skip to content

Commit aa6acce

Browse files
feat: use wrap_mtkparameters in build_explicit_observed_function
1 parent 9328836 commit aa6acce

File tree

1 file changed

+18
-4
lines changed

1 file changed

+18
-4
lines changed

src/systems/diffeqs/odesystem.jl

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -412,8 +412,13 @@ function build_explicit_observed_function(sys, ts;
412412
ts = [ts]
413413
end
414414
ts = unwrap.(ts)
415+
issplit = has_index_cache(sys) && get_index_cache(sys) !== nothing
415416
if is_dde(sys)
416-
ts = map(x -> delay_to_function(sys, x), ts)
417+
if issplit
418+
ts = map(x -> delay_to_function(sys, x; history_arg = issplit ? MTKPARAMETERS_ARG : DEFAULT_PARAMS_ARG), ts)
419+
else
420+
ts = map(x -> delay_to_function(sys, x), ts)
421+
end
417422
end
418423

419424
vars = Set()
@@ -491,7 +496,7 @@ function build_explicit_observed_function(sys, ts;
491496
for i in 1:maxidx
492497
eq = obs[i]
493498
if is_dde(sys)
494-
eq = delay_to_function(sys, eq)
499+
eq = delay_to_function(sys, eq; history_arg = issplit ? MTKPARAMETERS_ARG : DEFAULT_PARAMS_ARG)
495500
end
496501
lhs = eq.lhs
497502
rhs = eq.rhs
@@ -518,12 +523,14 @@ function build_explicit_observed_function(sys, ts;
518523
else
519524
dvs = (dvs,)
520525
end
526+
p_start = param_only ? 1 : (length(dvs) + 1)
521527
if inputs === nothing
522528
args = param_only ? [ps..., ivs...] : [dvs..., ps..., ivs...]
523529
else
524530
inputs = unwrap.(inputs)
525531
ipts = DestructuredArgs(inputs, inbounds = !checkbounds)
526532
args = param_only ? [ipts, ps..., ivs...] : [dvs..., ipts, ps..., ivs...]
533+
p_start += 1
527534
end
528535
pre = get_postprocess_fbody(sys)
529536

@@ -534,19 +541,26 @@ function build_explicit_observed_function(sys, ts;
534541
wrap_array_vars(sys, ts; ps = _ps, inputs) .∘
535542
wrap_parameter_dependencies(sys, isscalar)
536543
end
544+
mtkparams_wrapper = wrap_mtkparameters(sys, isscalar, p_start)
545+
if mtkparams_wrapper isa Tuple
546+
oop_mtkp_wrapper = mtkparams_wrapper[1]
547+
else
548+
oop_mtkp_wrapper = mtkparams_wrapper
549+
end
550+
537551
# Need to keep old method of building the function since it uses `output_type`,
538552
# which can't be provided to `build_function`
539553
oop_fn = Func(args, [],
540554
pre(Let(obsexprs,
541555
isscalar ? ts[1] : MakeArray(ts, output_type),
542-
false))) |> array_wrapper[1] |> toexpr
556+
false))) |> array_wrapper[1] |> oop_mtkp_wrapper |> toexpr
543557
oop_fn = expression ? oop_fn : eval_or_rgf(oop_fn; eval_expression, eval_module)
544558

545559
if !isscalar
546560
iip_fn = build_function(ts,
547561
args...;
548562
postprocess_fbody = pre,
549-
wrap_code = array_wrapper .∘ wrap_assignments(isscalar, obsexprs),
563+
wrap_code = array_wrapper .∘ wrap_assignments(isscalar, obsexprs) .∘ mtkparams_wrapper,
550564
expression = Val{true})[2]
551565
if !expression
552566
iip_fn = eval_or_rgf(iip_fn; eval_expression, eval_module)

0 commit comments

Comments
 (0)