Skip to content

Commit 6f59a71

Browse files
feat: support observed function generation for DDEs
1 parent 31e78ad commit 6f59a71

File tree

2 files changed

+32
-2
lines changed

2 files changed

+32
-2
lines changed

src/systems/abstractsystem.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -805,6 +805,10 @@ function SymbolicIndexingInterface.observed(
805805
return let _fn = _fn
806806
fn1(u, p, t) = _fn(u, p, t)
807807
fn1(u, p::MTKParameters, t) = _fn(u, p..., t)
808+
809+
# DDEs
810+
fn1(u, histfn, p, t) = _fn(u, histfn, p, t)
811+
fn1(u, histfn, p::MTKParameters, t) = _fn(u, histfn, p..., t)
808812
fn1
809813
end
810814
else

src/systems/diffeqs/odesystem.jl

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -494,9 +494,16 @@ function build_explicit_observed_function(sys, ts;
494494
end
495495
end
496496
ts = map(t -> substitute(t, subs), ts)
497+
if is_dde(sys)
498+
ts = map(x -> delay_to_function(sys, x), ts)
499+
end
497500
obsexprs = []
501+
498502
for i in 1:maxidx
499503
eq = obs[i]
504+
if is_dde(sys)
505+
eq = delay_to_function(sys, eq)
506+
end
500507
lhs = eq.lhs
501508
rhs = eq.rhs
502509
push!(obsexprs, lhs rhs)
@@ -517,12 +524,17 @@ function build_explicit_observed_function(sys, ts;
517524
ps = (DestructuredArgs(unwrap.(ps), inbounds = !checkbounds),)
518525
end
519526
dvs = DestructuredArgs(unknowns(sys), inbounds = !checkbounds)
527+
if !is_dde(sys)
528+
dvs = (dvs, DDE_HISTORY_FUN)
529+
else
530+
dvs = (dvs,)
531+
end
520532
if inputs === nothing
521-
args = param_only ? [ps..., ivs...] : [dvs, ps..., ivs...]
533+
args = param_only ? [ps..., ivs...] : [dvs..., ps..., ivs...]
522534
else
523535
inputs = unwrap.(inputs)
524536
ipts = DestructuredArgs(inputs, inbounds = !checkbounds)
525-
args = param_only ? [ipts, ps..., ivs...] : [dvs, ipts, ps..., ivs...]
537+
args = param_only ? [ipts, ps..., ivs...] : [dvs..., ipts, ps..., ivs...]
526538
end
527539
pre = get_postprocess_fbody(sys)
528540

@@ -558,6 +570,20 @@ function build_explicit_observed_function(sys, ts;
558570
end
559571
end
560572

573+
function populate_delays(delays::Set, obsexprs, histfn, sys, sym)
574+
_vars_util = vars(sym)
575+
for v in _vars_util
576+
v in delays && continue
577+
iscall(v) && issym(operation(v)) && (args = arguments(v); length(args) == 1) &&
578+
iscall(only(args)) || continue
579+
580+
idx = variable_index(sys, operation(v)(get_iv(sys)))
581+
idx === nothing && error("Delay term $v is not an unknown in the system")
582+
push!(delays, v)
583+
push!(obsexprs, v histfn(only(args))[idx])
584+
end
585+
end
586+
561587
function _eq_unordered(a, b)
562588
length(a) === length(b) || return false
563589
n = length(a)

0 commit comments

Comments
 (0)