From f6b6c00da4622bcee907a9ead370feaa33ac105c Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Wed, 28 May 2025 00:27:22 +0000 Subject: [PATCH 1/2] Setup NonlinearSolveAlg with jacobian reuse --- lib/OrdinaryDiffEqCore/src/misc_utils.jl | 1 + .../src/OrdinaryDiffEqDifferentiation.jl | 2 +- .../src/derivative_utils.jl | 2 +- .../src/OrdinaryDiffEqNonlinearSolve.jl | 4 ++-- .../src/newton.jl | 20 ++++++++++++++++--- .../src/nlsolve.jl | 4 ++-- lib/OrdinaryDiffEqNonlinearSolve/src/type.jl | 1 + lib/OrdinaryDiffEqNonlinearSolve/src/utils.jl | 12 +++++++---- 8 files changed, 33 insertions(+), 13 deletions(-) diff --git a/lib/OrdinaryDiffEqCore/src/misc_utils.jl b/lib/OrdinaryDiffEqCore/src/misc_utils.jl index c3b6082802..04314a4e9c 100644 --- a/lib/OrdinaryDiffEqCore/src/misc_utils.jl +++ b/lib/OrdinaryDiffEqCore/src/misc_utils.jl @@ -133,6 +133,7 @@ function get_differential_vars(f, u) end isnewton(::Any) = false +isnonlinearsolve(::Any) = false function _bool_to_ADType(::Val{true}, ::Val{CS}, _) where {CS} Base.depwarn( diff --git a/lib/OrdinaryDiffEqDifferentiation/src/OrdinaryDiffEqDifferentiation.jl b/lib/OrdinaryDiffEqDifferentiation/src/OrdinaryDiffEqDifferentiation.jl index 677b81e880..4be1ae92cb 100644 --- a/lib/OrdinaryDiffEqDifferentiation/src/OrdinaryDiffEqDifferentiation.jl +++ b/lib/OrdinaryDiffEqDifferentiation/src/OrdinaryDiffEqDifferentiation.jl @@ -38,7 +38,7 @@ using OrdinaryDiffEqCore: OrdinaryDiffEqAlgorithm, OrdinaryDiffEqAdaptiveImplici OrdinaryDiffEqAdaptiveExponentialAlgorithm, @unpack, AbstractNLSolver, nlsolve_f, issplit, concrete_jac, unwrap_alg, OrdinaryDiffEqCache, _vec, standardtag, - isnewton, _unwrap_val, + isnewton, isnonlinearsolve, _unwrap_val, set_new_W!, set_W_γdt!, alg_difftype, unwrap_cache, diffdir, get_W, isfirstcall, isfirststage, isJcurrent, get_new_W_γdt_cutoff, diff --git a/lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl b/lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl index 553e128857..7ae66f05c9 100644 --- a/lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl +++ b/lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl @@ -468,7 +468,7 @@ function do_newJW(integrator, alg, nlsolver, repeat_step)::NTuple{2, Bool} return true, true end # TODO: add `isJcurrent` support for Rosenbrock solvers - if !isnewton(nlsolver) + if !isnewton(nlsolver) && !isnonlinearsolve(nlsolver) isfreshJ = !(integrator.alg isa CompositeAlgorithm) && (integrator.iter > 1 && errorfail && !integrator.u_modified) return !isfreshJ, true diff --git a/lib/OrdinaryDiffEqNonlinearSolve/src/OrdinaryDiffEqNonlinearSolve.jl b/lib/OrdinaryDiffEqNonlinearSolve/src/OrdinaryDiffEqNonlinearSolve.jl index 4e58484ffb..38a9d4924b 100644 --- a/lib/OrdinaryDiffEqNonlinearSolve/src/OrdinaryDiffEqNonlinearSolve.jl +++ b/lib/OrdinaryDiffEqNonlinearSolve/src/OrdinaryDiffEqNonlinearSolve.jl @@ -50,12 +50,12 @@ using OrdinaryDiffEqCore: resize_nlsolver!, _initialize_dae!, import OrdinaryDiffEqCore: _initialize_dae!, isnewton, get_W, isfirstcall, isfirststage, isJcurrent, get_new_W_γdt_cutoff, resize_nlsolver!, apply_step!, - postamble! + postamble!, isnonlinearsolve import OrdinaryDiffEqDifferentiation: update_W!, is_always_new, build_uf, build_J_W, WOperator, StaticWOperator, wrapprecs, build_jac_config, dolinsolve, alg_autodiff, - resize_jac_config! + resize_jac_config!, do_newJW import StaticArrays: SArray, MVector, SVector, @SVector, StaticArray, MMatrix, SA, StaticMatrix diff --git a/lib/OrdinaryDiffEqNonlinearSolve/src/newton.jl b/lib/OrdinaryDiffEqNonlinearSolve/src/newton.jl index d12eae3f20..4b881a2f8f 100644 --- a/lib/OrdinaryDiffEqNonlinearSolve/src/newton.jl +++ b/lib/OrdinaryDiffEqNonlinearSolve/src/newton.jl @@ -96,8 +96,15 @@ end @unpack z, tmp, ztmp, γ, α, cache, method = nlsolver @unpack tstep, invγdt = cache + new_jac, new_W = do_newJW(integrator, integrator.alg, nlsolver, false) + if is_always_new(nlsolver) || new_jac || new_W + recompute_jacobian = true + else + recompute_jacobian = false + end + nlcache = nlsolver.cache.cache - step!(nlcache) + step!(nlcache; recompute_jacobian) nlsolver.ztmp = nlcache.u ustep = compute_ustep(tmp, γ, z, method) @@ -118,9 +125,16 @@ end @unpack z, tmp, ztmp, γ, α, cache, method = nlsolver @unpack tstep, invγdt, atmp, ustep = cache - nlstep_data = integrator.f.nlstep_data + new_jac, new_W = do_newJW(integrator, integrator.alg, nlsolver, false) + if is_always_new(nlsolver) || new_jac || new_W + recompute_jacobian = true + else + recompute_jacobian = false + end + nlcache = nlsolver.cache.cache - step!(nlcache) + nlstep_data = integrator.f.nlstep_data + step!(nlcache; recompute_jacobian) if nlstep_data !== nothing nlstepsol = SciMLBase.build_solution( diff --git a/lib/OrdinaryDiffEqNonlinearSolve/src/nlsolve.jl b/lib/OrdinaryDiffEqNonlinearSolve/src/nlsolve.jl index 3e389d4270..4bec8ffe3b 100644 --- a/lib/OrdinaryDiffEqNonlinearSolve/src/nlsolve.jl +++ b/lib/OrdinaryDiffEqNonlinearSolve/src/nlsolve.jl @@ -105,7 +105,7 @@ function nlsolve!(nlsolver::NL, integrator::SciMLBase.DEIntegrator, # don't trust θ for non-adaptive on first iter because the solver doesn't provide feedback # for us to know whether our previous nlsolve converged sufficiently well check_η_convergence = (iter > 1 || - (isnewton(nlsolver) && isadaptive(integrator.alg))) + ((isnewton(nlsolver) || isnonlinearsolve(nlsolver)) && isadaptive(integrator.alg))) if (iter == 1 && ndz < 1e-5) || (check_η_convergence && η >= zero(η) && η * ndz < κ) nlsolver.status = Convergence @@ -114,7 +114,7 @@ function nlsolve!(nlsolver::NL, integrator::SciMLBase.DEIntegrator, end end - if isnewton(nlsolver) && nlsolver.status == Divergence && + if (isnewton(nlsolver) || isnonlinearsolve(nlsolver)) && nlsolver.status == Divergence && !isJcurrent(nlsolver, integrator) nlsolver.status = TryAgain nlsolver.nfails += 1 diff --git a/lib/OrdinaryDiffEqNonlinearSolve/src/type.jl b/lib/OrdinaryDiffEqNonlinearSolve/src/type.jl index 7b41c54903..05d89a349b 100644 --- a/lib/OrdinaryDiffEqNonlinearSolve/src/type.jl +++ b/lib/OrdinaryDiffEqNonlinearSolve/src/type.jl @@ -218,4 +218,5 @@ mutable struct NonlinearSolveCache{uType, tType, rateType, tType2, P, C} <: invγdt::tType2 prob::P cache::C + new_W::Bool end diff --git a/lib/OrdinaryDiffEqNonlinearSolve/src/utils.jl b/lib/OrdinaryDiffEqNonlinearSolve/src/utils.jl index 49a0db411d..466c18dbd3 100644 --- a/lib/OrdinaryDiffEqNonlinearSolve/src/utils.jl +++ b/lib/OrdinaryDiffEqNonlinearSolve/src/utils.jl @@ -14,6 +14,10 @@ isnewton(nlsolver::AbstractNLSolver) = isnewton(nlsolver.cache) isnewton(::AbstractNLSolverCache) = false isnewton(::Union{NLNewtonCache, NLNewtonConstantCache}) = true +isnonlinearsolve(nlsolver::AbstractNLSolver) = isnonlinearsolve(nlsolver.cache) +isnonlinearsolve(::AbstractNLSolverCache) = false +isnonlinearsolve(::NonlinearSolveCache) = true + is_always_new(nlsolver::AbstractNLSolver) = is_always_new(nlsolver.alg) check_div(nlsolver::AbstractNLSolver) = check_div(nlsolver.alg) check_div(alg) = isdefined(alg, :check_div) ? alg.check_div : true @@ -32,9 +36,9 @@ getnfails(_) = 0 getnfails(nlsolver::AbstractNLSolver) = nlsolver.nfails set_new_W!(nlsolver::AbstractNLSolver, val::Bool)::Bool = set_new_W!(nlsolver.cache, val) -set_new_W!(nlcache::Union{NLNewtonCache, NLNewtonConstantCache}, val::Bool)::Bool = nlcache.new_W = val +set_new_W!(nlcache::Union{NLNewtonCache, NLNewtonConstantCache, NonlinearSolveCache}, val::Bool)::Bool = nlcache.new_W = val get_new_W!(nlsolver::AbstractNLSolver)::Bool = get_new_W!(nlsolver.cache) -get_new_W!(nlcache::Union{NLNewtonCache, NLNewtonConstantCache})::Bool = nlcache.new_W +get_new_W!(nlcache::Union{NLNewtonCache, NLNewtonConstantCache, NonlinearSolveCache})::Bool = nlcache.new_W get_new_W!(::AbstractNLSolverCache)::Bool = true get_W(nlsolver::AbstractNLSolver) = get_W(nlsolver.cache) @@ -239,7 +243,7 @@ function build_nlsolver( NonlinearProblem(NonlinearFunction{true}(nlf), ztmp, nlp_params) end cache = init(prob, nlalg.alg) - nlcache = NonlinearSolveCache(ustep, tstep, k, atmp, invγdt, prob, cache) + nlcache = NonlinearSolveCache(ustep, tstep, k, atmp, invγdt, prob, cache, true) else nlcache = NLNewtonCache(ustep, tstep, k, atmp, dz, J, W, true, true, true, tType(dt), du1, uf, jac_config, @@ -327,7 +331,7 @@ function build_nlsolver( prob = NonlinearProblem(NonlinearFunction{false}(nlf), copy(ztmp), nlp_params) cache = init(prob, nlalg.alg) nlcache = NonlinearSolveCache( - nothing, tstep, nothing, nothing, invγdt, prob, cache) + nothing, tstep, nothing, nothing, invγdt, prob, cache, true) else nlcache = NLNewtonConstantCache(tstep, J, W, true, true, true, tType(dt), uf, invγdt, tType(nlalg.new_W_dt_cutoff), t) From 514799c5293f36147dc605ed55cd863cb624138f Mon Sep 17 00:00:00 2001 From: Oscar Smith Date: Tue, 7 Oct 2025 00:19:14 -0400 Subject: [PATCH 2/2] use less isnewton --- .../src/OrdinaryDiffEqBDF.jl | 2 +- lib/OrdinaryDiffEqBDF/src/bdf_perform_step.jl | 8 ++-- .../src/derivative_utils.jl | 2 +- .../src/newton.jl | 6 ++- .../src/nlsolve.jl | 2 +- lib/OrdinaryDiffEqNonlinearSolve/src/type.jl | 7 +++- lib/OrdinaryDiffEqNonlinearSolve/src/utils.jl | 13 +++--- .../src/kencarp_kvaerno_perform_step.jl | 18 ++++---- .../src/sdirk_perform_step.jl | 41 ++++++++----------- 9 files changed, 50 insertions(+), 49 deletions(-) diff --git a/lib/OrdinaryDiffEqBDF/src/OrdinaryDiffEqBDF.jl b/lib/OrdinaryDiffEqBDF/src/OrdinaryDiffEqBDF.jl index fb7ab3d305..e60ca0fd57 100644 --- a/lib/OrdinaryDiffEqBDF/src/OrdinaryDiffEqBDF.jl +++ b/lib/OrdinaryDiffEqBDF/src/OrdinaryDiffEqBDF.jl @@ -36,7 +36,7 @@ using ArrayInterface: ismutable import OrdinaryDiffEqCore using OrdinaryDiffEqDifferentiation: UJacobianWrapper using OrdinaryDiffEqNonlinearSolve: NLNewton, du_alias_or_new, build_nlsolver, - nlsolve!, nlsolvefail, isnewton, markfirststage!, + nlsolve!, nlsolvefail, markfirststage!, set_new_W!, DIRK, compute_step!, COEFFICIENT_MULTISTEP, NonlinearSolveAlg import ADTypes: AutoForwardDiff, AutoFiniteDiff, AbstractADType diff --git a/lib/OrdinaryDiffEqBDF/src/bdf_perform_step.jl b/lib/OrdinaryDiffEqBDF/src/bdf_perform_step.jl index b421a987dd..3be0294f56 100644 --- a/lib/OrdinaryDiffEqBDF/src/bdf_perform_step.jl +++ b/lib/OrdinaryDiffEqBDF/src/bdf_perform_step.jl @@ -985,7 +985,7 @@ end ### STEP 2 nlsolver.tmp = z₁ nlsolver.c = 2 - isnewton(nlsolver) && set_new_W!(nlsolver, false) + set_new_W!(nlsolver, false) z = nlsolve!(nlsolver, integrator, cache, repeat_step) nlsolvefail(nlsolver) && return z₂ = z₁ + z @@ -993,7 +993,7 @@ end tmp2 = 0.5uprev + z₁ - 0.5z₂ nlsolver.tmp = tmp2 nlsolver.c = 1 - isnewton(nlsolver) && set_new_W!(nlsolver, false) + set_new_W!(nlsolver, false) z = nlsolve!(nlsolver, integrator, cache, repeat_step) nlsolvefail(nlsolver) && return u = tmp2 + z @@ -1039,7 +1039,7 @@ end ### STEP 2 nlsolver.tmp = z₁ nlsolver.c = 2 - isnewton(nlsolver) && set_new_W!(nlsolver, false) + set_new_W!(nlsolver, false) z = nlsolve!(nlsolver, integrator, cache, repeat_step) nlsolvefail(nlsolver) && return @.. broadcast=false z₂=z₁+z @@ -1048,7 +1048,7 @@ end @.. broadcast=false tmp2=0.5uprev+z₁-0.5z₂ nlsolver.tmp = tmp2 nlsolver.c = 1 - isnewton(nlsolver) && set_new_W!(nlsolver, false) + set_new_W!(nlsolver, false) z = nlsolve!(nlsolver, integrator, cache, repeat_step) nlsolvefail(nlsolver) && return @.. broadcast=false u=tmp2+z diff --git a/lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl b/lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl index 7ae66f05c9..e2a9315a0b 100644 --- a/lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl +++ b/lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl @@ -606,7 +606,7 @@ function calc_W!(W, integrator, nlsolver::Union{Nothing, AbstractNLSolver}, cach new_jac, new_W = newJW end - if new_jac && isnewton(lcache) + if new_jac && (isnewton(lcache)) lcache.J_t = t if isdae lcache.uf.α = nlsolver.α diff --git a/lib/OrdinaryDiffEqNonlinearSolve/src/newton.jl b/lib/OrdinaryDiffEqNonlinearSolve/src/newton.jl index 4b881a2f8f..3f9b441905 100644 --- a/lib/OrdinaryDiffEqNonlinearSolve/src/newton.jl +++ b/lib/OrdinaryDiffEqNonlinearSolve/src/newton.jl @@ -126,13 +126,17 @@ end @unpack tstep, invγdt, atmp, ustep = cache new_jac, new_W = do_newJW(integrator, integrator.alg, nlsolver, false) + cache.new_W = new_W + @show new_jac, new_W if is_always_new(nlsolver) || new_jac || new_W + cache.W_γdt = γ*dt + cache.J_t = t recompute_jacobian = true else recompute_jacobian = false end - nlcache = nlsolver.cache.cache + nlcache = cache.cache nlstep_data = integrator.f.nlstep_data step!(nlcache; recompute_jacobian) diff --git a/lib/OrdinaryDiffEqNonlinearSolve/src/nlsolve.jl b/lib/OrdinaryDiffEqNonlinearSolve/src/nlsolve.jl index 4bec8ffe3b..e4c209e916 100644 --- a/lib/OrdinaryDiffEqNonlinearSolve/src/nlsolve.jl +++ b/lib/OrdinaryDiffEqNonlinearSolve/src/nlsolve.jl @@ -154,7 +154,7 @@ function postamble!(nlsolver::NLSolver, integrator::SciMLBase.DEIntegrator) end integrator.force_stepfail = nlsolvefail(nlsolver) setfirststage!(nlsolver, false) - isnewton(nlsolver) && (nlsolver.cache.firstcall = false) + (isnewton(nlsolver) || isnonlinearsolve(nlsolver)) && (nlsolver.cache.firstcall = false) nlsolver.z end diff --git a/lib/OrdinaryDiffEqNonlinearSolve/src/type.jl b/lib/OrdinaryDiffEqNonlinearSolve/src/type.jl index 05d89a349b..fae10dac43 100644 --- a/lib/OrdinaryDiffEqNonlinearSolve/src/type.jl +++ b/lib/OrdinaryDiffEqNonlinearSolve/src/type.jl @@ -215,8 +215,13 @@ mutable struct NonlinearSolveCache{uType, tType, rateType, tType2, P, C} <: tstep::tType k::rateType atmp::uType - invγdt::tType2 prob::P cache::C new_W::Bool + firststage::Bool + firstcall::Bool + W_γdt::tType + invγdt::tType2 + new_W_γdt_cutoff::tType + J_t::tType end diff --git a/lib/OrdinaryDiffEqNonlinearSolve/src/utils.jl b/lib/OrdinaryDiffEqNonlinearSolve/src/utils.jl index 466c18dbd3..2941498b7e 100644 --- a/lib/OrdinaryDiffEqNonlinearSolve/src/utils.jl +++ b/lib/OrdinaryDiffEqNonlinearSolve/src/utils.jl @@ -26,7 +26,7 @@ isJcurrent(nlsolver::AbstractNLSolver, integrator) = integrator.t == nlsolver.ca isfirstcall(nlsolver::AbstractNLSolver) = nlsolver.cache.firstcall isfirststage(nlsolver::AbstractNLSolver) = nlsolver.cache.firststage setfirststage!(nlsolver::AbstractNLSolver, val::Bool) = setfirststage!(nlsolver.cache, val) -function setfirststage!(nlcache::Union{NLNewtonCache, NLNewtonConstantCache}, val::Bool) +function setfirststage!(nlcache::Union{NLNewtonCache, NLNewtonConstantCache, NonlinearSolveCache}, val::Bool) (nlcache.firststage = val) end setfirststage!(::Any, val::Bool) = nothing @@ -37,9 +37,9 @@ getnfails(nlsolver::AbstractNLSolver) = nlsolver.nfails set_new_W!(nlsolver::AbstractNLSolver, val::Bool)::Bool = set_new_W!(nlsolver.cache, val) set_new_W!(nlcache::Union{NLNewtonCache, NLNewtonConstantCache, NonlinearSolveCache}, val::Bool)::Bool = nlcache.new_W = val +set_new_W!(nlcache::AbstractNLSolverCache, val::Bool)::Bool = nothing get_new_W!(nlsolver::AbstractNLSolver)::Bool = get_new_W!(nlsolver.cache) -get_new_W!(nlcache::Union{NLNewtonCache, NLNewtonConstantCache, NonlinearSolveCache})::Bool = nlcache.new_W -get_new_W!(::AbstractNLSolverCache)::Bool = true +get_new_W!(::AbstractNLSolverCache)::Bool = nlcache.new_W get_W(nlsolver::AbstractNLSolver) = get_W(nlsolver.cache) get_W(nlcache::Union{NLNewtonCache, NLNewtonConstantCache}) = nlcache.W @@ -243,7 +243,8 @@ function build_nlsolver( NonlinearProblem(NonlinearFunction{true}(nlf), ztmp, nlp_params) end cache = init(prob, nlalg.alg) - nlcache = NonlinearSolveCache(ustep, tstep, k, atmp, invγdt, prob, cache, true) + nlcache = NonlinearSolveCache(ustep, tstep, k, atmp, prob, cache, + true, true, true, tType(dt), invγdt, tType(nlalg.new_W_dt_cutoff), t) else nlcache = NLNewtonCache(ustep, tstep, k, atmp, dz, J, W, true, true, true, tType(dt), du1, uf, jac_config, @@ -330,8 +331,8 @@ function build_nlsolver( end prob = NonlinearProblem(NonlinearFunction{false}(nlf), copy(ztmp), nlp_params) cache = init(prob, nlalg.alg) - nlcache = NonlinearSolveCache( - nothing, tstep, nothing, nothing, invγdt, prob, cache, true) + nlcache = NonlinearSolveCache(nothing, tstep, nothing, nothing, prob, cache, + true, true, true, tType(dt), invγdt, tType(nlalg.new_W_dt_cutoff), t) else nlcache = NLNewtonConstantCache(tstep, J, W, true, true, true, tType(dt), uf, invγdt, tType(nlalg.new_W_dt_cutoff), t) diff --git a/lib/OrdinaryDiffEqSDIRK/src/kencarp_kvaerno_perform_step.jl b/lib/OrdinaryDiffEqSDIRK/src/kencarp_kvaerno_perform_step.jl index 20d497abff..9f8ecb0d57 100644 --- a/lib/OrdinaryDiffEqSDIRK/src/kencarp_kvaerno_perform_step.jl +++ b/lib/OrdinaryDiffEqSDIRK/src/kencarp_kvaerno_perform_step.jl @@ -88,7 +88,7 @@ end nlsolver.c = γ z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) nlsolvefail(nlsolver) && return - isnewton(nlsolver) && set_new_W!(nlsolver, false) + set_new_W!(nlsolver, false) ################################## Solve Step 3 @@ -315,7 +315,7 @@ end nlsolver.c = 2γ z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) nlsolvefail(nlsolver) && return - isnewton(nlsolver) && set_new_W!(nlsolver, false) + set_new_W!(nlsolver, false) ################################## Solve Step 3 @@ -554,7 +554,7 @@ end nlsolver.c = c2 z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) nlsolvefail(nlsolver) && return - isnewton(nlsolver) && set_new_W!(nlsolver, false) + set_new_W!(nlsolver, false) ################################## Solve Step 3 @@ -719,7 +719,7 @@ end nlsolver.c = γ z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) nlsolvefail(nlsolver) && return - isnewton(nlsolver) && set_new_W!(nlsolver, false) + set_new_W!(nlsolver, false) ################################## Solve Step 3 @@ -1008,7 +1008,7 @@ end markfirststage!(nlsolver) z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) nlsolvefail(nlsolver) && return - isnewton(nlsolver) && set_new_W!(nlsolver, false) + set_new_W!(nlsolver, false) ################################## Solve Step 3 @@ -1262,7 +1262,7 @@ end nlsolver.c = γ z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) nlsolvefail(nlsolver) && return - isnewton(nlsolver) && set_new_W!(nlsolver, false) + set_new_W!(nlsolver, false) ################################## Solve Step 3 @@ -1612,7 +1612,7 @@ end nlsolver.c = 2γ z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) nlsolvefail(nlsolver) && return - isnewton(nlsolver) && set_new_W!(nlsolver, false) + set_new_W!(nlsolver, false) ################################## Solve Step 3 @@ -2028,7 +2028,7 @@ end nlsolver.c = 2γ z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) nlsolvefail(nlsolver) && return - isnewton(nlsolver) && set_new_W!(nlsolver, false) + set_new_W!(nlsolver, false) ################################## Solve Step 3 @@ -2449,7 +2449,7 @@ end nlsolver.c = 2γ z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) nlsolvefail(nlsolver) && return - isnewton(nlsolver) && set_new_W!(nlsolver, false) + set_new_W!(nlsolver, false) ################################## Solve Step 3 diff --git a/lib/OrdinaryDiffEqSDIRK/src/sdirk_perform_step.jl b/lib/OrdinaryDiffEqSDIRK/src/sdirk_perform_step.jl index a0f33020dc..59fa139f52 100644 --- a/lib/OrdinaryDiffEqSDIRK/src/sdirk_perform_step.jl +++ b/lib/OrdinaryDiffEqSDIRK/src/sdirk_perform_step.jl @@ -393,7 +393,6 @@ end @unpack t, dt, uprev, u, f, p = integrator @unpack zprev, zᵧ, atmp, nlsolver, step_limiter! = cache @unpack z, tmp = nlsolver - W = isnewton(nlsolver) ? get_W(nlsolver) : nothing b = nlsolver.ztmp @unpack γ, d, ω, btilde1, btilde2, btilde3, α1, α2 = cache.tab alg = unwrap_alg(integrator, true) @@ -418,7 +417,7 @@ end @.. broadcast=false z=α1 * zprev + α2 * zᵧ @.. broadcast=false tmp=uprev + ω * zprev + ω * zᵧ nlsolver.c = 1 - isnewton(nlsolver) && set_new_W!(nlsolver, false) + set_new_W!(nlsolver, false) nlsolve!(nlsolver, integrator, cache, repeat_step) nlsolvefail(nlsolver) && return @@ -451,7 +450,6 @@ end @unpack t, dt, uprev, u, f, p = integrator @unpack zprev, zᵧ, atmp, nlsolver, step_limiter! = cache @unpack z, tmp = nlsolver - W = isnewton(nlsolver) ? get_W(nlsolver) : nothing b = nlsolver.ztmp @unpack γ, d, ω, btilde1, btilde2, btilde3, α1, α2 = cache.tab alg = unwrap_alg(integrator, true) @@ -484,7 +482,7 @@ end tmp[i] = uprev[i] + ω * zprev[i] + ω * zᵧ[i] end nlsolver.c = 1 - isnewton(nlsolver) && set_new_W!(nlsolver, false) + set_new_W!(nlsolver, false) nlsolve!(nlsolver, integrator, cache, repeat_step) nlsolvefail(nlsolver) && return @@ -576,7 +574,6 @@ end @unpack t, dt, uprev, u, f, p = integrator @unpack z₁, z₂, atmp, nlsolver, step_limiter! = cache @unpack tmp = nlsolver - W = isnewton(nlsolver) ? get_W(nlsolver) : nothing alg = unwrap_alg(integrator, true) markfirststage!(nlsolver) @@ -602,7 +599,7 @@ end ### Initial Guess Is α₁ = c₂/γ, c₂ = 0 => z₂ = α₁z₁ = 0 z₂ .= zero(eltype(u)) nlsolver.z = z₂ - isnewton(nlsolver) && set_new_W!(nlsolver, false) + set_new_W!(nlsolver, false) @.. broadcast=false tmp=uprev - z₁ nlsolver.tmp = tmp z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) @@ -873,7 +870,7 @@ end @.. broadcast=false tmp=uprev + z₁ / 2 nlsolver.tmp = tmp - isnewton(nlsolver) && set_new_W!(nlsolver, false) + set_new_W!(nlsolver, false) z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) nlsolvefail(nlsolver) && return @@ -989,7 +986,6 @@ end @unpack t, dt, uprev, u, f, p = integrator @unpack z₁, z₂, z₃, z₄, z₅, atmp, nlsolver = cache @unpack tmp = nlsolver - W = isnewton(nlsolver) ? get_W(nlsolver) : nothing @unpack γ, a21, a31, a32, a41, a42, a43, a51, a52, a53, a54, c2, c3, c4 = cache.tab @unpack b1hat1, b2hat1, b3hat1, b4hat1, b1hat2, b2hat2, b3hat2, b4hat2 = cache.tab alg = unwrap_alg(integrator, true) @@ -1015,7 +1011,7 @@ end @.. broadcast=false tmp=uprev + a21 * z₁ nlsolver.tmp = tmp - isnewton(nlsolver) && set_new_W!(nlsolver, false) + set_new_W!(nlsolver, false) nlsolver.c = c2 z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) nlsolvefail(nlsolver) && return @@ -1154,7 +1150,6 @@ end @unpack t, dt, uprev, u, f, p = integrator @unpack z₁, z₂, z₃, z₄, nlsolver = cache @unpack tmp = nlsolver - W = isnewton(nlsolver) ? get_W(nlsolver) : nothing @unpack γ, a21, a31, a32, a41, a42, a43, a51, a52, a53, a54, c2, c3, c4 = cache.tab alg = unwrap_alg(integrator, true) markfirststage!(nlsolver) @@ -1178,7 +1173,7 @@ end @.. broadcast=false tmp=uprev + a21 * z₁ nlsolver.tmp = tmp - isnewton(nlsolver) && set_new_W!(nlsolver, false) + set_new_W!(nlsolver, false) nlsolver.c = c2 z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) nlsolvefail(nlsolver) && return @@ -1287,7 +1282,6 @@ end @unpack t, dt, uprev, u, f, p = integrator @unpack z₁, z₂, z₃, z₄, z₅, nlsolver = cache @unpack tmp = nlsolver - W = isnewton(nlsolver) ? get_W(nlsolver) : nothing @unpack γ, a21, a31, a32, a41, a42, a43, a51, a52, a53, a54, a61, a62, a63, a64, a65, c2, c3, c4, c5 = cache.tab alg = unwrap_alg(integrator, true) markfirststage!(nlsolver) @@ -1311,7 +1305,7 @@ end @.. broadcast=false tmp=uprev + a21 * z₁ nlsolver.tmp = tmp - isnewton(nlsolver) && set_new_W!(nlsolver, false) + set_new_W!(nlsolver, false) nlsolver.c = c2 z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) nlsolvefail(nlsolver) && return @@ -1441,7 +1435,6 @@ end @unpack t, dt, uprev, u, f, p = integrator @unpack z₁, z₂, z₃, z₄, z₅, z₆, nlsolver = cache @unpack tmp = nlsolver - W = isnewton(nlsolver) ? get_W(nlsolver) : nothing @unpack γ, a21, a31, a32, a41, a42, a43, a51, a52, a53, a54, a61, a62, a63, a64, a65, a71, a72, a73, a74, a75, a76, c2, c3, c4, c5, c6 = cache.tab alg = unwrap_alg(integrator, true) markfirststage!(nlsolver) @@ -1465,7 +1458,7 @@ end @.. broadcast=false tmp=uprev + a21 * z₁ nlsolver.tmp = tmp - isnewton(nlsolver) && set_new_W!(nlsolver, false) + set_new_W!(nlsolver, false) nlsolver.c = c2 z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) nlsolvefail(nlsolver) && return @@ -1619,7 +1612,6 @@ end @unpack t, dt, uprev, u, f, p = integrator @unpack z₁, z₂, z₃, z₄, z₅, z₆, z₇, nlsolver = cache @unpack tmp = nlsolver - W = isnewton(nlsolver) ? get_W(nlsolver) : nothing @unpack γ, a21, a31, a32, a41, a42, a43, a51, a52, a53, a54, a61, a62, a63, a64, a65, a71, a72, a73, a74, a75, a76, a81, a82, a83, a84, a85, a86, a87, c2, c3, c4, c5, c6, c7 = cache.tab alg = unwrap_alg(integrator, true) markfirststage!(nlsolver) @@ -1643,7 +1635,7 @@ end @.. broadcast=false tmp=uprev + a21 * z₁ nlsolver.tmp = tmp - isnewton(nlsolver) && set_new_W!(nlsolver, false) + set_new_W!(nlsolver, false) nlsolver.c = c2 z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) nlsolvefail(nlsolver) && return @@ -1822,7 +1814,6 @@ end @unpack t, dt, uprev, u, f, p = integrator @unpack z₁, z₂, z₃, z₄, z₅, z₆, z₇, z₈, nlsolver = cache @unpack tmp = nlsolver - W = isnewton(nlsolver) ? get_W(nlsolver) : nothing @unpack γ, a21, a31, a32, a41, a42, a43, a51, a52, a53, a54, a61, a62, a63, a64, a65, a71, a72, a73, a74, a75, a76, a81, a82, a83, a84, a85, a86, a87, a91, a92, a93, a94, a95, a96, a97, a98, c2, c3, c4, c5, c6, c7, c8 = cache.tab alg = unwrap_alg(integrator, true) markfirststage!(nlsolver) @@ -1846,7 +1837,7 @@ end @.. broadcast=false tmp=uprev + a21 * z₁ nlsolver.tmp = tmp - isnewton(nlsolver) && set_new_W!(nlsolver, false) + set_new_W!(nlsolver, false) nlsolver.c = c2 z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) nlsolvefail(nlsolver) && return @@ -2040,7 +2031,7 @@ end @.. broadcast=false tmp=uprev + a21 * z₁ nlsolver.tmp = tmp nlsolver.c = c2 - isnewton(nlsolver) && set_new_W!(nlsolver, false) + set_new_W!(nlsolver, false) z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) nlsolvefail(nlsolver) && return @@ -2239,7 +2230,7 @@ end nlsolver.c = 2γ z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) nlsolvefail(nlsolver) && return - isnewton(nlsolver) && set_new_W!(nlsolver, false) + set_new_W!(nlsolver, false) ################################## Solve Step 3 @@ -2433,7 +2424,7 @@ end nlsolver.c = 2γ z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) nlsolvefail(nlsolver) && return - isnewton(nlsolver) && set_new_W!(nlsolver, false) + set_new_W!(nlsolver, false) ################################## Solve Step 3 @@ -2618,7 +2609,7 @@ end nlsolver.c = 2γ z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) nlsolvefail(nlsolver) && return - isnewton(nlsolver) && set_new_W!(nlsolver, false) + set_new_W!(nlsolver, false) ################################## Solve Step 3 @@ -2813,7 +2804,7 @@ end nlsolver.c = 2γ z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) nlsolvefail(nlsolver) && return - isnewton(nlsolver) && set_new_W!(nlsolver, false) + set_new_W!(nlsolver, false) ################################## Solve Step 3 @@ -3029,7 +3020,7 @@ end nlsolver.c = 2γ z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) nlsolvefail(nlsolver) && return - isnewton(nlsolver) && set_new_W!(nlsolver, false) + set_new_W!(nlsolver, false) ################################## Solve Step 3