Skip to content

Commit 264cc3a

Browse files
Merge pull request #788 from AayushSabharwal/as/ddesol-obs
feat: add specialized observed for solutions of `DDEProblem`s
2 parents f9232dd + a2433a5 commit 264cc3a

File tree

6 files changed

+90
-2
lines changed

6 files changed

+90
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ StableRNGs = "1.0"
8888
StaticArrays = "1.7"
8989
StaticArraysCore = "1.4"
9090
Statistics = "1.10"
91-
SymbolicIndexingInterface = "0.3.30"
91+
SymbolicIndexingInterface = "0.3.31"
9292
Tables = "1.11"
9393
Zygote = "0.6.67"
9494
julia = "1.10"

src/integrator_interface.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,15 @@ function set_ut!(integrator::DEIntegrator, u, t)
383383
set_t!(integrator, t)
384384
end
385385

386+
"""
387+
get_sol(integrator::DEIntegrator)
388+
389+
Get the solution object contained in the integrator.
390+
"""
391+
function get_sol(integrator::DEIntegrator)
392+
return integrator.sol
393+
end
394+
386395
### Addat isn't a real thing. Let's make it a real thing Gretchen
387396

388397
function addat!(a::AbstractArray, idxs, val = nothing)
@@ -901,3 +910,7 @@ Checks if the integrator is adaptive
901910
function isadaptive(integrator::DEIntegrator)
902911
isdefined(integrator.opts, :adaptive) ? integrator.opts.adaptive : false
903912
end
913+
914+
function SymbolicIndexingInterface.get_history_function(integ::AbstractDDEIntegrator)
915+
DDESolutionHistoryWrapper(get_sol(integ))
916+
end

src/problems/dde_problems.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,8 @@ function DDEProblem(f::AbstractDDEFunction, args...; kwargs...)
259259
DDEProblem{isinplace(f)}(f, args...; kwargs...)
260260
end
261261

262+
SymbolicIndexingInterface.get_history_function(prob::AbstractDDEProblem) = prob.h
263+
262264
"""
263265
$(TYPEDEF)
264266
"""

src/solutions/ode_solutions.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,28 @@ function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv},
354354
return DiffEqArray(u, t, p, sol; discretes)
355355
end
356356

357+
struct DDESolutionHistoryWrapper{T}
358+
sol::T
359+
end
360+
361+
function (w::DDESolutionHistoryWrapper)(p, t; idxs = nothing)
362+
w.sol(t; idxs)
363+
end
364+
function (w::DDESolutionHistoryWrapper)(out, p, t; idxs = nothing)
365+
w.sol(out, t; idxs)
366+
end
367+
function (w::DDESolutionHistoryWrapper)(p, t, deriv::Type{Val{i}}; idxs = nothing) where {i}
368+
w.sol(t, deriv; idxs)
369+
end
370+
function (w::DDESolutionHistoryWrapper)(
371+
out, p, t, deriv::Type{Val{i}}; idxs = nothing) where {i}
372+
w.sol(out, t, deriv; idxs)
373+
end
374+
375+
function SymbolicIndexingInterface.get_history_function(sol::ODESolution)
376+
DDESolutionHistoryWrapper(sol)
377+
end
378+
357379
# public API, used by MTK
358380
"""
359381
create_parameter_timeseries_collection(sys, ps, tspan)

test/downstream/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
[deps]
22
BoundaryValueDiffEq = "764a87c0-6b3e-53db-9096-fe964310641d"
3+
DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb"
34
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
45
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
56
JumpProcesses = "ccbc3e58-028d-4f4c-8cd5-9ae44345cda5"
@@ -26,6 +27,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2627

2728
[compat]
2829
BoundaryValueDiffEq = "5"
30+
DelayDiffEq = "5"
2931
DiffEqCallbacks = "3, 4"
3032
ForwardDiff = "0.10"
3133
JumpProcesses = "9.10"

test/downstream/comprehensive_indexing.jl

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using ModelingToolkit, JumpProcesses, LinearAlgebra, NonlinearSolve, Optimization,
22
OptimizationOptimJL, OrdinaryDiffEq, RecursiveArrayTools, SciMLBase,
3-
SteadyStateDiffEq, StochasticDiffEq, SymbolicIndexingInterface,
3+
SteadyStateDiffEq, StochasticDiffEq, DelayDiffEq, SymbolicIndexingInterface,
44
DiffEqCallbacks, Test, Plots
55
using ModelingToolkit: t_nounits as t, D_nounits as D
66

@@ -922,3 +922,52 @@ end
922922
@test_throws ErrorException sol(-0.1; idxs = sys.c)
923923
@test_throws ErrorException sol(-0.1; idxs = [sys.x, sys.x + sys.c])
924924
end
925+
926+
@testset "DDEs" begin
927+
function oscillator(; name, k = 1.0, τ = 0.01)
928+
@parameters k=k τ=τ
929+
@variables x(..)=0.1 y(t)=0.1 jcn(t)=0.0 delx(t)
930+
eqs = [D(x(t)) ~ y,
931+
D(y) ~ -k * x(t - τ) + jcn,
932+
delx ~ x(t - τ)]
933+
return System(eqs, t; name = name)
934+
end
935+
systems = @named begin
936+
osc1 = oscillator(k = 1.0, τ = 0.01)
937+
osc2 = oscillator(k = 2.0, τ = 0.04)
938+
end
939+
eqs = [osc1.jcn ~ osc2.delx,
940+
osc2.jcn ~ osc1.delx]
941+
@named coupledOsc = System(eqs, t)
942+
@named coupledOsc = compose(coupledOsc, systems)
943+
sys = structural_simplify(coupledOsc)
944+
prob = DDEProblem(sys, [], (0.0, 10.0); constant_lags = [sys.osc1.τ, sys.osc2.τ])
945+
# TODO: Remove this hack once MTK can generate appropriate observed functions
946+
fn = prob.f
947+
function fake_observed(_)
948+
return function obsfn(u, h, p, t)
949+
return u + h(p, t - 0.1)
950+
end
951+
end
952+
953+
struct NonMarkovianWrapper{S}
954+
sys::S
955+
end
956+
SymbolicIndexingInterface.symbolic_container(x::NonMarkovianWrapper) = x.sys
957+
SymbolicIndexingInterface.is_markovian(::NonMarkovianWrapper) = false
958+
fn = DDEFunction(fn.f; observed = fake_observed, sys = NonMarkovianWrapper(fn.sys))
959+
function fake_hist(p, t)
960+
return ones(length(prob.u0)) .* t
961+
end
962+
prob = DDEProblem(
963+
fn, prob.u0, fake_hist, prob.tspan, prob.p; constant_lags = prob.constant_lags)
964+
sym = sys.osc1.delx
965+
@test prob[sym] prob.u0 .+ (prob.tspan[1] - 0.1)
966+
integ = init(prob, MethodOfSteps(Tsit5()))
967+
step!(integ, 10.0, true)
968+
# DelayDiffEq wraps `integ.f` and that doesn't contain `.observed`
969+
# so the hack above doesn't work. `@reset` also doesn't work.
970+
@test_broken integ[sym] integ.u + SciMLBase.get_sol(integ)(9.9)
971+
sol = solve(prob, MethodOfSteps(Tsit5()))
972+
@test sol[sym] sol.u .+ sol(sol.t .- 0.1).u
973+
end

0 commit comments

Comments
 (0)