Skip to content

Commit 610e13e

Browse files
Add nlstep_compile as a debug tool for turning off simplification
This can be used in order to test correctness and differences with the nlstep interface as it should exactly match the NonlinearSolveAlg
1 parent 646bace commit 610e13e

File tree

2 files changed

+11
-7
lines changed

2 files changed

+11
-7
lines changed

src/problems/odeproblem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
t = nothing, eval_expression = false, eval_module = @__MODULE__, sparse = false,
44
steady_state = false, checkbounds = false, sparsity = false, analytic = nothing,
55
simplify = false, cse = true, initialization_data = nothing, expression = Val{false},
6-
check_compatibility = true, nlstep = false, kwargs...) where {iip, spec}
6+
check_compatibility = true, nlstep = false, nlstep_compile = true, kwargs...) where {iip, spec}
77
check_complete(sys, ODEFunction)
88
check_compatibility && check_compatible_system(ODEFunction, sys)
99

@@ -42,7 +42,7 @@
4242
_M = concrete_massmatrix(M; sparse, u0)
4343

4444
if nlstep
45-
ode_nlstep = generate_ODENLStepData(sys, u0, p, M)
45+
ode_nlstep = generate_ODENLStepData(sys, u0, p, M, nlstep_compile)
4646
else
4747
ode_nlstep = nothing
4848
end

src/systems/solver_nlprob.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
function generate_ODENLStepData(sys::System, u0, p, mm = calculate_massmatrix(sys))
2-
nlsys, outer_tmp, inner_tmp = inner_nlsystem(sys, mm)
1+
function generate_ODENLStepData(sys::System, u0, p, mm = calculate_massmatrix(sys), nlstep_compile::Bool = true)
2+
nlsys, outer_tmp, inner_tmp = inner_nlsystem(sys, mm, nlstep_compile)
33
state = ProblemState(; u = u0, p)
44
op = Dict()
55
op[ODE_GAMMA[1]] = one(eltype(u0))
@@ -35,7 +35,7 @@ function get_inner_tmp(n::Int)
3535
only(@parameters inner_tmpₘₜₖ[1:n])
3636
end
3737

38-
function inner_nlsystem(sys::System, mm)
38+
function inner_nlsystem(sys::System, mm, nlstep_compile::Bool)
3939
dvs = unknowns(sys)
4040
eqs = full_equations(sys)
4141
t = get_iv(sys)
@@ -56,8 +56,12 @@ function inner_nlsystem(sys::System, mm)
5656

5757
new_dvs = unknowns(sys)
5858
new_ps = [parameters(sys); [gamma1, gamma2, gamma3, c, inner_tmp, outer_tmp]]
59-
nlsys = mtkcompile(
60-
System(new_eqs, new_dvs, new_ps; name = :nlsys); split = is_split(sys))
59+
nlsys = System(new_eqs, new_dvs, new_ps; name = :nlsys)
60+
nlsys = if nlstep_compile
61+
mtkcompile(nlsys; split = is_split(sys))
62+
else
63+
complete(nlsys; split = is_split(sys))
64+
end
6165
return nlsys, outer_tmp, inner_tmp
6266
end
6367

0 commit comments

Comments
 (0)