diff --git a/Project.toml b/Project.toml index b84a1259e3..fd27524de3 100644 --- a/Project.toml +++ b/Project.toml @@ -42,6 +42,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804" StaticArrayInterface = "0d7ed370-da01-4f52-bd93-41d350b8b718" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77" [compat] @@ -84,6 +85,7 @@ SparseArrays = "1.9" SparseDiffTools = "2.3" StaticArrayInterface = "1.2" StaticArrays = "1.0" +SymbolicIndexingInterface = "0.3.16" TruncatedStacktraces = "1.2" julia = "1.10" diff --git a/src/OrdinaryDiffEq.jl b/src/OrdinaryDiffEq.jl index 3a0506f30d..025ec50c46 100644 --- a/src/OrdinaryDiffEq.jl +++ b/src/OrdinaryDiffEq.jl @@ -64,6 +64,8 @@ using ExponentialUtilities using NonlinearSolve +using SymbolicIndexingInterface + # Required by temporary fix in not in-place methods with 12+ broadcasts # `MVector` is used by Nordsieck forms import StaticArrays: SArray, MVector, SVector, @SVector, StaticArray, MMatrix, SA diff --git a/src/initialize_dae.jl b/src/initialize_dae.jl index 9404d9531f..23f743d16f 100644 --- a/src/initialize_dae.jl +++ b/src/initialize_dae.jl @@ -71,6 +71,12 @@ function DiffEqBase.initialize_dae!(integrator::ODEIntegrator, Val(DiffEqBase.isinplace(integrator.sol.prob))) end +function DiffEqBase.initialize_dae!(integrator::DiffEqBase.NullODEIntegrator) + _initialize_dae!(integrator, integrator.sol.prob, + OverrideInit(), + Val(DiffEqBase.isinplace(integrator.sol.prob))) +end + ## Default algorithms function _initialize_dae!(integrator, prob::ODEProblem, @@ -134,19 +140,24 @@ end function _initialize_dae!(integrator, prob::Union{ODEProblem, DAEProblem}, alg::OverrideInit, isinplace::Union{Val{true}, Val{false}}) initializeprob = prob.f.initializeprob - + prob.f.initializeprob_init!(initializeprob, integrator) # If it doesn't have autodiff, assume it comes from symbolic system like ModelingToolkit # Since then it's the case of not a DAE but has initializeprob # In which case, it should be differentiable - isAD = has_autodiff(integrator.alg) ? alg_autodiff(integrator.alg) isa AutoForwardDiff : - true + isAD = if !isdefined(integrator, :alg) + false + elseif has_autodiff(integrator.alg) + alg_autodiff(integrator.alg) isa AutoForwardDiff + else + true + end alg = default_nlsolve(alg.nlsolve, isinplace, initializeprob.u0, initializeprob, isAD) nlsol = solve(initializeprob, alg) if isinplace === Val{true}() - integrator.u .= prob.f.initializeprobmap(nlsol) + prob.f.initializeprob_update!(integrator, nlsol) elseif isinplace === Val{false}() - integrator.u = prob.f.initializeprobmap(nlsol) + prob.f.initializeprob_update!(integrator, nlsol) else error("Unreachable reached. Report this error.") end