Skip to content

Commit 40da85c

Browse files
authored
Merge pull request #305 from SciML/qqy/nested_nlsolve_kwargs
2 parents 8742e25 + 5629136 commit 40da85c

File tree

7 files changed

+82
-36
lines changed

7 files changed

+82
-36
lines changed

docs/src/solvers/firk.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@ solve(prob::BVProblem, alg, dt; kwargs...)
1212
solve(prob::TwoPointBVProblem, alg, dt; kwargs...)
1313
```
1414

15-
!!! note "Nested nonlinear solving in FIRK methods"
16-
17-
When encountered with large BVP system, setting `nested_nlsolve` to `true` enables FIRK methods to use nested nonlinear solving for the implicit FIRK step instead of solving as a part of the global residual(when default as `nested_nlsolve=false`),
15+
## Nested nonlinear solving in FIRK methods
16+
17+
When working with large boundary value problems, especially those involving stiff systems, computational efficiency and solver robustness become critical concerns. To improve the efficiency of FIRK methods on large BVPs, we can use nested nonlinear solving to obtain the implicit FIRK step instead of solving them as part of the global residual. In BoundaryValueDiffEq.jl, we can set `nested_nlsolve` as `true` to enable FIRK methods to compute the implicit FIRK steps using nested nonlinear solving(default option in FIRK methods is `nested_nlsolve=false`).
18+
19+
Moreover, the nested nonlinear problem solver can be finely tuned to meet specific accuracy requirements by providing detailed keyword arguments through the `nested_nlsolve_kwargs` option in any FIRK solver, for example, `RadauIIa5(; nested_nlsolve = true, nested_nlsolve_kwargs = (; abstol = 1e-6, reltol = 1e-6))`, where `nested_nlsolve_kwargs` can be any common keyword arguments in NonlinearSolve.jl, see [Common Solver Options in NonlinearSolve.jl](https://docs.sciml.ai/NonlinearSolve/stable/basics/solve/).
1820

1921
## Full List of Methods
2022

lib/BoundaryValueDiffEqFIRK/src/adaptivity.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ end
4949

5050
@views function interp_eval!(
5151
y::AbstractArray, cache::FIRKCacheNested{iip, T}, t, mesh, mesh_dt) where {iip, T}
52-
(; f, ITU, nest_prob, nest_tol, alg) = cache
52+
(; f, ITU, nest_prob, alg) = cache
5353
(; q_coeff) = ITU
5454

5555
j = interval(mesh, t)
@@ -82,7 +82,7 @@ end
8282
nestprob_p[3:end] .= yᵢ
8383

8484
_nestprob = remake(nest_prob, p = nestprob_p)
85-
nestsol = __solve(_nestprob, nest_nlsolve_alg; abstol = nest_tol)
85+
nestsol = __solve(_nestprob, nest_nlsolve_alg; alg.nested_nlsolve_kwargs...)
8686
K = nestsol.u
8787

8888
z₁, z₁′ = eval_q(yᵢ, 0.5, h, q_coeff, K) # Evaluate q(x) at midpoints
@@ -325,7 +325,7 @@ an interpolant
325325
end
326326

327327
@views function defect_estimate!(cache::FIRKCacheNested{iip, T}) where {iip, T}
328-
(; f, mesh, mesh_dt, defect, ITU, nest_prob, nest_tol) = cache
328+
(; f, mesh, mesh_dt, defect, ITU, nest_prob, alg) = cache
329329
(; q_coeff, τ_star) = ITU
330330

331331
nlsolve_alg = __concrete_nonlinearsolve_algorithm(nest_prob, cache.alg.nlsolve)
@@ -347,7 +347,7 @@ end
347347
nestprob_p[3:end] .= yᵢ₁
348348

349349
_nestprob = remake(nest_prob, p = nestprob_p)
350-
nest_sol = __solve(_nestprob, nlsolve_alg; abstol = nest_tol)
350+
nest_sol = __solve(_nestprob, nlsolve_alg; alg.nested_nlsolve_kwargs...)
351351

352352
# Defect estimate from q(x) at y_i + τ* * h
353353
z₁, z₁′ = eval_q(yᵢ₁, τ_star, h, q_coeff, nest_sol.u)

lib/BoundaryValueDiffEqFIRK/src/algorithms.jl

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -84,13 +84,13 @@ for stage in (1, 2, 3, 5, 7)
8484
nlsolve::N = nothing
8585
jac_alg::J = BVPJacobianAlgorithm()
8686
nested_nlsolve::Bool = false
87-
nest_tol::Union{Number, Nothing} = nothing
87+
nested_nlsolve_kwargs::NamedTuple = (;)
8888
defect_threshold::T = 0.1
8989
max_num_subintervals::Int = 3000
9090
end
91-
$(alg)(nlsolve::N, jac_alg::J; nested = false, nest_tol::Union{Number, Nothing} = nothing, defect_threshold::T = 0.1, max_num_subintervals::Int = 3000) where {N, J, T} = $(alg){
92-
N, J, T}(
93-
nlsolve, jac_alg, nested, nest_tol, defect_threshold, max_num_subintervals)
91+
$(alg)(nlsolve::N, jac_alg::J; nested = false, nested_nlsolve_kwargs::NamedTuple = (;), defect_threshold::T = 0.1, max_num_subintervals::Int = 3000) where {N, J, T} = $(alg){
92+
N, J, T}(nlsolve, jac_alg, nested, nested_nlsolve_kwargs,
93+
defect_threshold, max_num_subintervals)
9494
end
9595
end
9696

@@ -178,13 +178,13 @@ for stage in (2, 3, 4, 5)
178178
nlsolve::N = nothing
179179
jac_alg::J = BVPJacobianAlgorithm()
180180
nested_nlsolve::Bool = false
181-
nest_tol::Union{Number, Nothing} = nothing
181+
nested_nlsolve_kwargs::NamedTuple = (;)
182182
defect_threshold::T = 0.1
183183
max_num_subintervals::Int = 3000
184184
end
185-
$(alg)(nlsolve::N, jac_alg::J; nested = false, nest_tol::Union{Number, Nothing} = nothing, defect_threshold::T = 0.1, max_num_subintervals::Int = 3000) where {N, J, T} = $(alg){
186-
N, J, T}(
187-
nlsolve, jac_alg, nested, nest_tol, defect_threshold, max_num_subintervals)
185+
$(alg)(nlsolve::N, jac_alg::J; nested = false, nested_nlsolve_kwargs::NamedTuple = (;), defect_threshold::T = 0.1, max_num_subintervals::Int = 3000) where {N, J, T} = $(alg){
186+
N, J, T}(nlsolve, jac_alg, nested, nested_nlsolve_kwargs,
187+
defect_threshold, max_num_subintervals)
188188
end
189189
end
190190

@@ -272,13 +272,13 @@ for stage in (2, 3, 4, 5)
272272
nlsolve::N = nothing
273273
jac_alg::J = BVPJacobianAlgorithm()
274274
nested_nlsolve::Bool = false
275-
nest_tol::Union{Number, Nothing} = nothing
275+
nested_nlsolve_kwargs::NamedTuple = (;)
276276
defect_threshold::T = 0.1
277277
max_num_subintervals::Int = 3000
278278
end
279-
$(alg)(nlsolve::N, jac_alg::J; nested = false, nest_tol::Union{Number, Nothing} = nothing, defect_threshold::T = 0.1, max_num_subintervals::Int = 3000) where {N, J, T} = $(alg){
280-
N, J, T}(
281-
nlsolve, jac_alg, nested, nest_tol, defect_threshold, max_num_subintervals)
279+
$(alg)(nlsolve::N, jac_alg::J; nested = false, nested_nlsolve_kwargs::NamedTuple = (;), defect_threshold::T = 0.1, max_num_subintervals::Int = 3000) where {N, J, T} = $(alg){
280+
N, J, T}(nlsolve, jac_alg, nested, nested_nlsolve_kwargs,
281+
defect_threshold, max_num_subintervals)
282282
end
283283
end
284284

@@ -366,13 +366,13 @@ for stage in (2, 3, 4, 5)
366366
nlsolve::N = nothing
367367
jac_alg::J = BVPJacobianAlgorithm()
368368
nested_nlsolve::Bool = false
369-
nest_tol::Union{Number, Nothing} = nothing
369+
nested_nlsolve_kwargs::NamedTuple = (;)
370370
defect_threshold::T = 0.1
371371
max_num_subintervals::Int = 3000
372372
end
373-
$(alg)(nlsolve::N, jac_alg::J; nested = false, nest_tol::Union{Number, Nothing} = nothing, defect_threshold::T = 0.1, max_num_subintervals::Int = 3000) where {N, J, T} = $(alg){
374-
N, J, T}(
375-
nlsolve, jac_alg, nested, nest_tol, defect_threshold, max_num_subintervals)
373+
$(alg)(nlsolve::N, jac_alg::J; nested = false, nested_nlsolve_kwargs::NamedTuple = (;), defect_threshold::T = 0.1, max_num_subintervals::Int = 3000) where {N, J, T} = $(alg){
374+
N, J, T}(nlsolve, jac_alg, nested, nested_nlsolve_kwargs,
375+
defect_threshold, max_num_subintervals)
376376
end
377377
end
378378

lib/BoundaryValueDiffEqFIRK/src/collocation.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,11 @@ end
8383
@views function Φ!(residual, fᵢ_cache, k_discrete, f!, TU::FIRKTableau{true},
8484
y, u, p, mesh, mesh_dt, stage::Int, cache)
8585
(; b) = TU
86-
(; nest_prob, nest_tol) = cache
86+
(; nest_prob, alg) = cache
8787

8888
T = eltype(u)
8989
nestprob_p = vcat(T(mesh[1]), T(mesh_dt[1]), get_tmp(y[1], u))
90-
nest_nlsolve_alg = __concrete_nonlinearsolve_algorithm(nest_prob, cache.alg.nlsolve)
90+
nest_nlsolve_alg = __concrete_nonlinearsolve_algorithm(nest_prob, alg.nlsolve)
9191

9292
for i in eachindex(k_discrete)
9393
residᵢ = residual[i]
@@ -103,7 +103,7 @@ end
103103
K = get_tmp(k_discrete[i], u)
104104

105105
_nestprob = remake(nest_prob, p = nestprob_p)
106-
nestsol = solve(_nestprob, nest_nlsolve_alg; abstol = nest_tol)
106+
nestsol = solve(_nestprob, nest_nlsolve_alg; alg.nested_nlsolve_kwargs...)
107107
@. K = nestsol.u
108108
@. residᵢ = yᵢ₊₁ - yᵢ
109109
__maybe_matmul!(residᵢ, nestsol.u, b, -h, T(1))
@@ -159,7 +159,7 @@ end
159159
@views function Φ(fᵢ_cache, k_discrete, f!, TU::FIRKTableau{true},
160160
y, u, p, mesh, mesh_dt, stage::Int, cache)
161161
(; b) = TU
162-
(; nest_prob, alg, nest_tol) = cache
162+
(; nest_prob, alg) = cache
163163

164164
residuals = [safe_similar(yᵢ) for yᵢ in y[1:(end - 1)]]
165165

@@ -179,7 +179,7 @@ end
179179
nestprob_p[3:end] = yᵢ
180180

181181
_nestprob = remake(nest_prob, p = nestprob_p)
182-
nestsol = solve(_nestprob, nest_nlsolve_alg, abstol = nest_tol)
182+
nestsol = solve(_nestprob, nest_nlsolve_alg; alg.nested_nlsolve_kwargs...)
183183

184184
@. residᵢ = yᵢ₊₁ - yᵢ
185185
__maybe_matmul!(residᵢ, nestsol.u, b, -h, T(1))

lib/BoundaryValueDiffEqFIRK/src/firk.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
fᵢ₂_cache
2525
defect
2626
nest_prob
27-
nest_tol
2827
resid_size
2928
kwargs
3029
end
@@ -165,7 +164,6 @@ function init_nested(prob::BVProblem, alg::AbstractFIRK; dt = 0.0, abstol = 1e-3
165164
K0 = __K0_on_u0(prob.u0, stage) # Somewhat arbitrary initialization of K
166165

167166
nestprob_p = zeros(T, M + 2)
168-
nest_tol = alg.nest_tol
169167

170168
if iip
171169
nestprob = NonlinearProblem(
@@ -176,10 +174,9 @@ function init_nested(prob::BVProblem, alg::AbstractFIRK; dt = 0.0, abstol = 1e-3
176174
end
177175

178176
return FIRKCacheNested{iip, T}(
179-
alg_order(alg), stage, M, size(X), f, bc, prob_, prob.problem_type,
180-
prob.p, alg, TU, ITU, bcresid_prototype, mesh, mesh_dt,
181-
k_discrete, y, y₀, residual, fᵢ_cache, fᵢ₂_cache, defect, nestprob,
182-
nest_tol, resid₁_size, (; abstol, dt, adaptive, kwargs...))
177+
alg_order(alg), stage, M, size(X), f, bc, prob_, prob.problem_type, prob.p, alg,
178+
TU, ITU, bcresid_prototype, mesh, mesh_dt, k_discrete, y, y₀, residual, fᵢ_cache,
179+
fᵢ₂_cache, defect, nestprob, resid₁_size, (; abstol, dt, adaptive, kwargs...))
183180
end
184181

185182
function init_expanded(prob::BVProblem, alg::AbstractFIRK; dt = 0.0, abstol = 1e-3,

lib/BoundaryValueDiffEqFIRK/src/interpolation.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ nodual_value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x)
180180
# Nested FIRK
181181
function (s::EvalSol{C})(tval::Number) where {C <: FIRKCacheNested}
182182
(; t, u, cache) = s
183-
(; f, nest_prob, nest_tol, alg, mesh_dt, p, ITU) = cache
183+
(; f, nest_prob, alg, mesh_dt, p, ITU) = cache
184184
(; q_coeff) = ITU
185185
stage = alg_stage(alg)
186186
# Quick handle for the case where tval is at the boundary
@@ -212,7 +212,7 @@ function (s::EvalSol{C})(tval::Number) where {C <: FIRKCacheNested}
212212
nestprob_p[3:end] .= nodual_value(yᵢ)
213213

214214
_nestprob = remake(nest_prob, p = nestprob_p)
215-
nestsol = __solve(_nestprob, nest_nlsolve_alg; abstol = nest_tol)
215+
nestsol = __solve(_nestprob, nest_nlsolve_alg; alg.nested_nlsolve_kwargs...)
216216
K = nestsol.u
217217

218218
z₁, z₁′ = eval_q(yᵢ, 0.5, h, q_coeff, K) # Evaluate q(x) at midpoints

lib/BoundaryValueDiffEqFIRK/test/nested/firk_basic_tests.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,3 +443,50 @@ end =#
443443
(0, pi / 2), pi / 2; bcresid_prototype = (zeros(1), zeros(1)))
444444
SciMLBase.successful_retcode(solve(bvp5, RadauIIa5(; nested_nlsolve = true), dt = 0.05))
445445
end
446+
447+
@testitem "Nested nlsolve kwargs in FIRK" setup=[FIRKNestedConvergenceTests] begin
448+
tspan = (0.0, π / 2)
449+
function simplependulum!(du, u, p, t)
450+
g, L, θ, dθ = 9.81, 1.0, u[1], u[2]
451+
du[1] =
452+
du[2] = -(g / L) * sin(θ)
453+
end
454+
455+
function bc_pendulum!(residual, u, p, t)
456+
residual[1] = u(pi / 4)[1] + π / 2
457+
residual[2] = u(pi / 2)[1] - π / 2
458+
end
459+
460+
u0 = [pi / 2, pi / 2]
461+
prob = BVProblem(simplependulum!, bc_pendulum!, u0, tspan)
462+
nested = true
463+
nested_nlsolve_kwargs = (; abstol = 1e-6, reltol = 1e-6)
464+
465+
@testset "RadauIIa$stage" for stage in (2, 3, 5, 7)
466+
@test_nowarn solve(prob,
467+
radau_solver(Val(stage); nested_nlsolve = nested,
468+
nested_nlsolve_kwargs = nested_nlsolve_kwargs);
469+
dt = 0.005)
470+
end
471+
472+
@testset "LobattoIIIa$stage" for stage in (3, 4, 5)
473+
@test_nowarn solve(prob,
474+
lobattoIIIa_solver(Val(stage); nested_nlsolve = nested,
475+
nested_nlsolve_kwargs = nested_nlsolve_kwargs);
476+
dt = 0.005)
477+
end
478+
479+
@testset "LobattoIIIb$stage" for stage in (3, 4, 5)
480+
@test_nowarn solve(prob,
481+
lobattoIIIb_solver(Val(stage); nested_nlsolve = nested,
482+
nested_nlsolve_kwargs = nested_nlsolve_kwargs);
483+
dt = 0.005)
484+
end
485+
486+
@testset "LobattoIIIc$stage" for stage in (3, 4, 5)
487+
@test_nowarn solve(prob,
488+
lobattoIIIb_solver(Val(stage); nested_nlsolve = nested,
489+
nested_nlsolve_kwargs = nested_nlsolve_kwargs);
490+
dt = 0.005)
491+
end
492+
end

0 commit comments

Comments
 (0)