Skip to content

Commit b60c152

Browse files
Merge pull request #3148 from AayushSabharwal/as/dde-u0-ctor
fix: fix `u0_constructor` for `DDEProblem`/`SDDEProblem`
2 parents c8ac522 + 93c1e8f commit b60c152

File tree

2 files changed

+17
-1
lines changed

2 files changed

+17
-1
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -879,6 +879,7 @@ function DiffEqBase.DDEProblem{iip}(sys::AbstractODESystem, u0map = [],
879879
check_length = true,
880880
eval_expression = false,
881881
eval_module = @__MODULE__,
882+
u0_constructor = identity,
882883
kwargs...) where {iip}
883884
if !iscomplete(sys)
884885
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `DDEProblem`")
@@ -892,6 +893,9 @@ function DiffEqBase.DDEProblem{iip}(sys::AbstractODESystem, u0map = [],
892893
h(p, t) = h_oop(p, t)
893894
h(p::MTKParameters, t) = h_oop(p..., t)
894895
u0 = h(p, tspan[1])
896+
if u0 !== nothing
897+
u0 = u0_constructor(u0)
898+
end
895899

896900
cbs = process_events(sys; callback, eval_expression, eval_module, kwargs...)
897901
kwargs = filter_kwargs(kwargs)
@@ -914,6 +918,7 @@ function DiffEqBase.SDDEProblem{iip}(sys::AbstractODESystem, u0map = [],
914918
sparsenoise = nothing,
915919
eval_expression = false,
916920
eval_module = @__MODULE__,
921+
u0_constructor = identity,
917922
kwargs...) where {iip}
918923
if !iscomplete(sys)
919924
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `SDDEProblem`")
@@ -929,6 +934,9 @@ function DiffEqBase.SDDEProblem{iip}(sys::AbstractODESystem, u0map = [],
929934
h(p::MTKParameters, t) = h_oop(p..., t)
930935
h(out, p::MTKParameters, t) = h_iip(out, p..., t)
931936
u0 = h(p, tspan[1])
937+
if u0 !== nothing
938+
u0 = u0_constructor(u0)
939+
end
932940

933941
cbs = process_events(sys; callback, eval_expression, eval_module, kwargs...)
934942
kwargs = filter_kwargs(kwargs)

test/dde.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using ModelingToolkit, DelayDiffEq, Test
1+
using ModelingToolkit, DelayDiffEq, StaticArrays, Test
22
using SymbolicIndexingInterface: is_markovian
33
using ModelingToolkit: t_nounits as t, D_nounits as D
44

@@ -89,6 +89,10 @@ eqs = [D(x(t)) ~ a * x(t) + b * x(t - τ) + c + (α * x(t) + γ) * η]
8989
prob_mtk = SDDEProblem(sys, [x(t) => 1.0 + t], tspan; constant_lags = (τ,));
9090
@test_nowarn sol_mtk = solve(prob_mtk, RKMil())
9191

92+
prob_sa = SDDEProblem(
93+
sys, [x(t) => 1.0 + t], tspan; constant_lags = (τ,), u0_constructor = SVector{1})
94+
@test prob_sa.u0 isa SVector{4, Float64}
95+
9296
@parameters x(..) a
9397

9498
function oscillator(; name, k = 1.0, τ = 0.01)
@@ -126,6 +130,10 @@ obsfn = ModelingToolkit.build_explicit_observed_function(
126130
@test_nowarn sol[[sys.osc1.delx, sys.osc2.delx]]
127131
@test sol[sys.osc1.delx] sol(sol.t .- 0.01; idxs = sys.osc1.x).u
128132

133+
prob_sa = DDEProblem(sys, [], (0.0, 10.0); constant_lags = [sys.osc1.τ, sys.osc2.τ],
134+
u0_constructor = SVector{4})
135+
@test prob_sa.u0 isa SVector{4, Float64}
136+
129137
@testset "DDE observed with array variables" begin
130138
@component function valve(; name)
131139
@parameters begin

0 commit comments

Comments
 (0)