Skip to content

Commit 9173f80

Browse files
fix: CheckInit bug fix
1 parent e0fade7 commit 9173f80

File tree

2 files changed

+44
-4
lines changed

2 files changed

+44
-4
lines changed

src/initialization.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,8 @@ Check if the algebraic constraints are satisfied, and error if they aren't. Retu
100100
the `u0` and `p` as-is, and is always successful if it returns. Valid only for
101101
`ODEProblem` and `DAEProblem`. Requires a `DEIntegrator` as its second argument.
102102
"""
103-
function get_initial_values(prob::AbstractODEProblem, integrator, f, alg::CheckInit,
103+
function get_initial_values(
104+
prob::AbstractDEProblem, integrator::DEIntegrator, f, alg::CheckInit,
104105
isinplace::Union{Val{true}, Val{false}}; kwargs...)
105106
u0 = state_values(integrator)
106107
p = parameter_values(integrator)
@@ -109,7 +110,7 @@ function get_initial_values(prob::AbstractODEProblem, integrator, f, alg::CheckI
109110

110111
algebraic_vars = [all(iszero, x) for x in eachcol(M)]
111112
algebraic_eqs = [all(iszero, x) for x in eachrow(M)]
112-
(iszero(algebraic_vars) || iszero(algebraic_eqs)) && return
113+
(iszero(algebraic_vars) || iszero(algebraic_eqs)) && return u0, p, true
113114
update_coefficients!(M, u0, p, t)
114115
tmp = _evaluate_f_ode(integrator, f, isinplace, u0, p, t)
115116
tmp .= ArrayInterface.restructure(tmp, algebraic_eqs .* _vec(tmp))
@@ -135,7 +136,8 @@ function _evaluate_f_dae(integrator, f, isinplace::Val{false}, args...)
135136
return f(args...)
136137
end
137138

138-
function get_initial_values(prob::AbstractDAEProblem, integrator, f, alg::CheckInit,
139+
function get_initial_values(
140+
prob::AbstractDAEProblem, integrator::DEIntegrator, f, alg::CheckInit,
139141
isinplace::Union{Val{true}, Val{false}}; kwargs...)
140142
u0 = state_values(integrator)
141143
p = parameter_values(integrator)

test/initialization.jl

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using OrdinaryDiffEq, NonlinearSolve, SymbolicIndexingInterface, Test
1+
using StochasticDiffEq, OrdinaryDiffEq, NonlinearSolve, SymbolicIndexingInterface, Test
22

33
@testset "CheckInit" begin
44
@testset "ODEProblem" begin
@@ -57,6 +57,44 @@ using OrdinaryDiffEq, NonlinearSolve, SymbolicIndexingInterface, Test
5757
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)))
5858
end
5959
end
60+
61+
@testset "SDEProblem" begin
62+
mm_A = [1 0 0; 0 1 0; 0 0 0]
63+
function sdef!(du, u, p, t)
64+
du[1] = u[1]
65+
du[2] = u[2]
66+
du[3] = u[1] + u[2] + u[3] - 1
67+
end
68+
function sdef(u, p, t)
69+
du = similar(u)
70+
sdef!(du, u, p, t)
71+
du
72+
end
73+
74+
function g!(du, u, p, t)
75+
@. du = 0.1
76+
end
77+
function g(u, p, t)
78+
du = similar(u)
79+
g!(du, u, p, t)
80+
du
81+
end
82+
iipfn = SDEFunction{true}(sdef!, g!; mass_matrix = mm_A)
83+
oopfn = SDEFunction{false}(sdef, g; mass_matrix = mm_A)
84+
85+
@testset "Inplace = $(SciMLBase.isinplace(f))" for f in [oopfn, iipfn]
86+
prob = SDEProblem(f, [1.0, 1.0, -1.0], (0.0, 1.0))
87+
integ = init(prob, ImplicitEM())
88+
u0, _, success = SciMLBase.get_initial_values(
89+
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)))
90+
@test success
91+
@test u0 == prob.u0
92+
93+
integ.u[2] = 2.0
94+
@test_throws SciMLBase.CheckInitFailureError SciMLBase.get_initial_values(
95+
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)))
96+
end
97+
end
6098
end
6199

62100
@testset "OverrideInit" begin

0 commit comments

Comments
 (0)