From 610e13e485ebc400b7999d64791a2eb1ac5177a0 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Wed, 6 Aug 2025 15:23:38 -0400 Subject: [PATCH] Add nlstep_compile as a debug tool for turning off simplification This can be used in order to test correctness and differences with the nlstep interface as it should exactly match the NonlinearSolveAlg --- src/problems/odeproblem.jl | 4 ++-- src/systems/solver_nlprob.jl | 14 +++++++++----- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/src/problems/odeproblem.jl b/src/problems/odeproblem.jl index 68a38c95cf..da7fd10e31 100644 --- a/src/problems/odeproblem.jl +++ b/src/problems/odeproblem.jl @@ -3,7 +3,7 @@ t = nothing, eval_expression = false, eval_module = @__MODULE__, sparse = false, steady_state = false, checkbounds = false, sparsity = false, analytic = nothing, simplify = false, cse = true, initialization_data = nothing, expression = Val{false}, - check_compatibility = true, nlstep = false, kwargs...) where {iip, spec} + check_compatibility = true, nlstep = false, nlstep_compile = true, kwargs...) where {iip, spec} check_complete(sys, ODEFunction) check_compatibility && check_compatible_system(ODEFunction, sys) @@ -42,7 +42,7 @@ _M = concrete_massmatrix(M; sparse, u0) if nlstep - ode_nlstep = generate_ODENLStepData(sys, u0, p, M) + ode_nlstep = generate_ODENLStepData(sys, u0, p, M, nlstep_compile) else ode_nlstep = nothing end diff --git a/src/systems/solver_nlprob.jl b/src/systems/solver_nlprob.jl index 9fdc85eacd..04e04ed4dd 100644 --- a/src/systems/solver_nlprob.jl +++ b/src/systems/solver_nlprob.jl @@ -1,5 +1,5 @@ -function generate_ODENLStepData(sys::System, u0, p, mm = calculate_massmatrix(sys)) - nlsys, outer_tmp, inner_tmp = inner_nlsystem(sys, mm) +function generate_ODENLStepData(sys::System, u0, p, mm = calculate_massmatrix(sys), nlstep_compile::Bool = true) + nlsys, outer_tmp, inner_tmp = inner_nlsystem(sys, mm, nlstep_compile) state = ProblemState(; u = u0, p) op = Dict() op[ODE_GAMMA[1]] = one(eltype(u0)) @@ -35,7 +35,7 @@ function get_inner_tmp(n::Int) only(@parameters inner_tmpₘₜₖ[1:n]) end -function inner_nlsystem(sys::System, mm) +function inner_nlsystem(sys::System, mm, nlstep_compile::Bool) dvs = unknowns(sys) eqs = full_equations(sys) t = get_iv(sys) @@ -56,8 +56,12 @@ function inner_nlsystem(sys::System, mm) new_dvs = unknowns(sys) new_ps = [parameters(sys); [gamma1, gamma2, gamma3, c, inner_tmp, outer_tmp]] - nlsys = mtkcompile( - System(new_eqs, new_dvs, new_ps; name = :nlsys); split = is_split(sys)) + nlsys = System(new_eqs, new_dvs, new_ps; name = :nlsys) + nlsys = if nlstep_compile + mtkcompile(nlsys; split = is_split(sys)) + else + complete(nlsys; split = is_split(sys)) + end return nlsys, outer_tmp, inner_tmp end