Skip to content

Commit ae95c7d

Browse files
feat: allow passing wrap_delays to build_explicit_observed_function
1 parent 8cc641f commit ae95c7d

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

src/systems/diffeqs/odesystem.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,8 @@ Generates a function that computes the observed value(s) `ts` in the system `sys
476476
- `throw = true` if true, throw an error when generating a function for `ts` that reference variables that do not exist.
477477
- `mkarray`: only used if the output is an array (that is, `!isscalar(ts)` and `ts` is not a tuple, in which case the result will always be a tuple). Called as `mkarray(ts, output_type)` where `ts` are the expressions to put in the array and `output_type` is the argument of the same name passed to build_explicit_observed_function.
478478
- `cse = true`: Whether to use Common Subexpression Elimination (CSE) to generate a more efficient function.
479+
- `wrap_delays = is_dde(sys)`: Whether to add an argument for the history function and use
480+
it to calculate all delayed variables.
479481
480482
## Returns
481483
@@ -514,7 +516,8 @@ function build_explicit_observed_function(sys, ts;
514516
op = Operator,
515517
throw = true,
516518
cse = true,
517-
mkarray = nothing)
519+
mkarray = nothing,
520+
wrap_delays = is_dde(sys))
518521
is_tuple = ts isa Tuple
519522
if is_tuple
520523
ts = collect(ts)
@@ -600,14 +603,15 @@ function build_explicit_observed_function(sys, ts;
600603
p_end = length(dvs) + length(inputs) + length(ps)
601604
fns = build_function_wrapper(
602605
sys, ts, args...; p_start, p_end, filter_observed = obsfilter,
603-
output_type, mkarray, try_namespaced = true, expression = Val{true}, cse)
606+
output_type, mkarray, try_namespaced = true, expression = Val{true}, cse,
607+
wrap_delays)
604608
if fns isa Tuple
605609
if expression
606610
return return_inplace ? fns : fns[1]
607611
end
608612
oop, iip = eval_or_rgf.(fns; eval_expression, eval_module)
609613
f = GeneratedFunctionWrapper{(
610-
p_start + is_dde(sys), length(args) - length(ps) + 1 + is_dde(sys), is_split(sys))}(
614+
p_start + wrap_delays, length(args) - length(ps) + 1 + wrap_delays, is_split(sys))}(
611615
oop, iip)
612616
return return_inplace ? (f, f) : f
613617
else
@@ -616,7 +620,7 @@ function build_explicit_observed_function(sys, ts;
616620
end
617621
f = eval_or_rgf(fns; eval_expression, eval_module)
618622
f = GeneratedFunctionWrapper{(
619-
p_start + is_dde(sys), length(args) - length(ps) + 1 + is_dde(sys), is_split(sys))}(
623+
p_start + wrap_delays, length(args) - length(ps) + 1 + wrap_delays, is_split(sys))}(
620624
f, nothing)
621625
return f
622626
end

0 commit comments

Comments
 (0)