Skip to content

Commit ab639eb

Browse files
feat: use wrap_mtkparameters in build_explicit_observed_function
1 parent 0e50dcf commit ab639eb

File tree

1 file changed

+23
-4
lines changed

1 file changed

+23
-4
lines changed

src/systems/diffeqs/odesystem.jl

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -412,8 +412,16 @@ function build_explicit_observed_function(sys, ts;
412412
ts = [ts]
413413
end
414414
ts = unwrap.(ts)
415+
issplit = has_index_cache(sys) && get_index_cache(sys) !== nothing
415416
if is_dde(sys)
416-
ts = map(x -> delay_to_function(sys, x), ts)
417+
if issplit
418+
ts = map(
419+
x -> delay_to_function(
420+
sys, x; history_arg = issplit ? MTKPARAMETERS_ARG : DEFAULT_PARAMS_ARG),
421+
ts)
422+
else
423+
ts = map(x -> delay_to_function(sys, x), ts)
424+
end
417425
end
418426

419427
vars = Set()
@@ -491,7 +499,8 @@ function build_explicit_observed_function(sys, ts;
491499
for i in 1:maxidx
492500
eq = obs[i]
493501
if is_dde(sys)
494-
eq = delay_to_function(sys, eq)
502+
eq = delay_to_function(
503+
sys, eq; history_arg = issplit ? MTKPARAMETERS_ARG : DEFAULT_PARAMS_ARG)
495504
end
496505
lhs = eq.lhs
497506
rhs = eq.rhs
@@ -518,12 +527,14 @@ function build_explicit_observed_function(sys, ts;
518527
else
519528
dvs = (dvs,)
520529
end
530+
p_start = param_only ? 1 : (length(dvs) + 1)
521531
if inputs === nothing
522532
args = param_only ? [ps..., ivs...] : [dvs..., ps..., ivs...]
523533
else
524534
inputs = unwrap.(inputs)
525535
ipts = DestructuredArgs(inputs, inbounds = !checkbounds)
526536
args = param_only ? [ipts, ps..., ivs...] : [dvs..., ipts, ps..., ivs...]
537+
p_start += 1
527538
end
528539
pre = get_postprocess_fbody(sys)
529540

@@ -534,19 +545,27 @@ function build_explicit_observed_function(sys, ts;
534545
wrap_array_vars(sys, ts; ps = _ps, inputs) .∘
535546
wrap_parameter_dependencies(sys, isscalar)
536547
end
548+
mtkparams_wrapper = wrap_mtkparameters(sys, isscalar, p_start)
549+
if mtkparams_wrapper isa Tuple
550+
oop_mtkp_wrapper = mtkparams_wrapper[1]
551+
else
552+
oop_mtkp_wrapper = mtkparams_wrapper
553+
end
554+
537555
# Need to keep old method of building the function since it uses `output_type`,
538556
# which can't be provided to `build_function`
539557
oop_fn = Func(args, [],
540558
pre(Let(obsexprs,
541559
isscalar ? ts[1] : MakeArray(ts, output_type),
542-
false))) |> array_wrapper[1] |> toexpr
560+
false))) |> array_wrapper[1] |> oop_mtkp_wrapper |> toexpr
543561
oop_fn = expression ? oop_fn : eval_or_rgf(oop_fn; eval_expression, eval_module)
544562

545563
if !isscalar
546564
iip_fn = build_function(ts,
547565
args...;
548566
postprocess_fbody = pre,
549-
wrap_code = array_wrapper .∘ wrap_assignments(isscalar, obsexprs),
567+
wrap_code = array_wrapper .∘ wrap_assignments(isscalar, obsexprs) .∘
568+
mtkparams_wrapper,
550569
expression = Val{true})[2]
551570
if !expression
552571
iip_fn = eval_or_rgf(iip_fn; eval_expression, eval_module)

0 commit comments

Comments
 (0)