Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/OrdinaryDiffEqBDF/src/OrdinaryDiffEqBDF.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ using ArrayInterface: ismutable
import OrdinaryDiffEqCore
using OrdinaryDiffEqDifferentiation: UJacobianWrapper
using OrdinaryDiffEqNonlinearSolve: NLNewton, du_alias_or_new, build_nlsolver,
nlsolve!, nlsolvefail, isnewton, markfirststage!,
nlsolve!, nlsolvefail, markfirststage!,
set_new_W!, DIRK, compute_step!, COEFFICIENT_MULTISTEP,
NonlinearSolveAlg
import ADTypes: AutoForwardDiff, AutoFiniteDiff, AbstractADType
Expand Down
8 changes: 4 additions & 4 deletions lib/OrdinaryDiffEqBDF/src/bdf_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -985,15 +985,15 @@ end
### STEP 2
nlsolver.tmp = z₁
nlsolver.c = 2
isnewton(nlsolver) && set_new_W!(nlsolver, false)
set_new_W!(nlsolver, false)
z = nlsolve!(nlsolver, integrator, cache, repeat_step)
nlsolvefail(nlsolver) && return
z₂ = z₁ + z
### STEP 3
tmp2 = 0.5uprev + z₁ - 0.5z₂
nlsolver.tmp = tmp2
nlsolver.c = 1
isnewton(nlsolver) && set_new_W!(nlsolver, false)
set_new_W!(nlsolver, false)
z = nlsolve!(nlsolver, integrator, cache, repeat_step)
nlsolvefail(nlsolver) && return
u = tmp2 + z
Expand Down Expand Up @@ -1039,7 +1039,7 @@ end
### STEP 2
nlsolver.tmp = z₁
nlsolver.c = 2
isnewton(nlsolver) && set_new_W!(nlsolver, false)
set_new_W!(nlsolver, false)
z = nlsolve!(nlsolver, integrator, cache, repeat_step)
nlsolvefail(nlsolver) && return
@.. broadcast=false z₂=z₁+z
Expand All @@ -1048,7 +1048,7 @@ end
@.. broadcast=false tmp2=0.5uprev+z₁-0.5z₂
nlsolver.tmp = tmp2
nlsolver.c = 1
isnewton(nlsolver) && set_new_W!(nlsolver, false)
set_new_W!(nlsolver, false)
z = nlsolve!(nlsolver, integrator, cache, repeat_step)
nlsolvefail(nlsolver) && return
@.. broadcast=false u=tmp2+z
Expand Down
1 change: 1 addition & 0 deletions lib/OrdinaryDiffEqCore/src/misc_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ function get_differential_vars(f, u)
end

isnewton(::Any) = false
isnonlinearsolve(::Any) = false

