diff --git a/Project.toml b/Project.toml index 002dbb023..ef2cdff37 100644 --- a/Project.toml +++ b/Project.toml @@ -88,7 +88,7 @@ StableRNGs = "1.0" StaticArrays = "1.7" StaticArraysCore = "1.4" Statistics = "1.10" -SymbolicIndexingInterface = "0.3.30" +SymbolicIndexingInterface = "0.3.31" Tables = "1.11" Zygote = "0.6.67" julia = "1.10" diff --git a/src/integrator_interface.jl b/src/integrator_interface.jl index 43d473924..f40026f50 100644 --- a/src/integrator_interface.jl +++ b/src/integrator_interface.jl @@ -383,6 +383,15 @@ function set_ut!(integrator::DEIntegrator, u, t) set_t!(integrator, t) end +""" + get_sol(integrator::DEIntegrator) + +Get the solution object contained in the integrator. +""" +function get_sol(integrator::DEIntegrator) + return integrator.sol +end + ### Addat isn't a real thing. Let's make it a real thing Gretchen function addat!(a::AbstractArray, idxs, val = nothing) @@ -901,3 +910,7 @@ Checks if the integrator is adaptive function isadaptive(integrator::DEIntegrator) isdefined(integrator.opts, :adaptive) ? integrator.opts.adaptive : false end + +function SymbolicIndexingInterface.get_history_function(integ::AbstractDDEIntegrator) + DDESolutionHistoryWrapper(get_sol(integ)) +end diff --git a/src/problems/dde_problems.jl b/src/problems/dde_problems.jl index 750083975..a93c1d320 100644 --- a/src/problems/dde_problems.jl +++ b/src/problems/dde_problems.jl @@ -259,6 +259,8 @@ function DDEProblem(f::AbstractDDEFunction, args...; kwargs...) DDEProblem{isinplace(f)}(f, args...; kwargs...) end +SymbolicIndexingInterface.get_history_function(prob::AbstractDDEProblem) = prob.h + """ $(TYPEDEF) """ diff --git a/src/solutions/ode_solutions.jl b/src/solutions/ode_solutions.jl index 94e76f0e8..9454902fe 100644 --- a/src/solutions/ode_solutions.jl +++ b/src/solutions/ode_solutions.jl @@ -354,6 +354,28 @@ function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv}, return DiffEqArray(u, t, p, sol; discretes) end +struct DDESolutionHistoryWrapper{T} + sol::T +end + +function (w::DDESolutionHistoryWrapper)(p, t; idxs = nothing) + w.sol(t; idxs) +end +function (w::DDESolutionHistoryWrapper)(out, p, t; idxs = nothing) + w.sol(out, t; idxs) +end +function (w::DDESolutionHistoryWrapper)(p, t, deriv::Type{Val{i}}; idxs = nothing) where {i} + w.sol(t, deriv; idxs) +end +function (w::DDESolutionHistoryWrapper)( + out, p, t, deriv::Type{Val{i}}; idxs = nothing) where {i} + w.sol(out, t, deriv; idxs) +end + +function SymbolicIndexingInterface.get_history_function(sol::ODESolution) + DDESolutionHistoryWrapper(sol) +end + # public API, used by MTK """ create_parameter_timeseries_collection(sys, ps, tspan) diff --git a/test/downstream/Project.toml b/test/downstream/Project.toml index 86cf7681f..4a883e985 100644 --- a/test/downstream/Project.toml +++ b/test/downstream/Project.toml @@ -1,5 +1,6 @@ [deps] BoundaryValueDiffEq = "764a87c0-6b3e-53db-9096-fe964310641d" +DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb" DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JumpProcesses = "ccbc3e58-028d-4f4c-8cd5-9ae44345cda5" @@ -26,6 +27,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] BoundaryValueDiffEq = "5" +DelayDiffEq = "5" DiffEqCallbacks = "3, 4" ForwardDiff = "0.10" JumpProcesses = "9.10" diff --git a/test/downstream/comprehensive_indexing.jl b/test/downstream/comprehensive_indexing.jl index 0f174b03f..69db72d15 100644 --- a/test/downstream/comprehensive_indexing.jl +++ b/test/downstream/comprehensive_indexing.jl @@ -1,6 +1,6 @@ using ModelingToolkit, JumpProcesses, LinearAlgebra, NonlinearSolve, Optimization, OptimizationOptimJL, OrdinaryDiffEq, RecursiveArrayTools, SciMLBase, - SteadyStateDiffEq, StochasticDiffEq, SymbolicIndexingInterface, + SteadyStateDiffEq, StochasticDiffEq, DelayDiffEq, SymbolicIndexingInterface, DiffEqCallbacks, Test, Plots using ModelingToolkit: t_nounits as t, D_nounits as D @@ -922,3 +922,52 @@ end @test_throws ErrorException sol(-0.1; idxs = sys.c) @test_throws ErrorException sol(-0.1; idxs = [sys.x, sys.x + sys.c]) end + +@testset "DDEs" begin + function oscillator(; name, k = 1.0, τ = 0.01) + @parameters k=k τ=τ + @variables x(..)=0.1 y(t)=0.1 jcn(t)=0.0 delx(t) + eqs = [D(x(t)) ~ y, + D(y) ~ -k * x(t - τ) + jcn, + delx ~ x(t - τ)] + return System(eqs, t; name = name) + end + systems = @named begin + osc1 = oscillator(k = 1.0, τ = 0.01) + osc2 = oscillator(k = 2.0, τ = 0.04) + end + eqs = [osc1.jcn ~ osc2.delx, + osc2.jcn ~ osc1.delx] + @named coupledOsc = System(eqs, t) + @named coupledOsc = compose(coupledOsc, systems) + sys = structural_simplify(coupledOsc) + prob = DDEProblem(sys, [], (0.0, 10.0); constant_lags = [sys.osc1.τ, sys.osc2.τ]) + # TODO: Remove this hack once MTK can generate appropriate observed functions + fn = prob.f + function fake_observed(_) + return function obsfn(u, h, p, t) + return u + h(p, t - 0.1) + end + end + + struct NonMarkovianWrapper{S} + sys::S + end + SymbolicIndexingInterface.symbolic_container(x::NonMarkovianWrapper) = x.sys + SymbolicIndexingInterface.is_markovian(::NonMarkovianWrapper) = false + fn = DDEFunction(fn.f; observed = fake_observed, sys = NonMarkovianWrapper(fn.sys)) + function fake_hist(p, t) + return ones(length(prob.u0)) .* t + end + prob = DDEProblem( + fn, prob.u0, fake_hist, prob.tspan, prob.p; constant_lags = prob.constant_lags) + sym = sys.osc1.delx + @test prob[sym] ≈ prob.u0 .+ (prob.tspan[1] - 0.1) + integ = init(prob, MethodOfSteps(Tsit5())) + step!(integ, 10.0, true) + # DelayDiffEq wraps `integ.f` and that doesn't contain `.observed` + # so the hack above doesn't work. `@reset` also doesn't work. + @test_broken integ[sym] ≈ integ.u + SciMLBase.get_sol(integ)(9.9) + sol = solve(prob, MethodOfSteps(Tsit5())) + @test sol[sym] ≈ sol.u .+ sol(sol.t .- 0.1).u +end