Skip to content

Commit 5e6b3ee

Browse files
feat: support observed function generation for DDEs
1 parent 9aadc71 commit 5e6b3ee

File tree

2 files changed

+34
-2
lines changed

2 files changed

+34
-2
lines changed

src/systems/abstractsystem.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -805,6 +805,10 @@ function SymbolicIndexingInterface.observed(
805805
return let _fn = _fn
806806
fn1(u, p, t) = _fn(u, p, t)
807807
fn1(u, p::MTKParameters, t) = _fn(u, p..., t)
808+
809+
# DDEs
810+
fn1(u, histfn, p, t) = _fn(u, histfn, p, t)
811+
fn1(u, histfn, p::MTKParameters, t) = _fn(u, histfn, p..., t)
808812
fn1
809813
end
810814
else

src/systems/diffeqs/odesystem.jl

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -475,13 +475,22 @@ function build_explicit_observed_function(sys, ts;
475475
end
476476
ts = map(t -> substitute(t, subs), ts)
477477
obsexprs = []
478+
479+
histfn_name = gensym(:hist)
480+
histfn = only(@variables ($histfn_name)(..)::Vector{Real})
481+
482+
delays = Set()
483+
_vars_util = Set()
478484
for i in 1:maxidx
479485
eq = obs[i]
480486
lhs = eq.lhs
481487
rhs = eq.rhs
488+
populate_delays(delays, obsexprs, histfn, sys, rhs)
482489
push!(obsexprs, lhs rhs)
483490
end
484491

492+
populate_delays(delays, obsexprs, histfn, sys, ts)
493+
485494
if inputs !== nothing
486495
ps = setdiff(ps, inputs) # Inputs have been converted to parameters by io_preprocessing, remove those from the parameter list
487496
end
@@ -497,12 +506,17 @@ function build_explicit_observed_function(sys, ts;
497506
ps = (DestructuredArgs(unwrap.(ps), inbounds = !checkbounds),)
498507
end
499508
dvs = DestructuredArgs(unknowns(sys), inbounds = !checkbounds)
509+
if !isempty(delays)
510+
dvs = (dvs, histfn)
511+
else
512+
dvs = (dvs,)
513+
end
500514
if inputs === nothing
501-
args = param_only ? [ps..., ivs...] : [dvs, ps..., ivs...]
515+
args = param_only ? [ps..., ivs...] : [dvs..., ps..., ivs...]
502516
else
503517
inputs = unwrap.(inputs)
504518
ipts = DestructuredArgs(inputs, inbounds = !checkbounds)
505-
args = param_only ? [ipts, ps..., ivs...] : [dvs, ipts, ps..., ivs...]
519+
args = param_only ? [ipts, ps..., ivs...] : [dvs..., ipts, ps..., ivs...]
506520
end
507521
pre = get_postprocess_fbody(sys)
508522

@@ -538,6 +552,20 @@ function build_explicit_observed_function(sys, ts;
538552
end
539553
end
540554

555+
function populate_delays(delays::Set, obsexprs, histfn, sys, sym)
556+
_vars_util = vars(sym)
557+
for v in _vars_util
558+
v in delays && continue
559+
iscall(v) && issym(operation(v)) && (args = arguments(v); length(args) == 1) &&
560+
iscall(only(args)) || continue
561+
562+
idx = variable_index(sys, operation(v)(get_iv(sys)))
563+
idx === nothing && error("Delay term $v is not an unknown in the system")
564+
push!(delays, v)
565+
push!(obsexprs, v histfn(only(args))[idx])
566+
end
567+
end
568+
541569
function _eq_unordered(a, b)
542570
length(a) === length(b) || return false
543571
n = length(a)

0 commit comments

Comments
 (0)