function _bool_to_ADType(::Val{true}, ::Val{CS}, _) where {CS}
Base.depwarn(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ using OrdinaryDiffEqCore: OrdinaryDiffEqAlgorithm, OrdinaryDiffEqAdaptiveImplici
OrdinaryDiffEqAdaptiveExponentialAlgorithm, @unpack,
AbstractNLSolver, nlsolve_f, issplit,
concrete_jac, unwrap_alg, OrdinaryDiffEqCache, _vec, standardtag,
isnewton, _unwrap_val,
isnewton, isnonlinearsolve, _unwrap_val,
set_new_W!, set_W_γdt!, alg_difftype, unwrap_cache, diffdir,
get_W, isfirstcall, isfirststage, isJcurrent,
get_new_W_γdt_cutoff,
Expand Down
4 changes: 2 additions & 2 deletions lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ function do_newJW(integrator, alg, nlsolver, repeat_step)::NTuple{2, Bool}
return true, true
end
# TODO: add `isJcurrent` support for Rosenbrock solvers
if !isnewton(nlsolver)
if !isnewton(nlsolver) && !isnonlinearsolve(nlsolver)
isfreshJ = !(integrator.alg isa CompositeAlgorithm) &&
(integrator.iter > 1 && errorfail && !integrator.u_modified)
return !isfreshJ, true
Expand Down Expand Up @@ -606,7 +606,7 @@ function calc_W!(W, integrator, nlsolver::Union{Nothing, AbstractNLSolver}, cach
new_jac, new_W = newJW
end

if new_jac && isnewton(lcache)
if new_jac && (isnewton(lcache))
lcache.J_t = t
if isdae
lcache.uf.α = nlsolver.α
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,12 @@ using OrdinaryDiffEqCore: resize_nlsolver!, _initialize_dae!,

import OrdinaryDiffEqCore: _initialize_dae!, isnewton, get_W, isfirstcall, isfirststage,
isJcurrent, get_new_W_γdt_cutoff, resize_nlsolver!, apply_step!,
postamble!
postamble!, isnonlinearsolve

import OrdinaryDiffEqDifferentiation: update_W!, is_always_new, build_uf, build_J_W,
WOperator, StaticWOperator, wrapprecs,
build_jac_config, dolinsolve, alg_autodiff,
resize_jac_config!
resize_jac_config!, do_newJW

import StaticArrays: SArray, MVector, SVector, @SVector, StaticArray, MMatrix, SA,
StaticMatrix
Expand Down
24 changes: 21 additions & 3 deletions lib/OrdinaryDiffEqNonlinearSolve/src/newton.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,15 @@ end
@unpack z, tmp, ztmp, γ, α, cache, method = nlsolver
@unpack tstep, invγdt = cache

new_jac, new_W = do_newJW(integrator, integrator.alg, nlsolver, false)
if is_always_new(nlsolver) || new_jac || new_W
recompute_jacobian = true
else
recompute_jacobian = false
end

nlcache = nlsolver.cache.cache
step!(nlcache)
step!(nlcache; recompute_jacobian)
nlsolver.ztmp = nlcache.u

ustep = compute_ustep(tmp, γ, z, method)
Expand All @@ -118,9 +125,20 @@ end
@unpack z, tmp, ztmp, γ, α, cache, method = nlsolver
@unpack tstep, invγdt, atmp, ustep = cache

new_jac, new_W = do_newJW(integrator, integrator.alg, nlsolver, false)
cache.new_W = new_W
@show new_jac, new_W
if is_always_new(nlsolver) || new_jac || new_W
cache.W_γdt = γ*dt
cache.J_t = t
recompute_jacobian = true
else
recompute_jacobian = false
end

nlcache = cache.cache
nlstep_data = integrator.f.nlstep_data
nlcache = nlsolver.cache.cache
step!(nlcache)
step!(nlcache; recompute_jacobian)

if nlstep_data !== nothing
nlstepsol = SciMLBase.build_solution(
Expand Down
6 changes: 3 additions & 3 deletions lib/OrdinaryDiffEqNonlinearSolve/src/nlsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ function nlsolve!(nlsolver::NL, integrator::SciMLBase.DEIntegrator,
# don't trust θ for non-adaptive on first iter because the solver doesn't provide feedback
# for us to know whether our previous nlsolve converged sufficiently well
check_η_convergence = (iter > 1 ||
(isnewton(nlsolver) && isadaptive(integrator.alg)))
((isnewton(nlsolver) || isnonlinearsolve(nlsolver)) && isadaptive(integrator.alg)))
if (iter == 1 && ndz < 1e-5) ||
(check_η_convergence && η >= zero(η) && η * ndz < κ)
nlsolver.status = Convergence
Expand All @@ -114,7 +114,7 @@ function nlsolve!(nlsolver::NL, integrator::SciMLBase.DEIntegrator,
end
end

if isnewton(nlsolver) && nlsolver.status == Divergence &&
if (isnewton(nlsolver) || isnonlinearsolve(nlsolver)) && nlsolver.status == Divergence &&
!isJcurrent(nlsolver, integrator)
nlsolver.status = TryAgain
nlsolver.nfails += 1
Expand Down Expand Up @@ -154,7 +154,7 @@ function postamble!(nlsolver::NLSolver, integrator::SciMLBase.DEIntegrator)
end
integrator.force_stepfail = nlsolvefail(nlsolver)
setfirststage!(nlsolver, false)
isnewton(nlsolver) && (nlsolver.cache.firstcall = false)
(isnewton(nlsolver) || isnonlinearsolve(nlsolver)) && (nlsolver.cache.firstcall = false)

nlsolver.z
end
8 changes: 7 additions & 1 deletion lib/OrdinaryDiffEqNonlinearSolve/src/type.jl
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,13 @@ mutable struct NonlinearSolveCache{uType, tType, rateType, tType2, P, C} <:
tstep::tType
k::rateType
atmp::uType
invγdt::tType2
prob::P
cache::C
new_W::Bool
firststage::Bool
firstcall::Bool
W_γdt::tType
invγdt::tType2
new_W_γdt_cutoff::tType
J_t::tType
end
19 changes: 12 additions & 7 deletions lib/OrdinaryDiffEqNonlinearSolve/src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ isnewton(nlsolver::AbstractNLSolver) = isnewton(nlsolver.cache)
isnewton(::AbstractNLSolverCache) = false
isnewton(::Union{NLNewtonCache, NLNewtonConstantCache}) = true

isnonlinearsolve(nlsolver::AbstractNLSolver) = isnonlinearsolve(nlsolver.cache)
isnonlinearsolve(::AbstractNLSolverCache) = false
isnonlinearsolve(::NonlinearSolveCache) = true

is_always_new(nlsolver::AbstractNLSolver) = is_always_new(nlsolver.alg)
check_div(nlsolver::AbstractNLSolver) = check_div(nlsolver.alg)
check_div(alg) = isdefined(alg, :check_div) ? alg.check_div : true
Expand All @@ -22,7 +26,7 @@ isJcurrent(nlsolver::AbstractNLSolver, integrator) = integrator.t == nlsolver.ca
isfirstcall(nlsolver::AbstractNLSolver) = nlsolver.cache.firstcall
isfirststage(nlsolver::AbstractNLSolver) = nlsolver.cache.firststage
setfirststage!(nlsolver::AbstractNLSolver, val::Bool) = setfirststage!(nlsolver.cache, val)
function setfirststage!(nlcache::Union{NLNewtonCache, NLNewtonConstantCache}, val::Bool)
function setfirststage!(nlcache::Union{NLNewtonCache, NLNewtonConstantCache, NonlinearSolveCache}, val::Bool)
(nlcache.firststage = val)
end
setfirststage!(::Any, val::Bool) = nothing
Expand All @@ -32,10 +36,10 @@ getnfails(_) = 0
getnfails(nlsolver::AbstractNLSolver) = nlsolver.nfails

set_new_W!(nlsolver::AbstractNLSolver, val::Bool)::Bool = set_new_W!(nlsolver.cache, val)
set_new_W!(nlcache::Union{NLNewtonCache, NLNewtonConstantCache}, val::Bool)::Bool = nlcache.new_W = val
set_new_W!(nlcache::Union{NLNewtonCache, NLNewtonConstantCache, NonlinearSolveCache}, val::Bool)::Bool = nlcache.new_W = val
set_new_W!(nlcache::AbstractNLSolverCache, val::Bool)::Bool = nothing
get_new_W!(nlsolver::AbstractNLSolver)::Bool = get_new_W!(nlsolver.cache)
get_new_W!(nlcache::Union{NLNewtonCache, NLNewtonConstantCache})::Bool = nlcache.new_W
get_new_W!(::AbstractNLSolverCache)::Bool = true
get_new_W!(::AbstractNLSolverCache)::Bool = nlcache.new_W

get_W(nlsolver::AbstractNLSolver) = get_W(nlsolver.cache)
get_W(nlcache::Union{NLNewtonCache, NLNewtonConstantCache}) = nlcache.W
Expand Down Expand Up @@ -239,7 +243,8 @@ function build_nlsolver(
NonlinearProblem(NonlinearFunction{true}(nlf), ztmp, nlp_params)
end
cache = init(prob, nlalg.alg)
nlcache = NonlinearSolveCache(ustep, tstep, k, atmp, invγdt, prob, cache)
nlcache = NonlinearSolveCache(ustep, tstep, k, atmp, prob, cache,
true, true, true, tType(dt), invγdt, tType(nlalg.new_W_dt_cutoff), t)
else
nlcache = NLNewtonCache(ustep, tstep, k, atmp, dz, J, W, true,
true, true, tType(dt), du1, uf, jac_config,
Expand Down Expand Up @@ -326,8 +331,8 @@ function build_nlsolver(
end
prob = NonlinearProblem(NonlinearFunction{false}(nlf), copy(ztmp), nlp_params)
cache = init(prob, nlalg.alg)
nlcache = NonlinearSolveCache(
nothing, tstep, nothing, nothing, invγdt, prob, cache)
nlcache = NonlinearSolveCache(nothing, tstep, nothing, nothing, prob, cache,
true, true, true, tType(dt), invγdt, tType(nlalg.new_W_dt_cutoff), t)
else
nlcache = NLNewtonConstantCache(tstep, J, W, true, true, true, tType(dt), uf,
invγdt, tType(nlalg.new_W_dt_cutoff), t)
Expand Down
18 changes: 9 additions & 9 deletions lib/OrdinaryDiffEqSDIRK/src/kencarp_kvaerno_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ end
nlsolver.c = γ
z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step)
nlsolvefail(nlsolver) && return
isnewton(nlsolver) && set_new_W!(nlsolver, false)
set_new_W!(nlsolver, false)

################################## Solve Step 3

Expand Down Expand Up @@ -315,7 +315,7 @@ end
nlsolver.c = 2γ
z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step)
nlsolvefail(nlsolver) && return
isnewton(nlsolver) && set_new_W!(nlsolver, false)
set_new_W!(nlsolver, false)

################################## Solve Step 3

Expand Down Expand Up @@ -554,7 +554,7 @@ end
nlsolver.c = c2
z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step)
nlsolvefail(nlsolver) && return
isnewton(nlsolver) && set_new_W!(nlsolver, false)
set_new_W!(nlsolver, false)

################################## Solve Step 3

Expand Down Expand Up @@ -719,7 +719,7 @@ end
nlsolver.c = γ
z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step)
nlsolvefail(nlsolver) && return
isnewton(nlsolver) && set_new_W!(nlsolver, false)
set_new_W!(nlsolver, false)

################################## Solve Step 3

Expand Down Expand Up @@ -1008,7 +1008,7 @@ end
markfirststage!(nlsolver)
z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step)
nlsolvefail(nlsolver) && return
isnewton(nlsolver) && set_new_W!(nlsolver, false)
set_new_W!(nlsolver, false)

################################## Solve Step 3

Expand Down Expand Up @@ -1262,7 +1262,7 @@ end
nlsolver.c = γ
z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step)
nlsolvefail(nlsolver) && return
isnewton(nlsolver) && set_new_W!(nlsolver, false)
set_new_W!(nlsolver, false)

################################## Solve Step 3

Expand Down Expand Up @@ -1612,7 +1612,7 @@ end
nlsolver.c = 2γ
z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step)
nlsolvefail(nlsolver) && return
isnewton(nlsolver) && set_new_W!(nlsolver, false)
set_new_W!(nlsolver, false)

################################## Solve Step 3

Expand Down Expand Up @@ -2028,7 +2028,7 @@ end
nlsolver.c = 2γ
z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step)
nlsolvefail(nlsolver) && return
isnewton(nlsolver) && set_new_W!(nlsolver, false)
set_new_W!(nlsolver, false)

################################## Solve Step 3

Expand Down Expand Up @@ -2449,7 +2449,7 @@ end
nlsolver.c = 2γ
z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step)
nlsolvefail(nlsolver) && return
isnewton(nlsolver) && set_new_W!(nlsolver, false)
set_new_W!(nlsolver, false)

################################## Solve Step 3

Expand Down
Loading
Loading