Skip to content

Commit 7f37ea3

Browse files
committed
fix: SDE has no tearing state
1 parent d2e952f commit 7f37ea3

File tree

3 files changed

+12
-18
lines changed

3 files changed

+12
-18
lines changed

src/systems/diffeqs/sdesystem.jl

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -642,20 +642,16 @@ function DiffEqBase.SDEFunction{iip, specialize}(sys::SDESystem, dvs = unknowns(
642642
_Wfact, _Wfact_t = nothing, nothing
643643
end
644644

645+
M = calculate_massmatrix(sys)
645646
if sparse
646647
uElType = u0 === nothing ? Float64 : eltype(u0)
647-
if jac
648-
jac_prototype = similar(calculate_jacobian(sys; sparse), uElType)
649-
else
650-
jac_prototype = similar(jacobian_sparsity(sys), uElType)
651-
end
652-
W_prototype = similar(W_sparsity(sys), uElType)
648+
jac_prototype = similar(calculate_jacobian(sys; sparse), uElType)
649+
W_prototype = similar(jac_prototype .+ M, uElType)
653650
else
654651
jac_prototype = nothing
655652
W_prototype = nothing
656653
end
657654

658-
M = calculate_massmatrix(sys)
659655
_M = (u0 === nothing || M == I) ? M : ArrayInterface.restructure(u0 .* u0', M)
660656

661657
observedfun = ObservedFunctionCache(
@@ -742,15 +738,14 @@ function SDEFunctionExpr{iip}(sys::SDESystem, dvs = unknowns(sys),
742738
_jac = :nothing
743739
end
744740

745-
jac_prototype = if sparse
741+
M = calculate_massmatrix(sys)
742+
if sparse
746743
uElType = u0 === nothing ? Float64 : eltype(u0)
747-
if jac
748-
similar(calculate_jacobian(sys, sparse = sparse), uElType)
749-
else
750-
similar(jacobian_sparsity(sys), uElType)
751-
end
744+
jac_prototype = similar(calculate_jacobian(sys; sparse), uElType)
745+
W_prototype = similar(jac_prototype .+ M, uElType)
752746
else
753-
nothing
747+
jac_prototype = nothing
748+
W_prototype = nothing
754749
end
755750

756751
if Wfact
@@ -763,8 +758,6 @@ function SDEFunctionExpr{iip}(sys::SDESystem, dvs = unknowns(sys),
763758
_Wfact, _Wfact_t = :nothing, :nothing
764759
end
765760

766-
M = calculate_massmatrix(sys)
767-
768761
_M = (u0 === nothing || M == I) ? M : ArrayInterface.restructure(u0 .* u0', M)
769762

770763
ex = quote

test/jacobiansparsity.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using ModelingToolkit, SparseArrays#, OrdinaryDiffEq
1+
using ModelingToolkit, SparseArrays, OrdinaryDiffEq
22

33
N = 3
44
xyd_brusselator = range(0, stop = 1, length = N)
@@ -100,6 +100,7 @@ prob = ODEProblem(sys, u0, (0, 11.5), sparse = true, jac = true)
100100
W_prototype = ModelingToolkit.W_sparsity(pend)
101101
@test nnz(W_prototype) == nnz(jac_prototype) + 2
102102

103+
# jac_prototype should be the same as W_prototype
103104
@test findnz(prob.f.jac_prototype)[1:2] == findnz(W_prototype)[1:2]
104105

105106
u = zeros(5)

test/nonlinearsystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ eqs = [0 ~ σ * (y - x) * h,
3030
@test eval(toexpr(ns)) == ns
3131
test_nlsys_inference("standard", ns, (x, y, z), (σ, ρ, β))
3232
@test begin
33-
f = eval(generate_function(ns, [x, y, z], [σ, ρ, β])[2])
33+
f = generate_function(ns, [x, y, z], [σ, ρ, β], expression = Val{false})[2]
3434
du = [0.0, 0.0, 0.0]
3535
f(du, [1, 2, 3], [1, 2, 3])
3636
du [1, -3, -7]

0 commit comments

Comments
 (0)