Skip to content

Commit 8211ddc

Browse files
Merge pull request #3223 from AayushSabharwal/as/sde-mtkize
fix: set defaults in `modelingtoolkitize(::SDEProblem)`
2 parents f284c3f + a401772 commit 8211ddc

File tree

2 files changed

+43
-1
lines changed

2 files changed

+43
-1
lines changed

src/systems/diffeqs/modelingtoolkitize.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,10 +277,24 @@ function modelingtoolkitize(prob::DiffEqBase.SDEProblem; kwargs...)
277277
else
278278
Vector(vec(params))
279279
end
280+
sts = Vector(vec(vars))
281+
default_u0 = Dict(sts .=> vec(collect(prob.u0)))
282+
default_p = if has_p
283+
if prob.p isa AbstractDict
284+
Dict(v => prob.p[k] for (k, v) in pairs(_params))
285+
elseif prob.p isa MTKParameters
286+
Dict(params .=> reduce(vcat, prob.p))
287+
else
288+
Dict(params .=> vec(collect(prob.p)))
289+
end
290+
else
291+
Dict()
292+
end
280293

281-
de = SDESystem(deqs, neqs, t, Vector(vec(vars)), params;
294+
de = SDESystem(deqs, neqs, t, sts, params;
282295
name = gensym(:MTKizedSDE),
283296
tspan = prob.tspan,
297+
defaults = merge(default_u0, default_p),
284298
kwargs...)
285299

286300
de

test/modelingtoolkitize.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,3 +441,31 @@ prob = NonlinearLeastSquaresProblem(
441441
sys = modelingtoolkitize(prob)
442442
@test length(equations(sys)) == 3
443443
@test length(equations(structural_simplify(sys; fully_determined = false))) == 0
444+
445+
@testset "`modelingtoolkitize(::SDEProblem)` sets defaults" begin
446+
function sdeg!(du, u, p, t)
447+
du[1] = 0.3 * u[1]
448+
du[2] = 0.3 * u[2]
449+
du[3] = 0.3 * u[3]
450+
end
451+
function sdef!(du, u, p, t)
452+
x, y, z = u
453+
sigma, rho, beta = p
454+
du[1] = sigma * (y - x)
455+
du[2] = x * (rho - z) - y
456+
du[3] = x * y - beta * z
457+
end
458+
u0 = [1.0, 0.0, 0.0]
459+
tspan = (0.0, 100.0)
460+
p = [10.0, 28.0, 2.66]
461+
sprob = SDEProblem(sdef!, sdeg!, u0, tspan, p)
462+
sys = complete(modelingtoolkitize(sprob))
463+
@test length(ModelingToolkit.defaults(sys)) == 6
464+
sprob2 = SDEProblem(sys, [], tspan)
465+
466+
truevals = similar(u0)
467+
sprob.f(truevals, u0, p, tspan[1])
468+
mtkvals = similar(u0)
469+
sprob2.f(mtkvals, sprob2.u0, sprob2.p, tspan[1])
470+
@test mtkvals truevals
471+
end

0 commit comments

Comments
 (0)