Skip to content

Commit bda144a

Browse files
authored
Merge pull request #323 from SciML/qqy/merge_kwargs
Allow more kwargs
2 parents 4e5a359 + d13daba commit bda144a

File tree

7 files changed

+102
-97
lines changed

7 files changed

+102
-97
lines changed

docs/src/basics/error_control.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Error Control Adaptivity
1+
# [Error Control Adaptivity](@id error_control)
22

33
Adaptivity helps ensure the quality of the our numerical solution, and when our solution exhibits significant estimating errors, adaptivity automatically refine the mesh based on the error distribution, and providing a final satisfying solution.
44

docs/src/basics/solve.md

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
# [Common Solver Options (Solve Keyword Arguments)](@id solver_options)
22

3-
## Iteration Controls
4-
53
- `abstol::Number`: The absolute tolerance. Defaults to `1e-6`.
4+
- `adaptive::Bool`: Whether the error control adaptivity is on, default as `true`.
5+
- `controller`: Error controller for collocation methods, default as `DefectControl()`, more controller options in [Error Control Adaptivity](@ref error_control).
66
- `defect_threshold`: Monitor of the size of defect norm. Defaults to `0.1`.
7-
- `odesolve_kwargs`: OrdinaryDiffEq.jl solvers kwargs for passing to ODE solving in shooting methods.
8-
- `nlsolve_kwargs`: NonlinearSolve.jl solvers kwargs for passing to nonlinear solving in collocation methods and shootingn methods.
7+
- `odesolve_kwargs`: OrdinaryDiffEq.jl solvers kwargs for passing to ODE solving in shooting methods. For more information, see the documentation for OrdinaryDiffEq: [Common Solver Options](https://docs.sciml.ai/DiffEqDocs/latest/basics/common_solver_opts/).
8+
- `nlsolve_kwargs`: NonlinearSolve.jl solvers kwargs for passing to nonlinear solving in collocation methods and shooting methods. For more information, see the documentation for NonlinearSolve: [Commom Solver Options](https://docs.sciml.ai/NonlinearSolve/stable/basics/solve/). The default absolute tolerance of nonlinear solving in collocaio
9+
- `verbose`: Toggles whether warnings are thrown when the solver exits early. Defaults to `true`.
10+
- `ensemblealg`: Whether `MultipleShooting` uses multithreading, default as `EnsembleThreads()`. For more information, see the documentation for OrdinaryDiffEq: [EnsembleAlgorithms](https://docs.sciml.ai/DiffEqDocs/latest/features/ensemble/#EnsembleAlgorithms).

lib/BoundaryValueDiffEqAscher/src/ascher.jl

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
ipvtw
3939
TU
4040
valstr
41+
nlsolve_kwargs
4142
kwargs
4243
end
4344

@@ -57,8 +58,9 @@ function get_fixed_points(prob::BVProblem, alg::AbstractAscher)
5758
end
5859
end
5960

60-
function SciMLBase.__init(prob::BVProblem, alg::AbstractAscher; dt = 0.0,
61-
controller = GlobalErrorControl(), adaptive = true, abstol = 1e-4, kwargs...)
61+
function SciMLBase.__init(
62+
prob::BVProblem, alg::AbstractAscher; dt = 0.0, controller = GlobalErrorControl(),
63+
adaptive = true, abstol = 1e-4, nlsolve_kwargs = (; abstol = abstol), kwargs...)
6264
(; tspan, p) = prob
6365
_, T, ncy, n, u0 = __extract_problem_details(prob; dt, check_positive_dt = true)
6466
t₀, t₁ = tspan
@@ -143,25 +145,24 @@ function SciMLBase.__init(prob::BVProblem, alg::AbstractAscher; dt = 0.0,
143145

144146
g = build_almost_block_diagonals(zeta, ncomp, mesh, T)
145147
cache = AscherCache{iip, T}(
146-
prob, f, jac, bc, bcjac, k, copy(mesh), mesh, mesh_dt, ncomp, ny, p,
147-
zeta, fixpnt, alg, prob.problem_type, bcresid_prototype, residual,
148-
zval, yval, gval, err, g, w, v, lz, ly, dmz, delz, deldmz, dqdmz, dmv,
149-
pvtg, pvtw, TU, valst, (; abstol, dt, adaptive, controller, kwargs...))
148+
prob, f, jac, bc, bcjac, k, copy(mesh), mesh, mesh_dt, ncomp, ny, p, zeta,
149+
fixpnt, alg, prob.problem_type, bcresid_prototype, residual, zval, yval,
150+
gval, err, g, w, v, lz, ly, dmz, delz, deldmz, dqdmz, dmv, pvtg, pvtw, TU,
151+
valst, nlsolve_kwargs, (; abstol, dt, adaptive, controller, kwargs...))
150152
return cache
151153
end
152154

153155
function SciMLBase.solve!(cache::AscherCache{iip, T}) where {iip, T}
154-
(abstol, adaptive, _), kwargs = __split_kwargs(; cache.kwargs...)
156+
(abstol, adaptive, _), _ = __split_kwargs(; cache.kwargs...)
155157
info::ReturnCode.T = ReturnCode.Success
156158

157159
# We do the first iteration outside the loop to preserve type-stability of the
158160
# `original` field of the solution
159-
z, y, info, error_norm = __perform_ascher_iteration(cache, abstol, adaptive; kwargs...)
161+
z, y, info, error_norm = __perform_ascher_iteration(cache, abstol, adaptive)
160162

161163
if adaptive
162164
while SciMLBase.successful_retcode(info) && norm(error_norm) > abstol
163-
z, y, info, error_norm = __perform_ascher_iteration(
164-
cache, abstol, adaptive; kwargs...)
165+
z, y, info, error_norm = __perform_ascher_iteration(cache, abstol, adaptive)
165166
end
166167
end
167168
u = [vcat(zᵢ, yᵢ) for (zᵢ, yᵢ) in zip(z, y)]
@@ -170,21 +171,19 @@ function SciMLBase.solve!(cache::AscherCache{iip, T}) where {iip, T}
170171
cache.prob, cache.alg, cache.original_mesh, u; retcode = info)
171172
end
172173

173-
function __perform_ascher_iteration(cache::AscherCache{iip, T}, abstol, adaptive::Bool;
174-
nlsolve_kwargs = (;), kwargs...) where {iip, T}
174+
function __perform_ascher_iteration(
175+
cache::AscherCache{iip, T}, abstol, adaptive::Bool) where {iip, T}
175176
info::ReturnCode.T = ReturnCode.Success
176177
nlprob = __construct_nlproblem(cache)
177178
nlsolve_alg = __concrete_nonlinearsolve_algorithm(nlprob, cache.alg.nlsolve)
178-
nlsol = __solve(nlprob, nlsolve_alg; abstol, kwargs..., nlsolve_kwargs...)
179+
nlsol = __solve(nlprob, nlsolve_alg; cache.nlsolve_kwargs...)
179180
error_norm = 2 * abstol
180181
info = nlsol.retcode
181182

182-
N = length(cache.mesh)
183-
184183
z = copy(cache.z)
185184
y = copy(cache.y)
186-
for i in 1:N
187-
@views approx(cache, cache.mesh[i], z[i], y[i])
185+
for (i, m) in enumerate(cache.mesh)
186+
@views approx(cache, m, z[i], y[i])
188187
end
189188

190189
# Preserve dmz, and mesh for the mesh selection
@@ -203,7 +202,7 @@ function __perform_ascher_iteration(cache::AscherCache{iip, T}, abstol, adaptive
203202
__expand_cache_for_error!(cache)
204203

205204
_nlprob = __construct_nlproblem(cache)
206-
nlsol = __solve(_nlprob, nlsolve_alg; abstol, kwargs..., nlsolve_kwargs...)
205+
nlsol = __solve(_nlprob, nlsolve_alg; cache.nlsolve_kwargs...)
207206

208207
error_norm = error_estimate!(cache)
209208
if norm(error_norm) > abstol

lib/BoundaryValueDiffEqFIRK/src/firk.jl

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
defect
2626
nest_prob
2727
resid_size
28+
nlsolve_kwargs
2829
kwargs
2930
end
3031

@@ -56,6 +57,7 @@ Base.eltype(::FIRKCacheNested{iip, T}) where {iip, T} = T
5657
fᵢ₂_cache
5758
defect
5859
resid_size
60+
nlsolve_kwargs
5961
kwargs
6062
end
6163

@@ -79,19 +81,21 @@ function shrink_y(y, N, stage)
7981
return y_shrink
8082
end
8183

82-
function SciMLBase.__init(prob::BVProblem, alg::AbstractFIRK; dt = 0.0, abstol = 1e-6,
83-
adaptive = true, controller = DefectControl(), kwargs...)
84+
function SciMLBase.__init(
85+
prob::BVProblem, alg::AbstractFIRK; dt = 0.0, abstol = 1e-6, adaptive = true,
86+
controller = DefectControl(), nlsolve_kwargs = (; abstol = abstol), kwargs...)
8487
if alg.nested_nlsolve
85-
return init_nested(prob, alg; dt = dt, abstol = abstol,
86-
adaptive = adaptive, controller = controller, kwargs...)
88+
return init_nested(prob, alg; dt = dt, abstol = abstol, adaptive = adaptive,
89+
controller = controller, nlsolve_kwargs = nlsolve_kwargs, kwargs...)
8790
else
88-
return init_expanded(prob, alg; dt = dt, abstol = abstol,
89-
adaptive = adaptive, controller = controller, kwargs...)
91+
return init_expanded(prob, alg; dt = dt, abstol = abstol, adaptive = adaptive,
92+
controller = controller, nlsolve_kwargs = nlsolve_kwargs, kwargs...)
9093
end
9194
end
9295

93-
function init_nested(prob::BVProblem, alg::AbstractFIRK; dt = 0.0, abstol = 1e-6,
94-
adaptive = true, controller = DefectControl(), kwargs...)
96+
function init_nested(
97+
prob::BVProblem, alg::AbstractFIRK; dt = 0.0, abstol = 1e-6, adaptive = true,
98+
controller = DefectControl(), nlsolve_kwargs = (; abstol = abstol), kwargs...)
9599
@set! alg.jac_alg = concrete_jacobian_algorithm(alg.jac_alg, prob, alg)
96100

97101
iip = isinplace(prob)
@@ -178,13 +182,14 @@ function init_nested(prob::BVProblem, alg::AbstractFIRK; dt = 0.0, abstol = 1e-6
178182

179183
return FIRKCacheNested{iip, T, typeof(diffcache)}(
180184
alg_order(alg), stage, M, size(X), f, bc, prob_, prob.problem_type,
181-
prob.p, alg, TU, ITU, bcresid_prototype, mesh, mesh_dt,
182-
k_discrete, y, y₀, residual, fᵢ_cache, fᵢ₂_cache, defect, nestprob,
183-
resid₁_size, (; abstol, dt, adaptive, controller, kwargs...))
185+
prob.p, alg, TU, ITU, bcresid_prototype, mesh, mesh_dt, k_discrete,
186+
y, y₀, residual, fᵢ_cache, fᵢ₂_cache, defect, nestprob, resid₁_size,
187+
nlsolve_kwargs, (; abstol, dt, adaptive, controller, kwargs...))
184188
end
185189

186-
function init_expanded(prob::BVProblem, alg::AbstractFIRK; dt = 0.0, abstol = 1e-6,
187-
adaptive = true, controller = DefectControl(), kwargs...)
190+
function init_expanded(
191+
prob::BVProblem, alg::AbstractFIRK; dt = 0.0, abstol = 1e-6, adaptive = true,
192+
controller = DefectControl(), nlsolve_kwargs = (; abstol = abstol), kwargs...)
188193
@set! alg.jac_alg = concrete_jacobian_algorithm(alg.jac_alg, prob, alg)
189194

190195
if adaptive && isa(alg, FIRKNoAdaptivity)
@@ -260,9 +265,10 @@ function init_expanded(prob::BVProblem, alg::AbstractFIRK; dt = 0.0, abstol = 1e
260265
prob_ = !(prob.u0 isa AbstractArray) ? remake(prob; u0 = X) : prob
261266

262267
return FIRKCacheExpand{iip, T, typeof(diffcache)}(
263-
alg_order(alg), stage, M, size(X), f, bc, prob_, prob.problem_type, prob.p, alg,
264-
TU, ITU, bcresid_prototype, mesh, mesh_dt, k_discrete, y, y₀, residual, fᵢ_cache,
265-
fᵢ₂_cache, defect, resid₁_size, (; abstol, dt, adaptive, controller, kwargs...))
268+
alg_order(alg), stage, M, size(X), f, bc, prob_, prob.problem_type,
269+
prob.p, alg, TU, ITU, bcresid_prototype, mesh, mesh_dt, k_discrete,
270+
y, y₀, residual, fᵢ_cache, fᵢ₂_cache, defect, resid₁_size,
271+
nlsolve_kwargs, (; abstol, dt, adaptive, controller, kwargs...))
266272
end
267273

268274
"""
@@ -297,13 +303,12 @@ function SciMLBase.solve!(cache::FIRKCacheExpand{iip, T}) where {iip, T}
297303

298304
# We do the first iteration outside the loop to preserve type-stability of the
299305
# `original` field of the solution
300-
sol_nlprob, info, defect_norm = __perform_firk_iteration(
301-
cache, abstol, adaptive; kwargs...)
306+
sol_nlprob, info, defect_norm = __perform_firk_iteration(cache, abstol, adaptive)
302307

303308
if adaptive
304309
while SciMLBase.successful_retcode(info) && defect_norm > abstol
305310
sol_nlprob, info, defect_norm = __perform_firk_iteration(
306-
cache, abstol, adaptive; kwargs...)
311+
cache, abstol, adaptive)
307312
end
308313
end
309314

@@ -323,13 +328,12 @@ function SciMLBase.solve!(cache::FIRKCacheNested{iip, T}) where {iip, T}
323328

324329
# We do the first iteration outside the loop to preserve type-stability of the
325330
# `original` field of the solution
326-
sol_nlprob, info, defect_norm = __perform_firk_iteration(
327-
cache, abstol, adaptive; kwargs...)
331+
sol_nlprob, info, defect_norm = __perform_firk_iteration(cache, abstol, adaptive)
328332

329333
if adaptive
330334
while SciMLBase.successful_retcode(info) && defect_norm > abstol
331335
sol_nlprob, info, defect_norm = __perform_firk_iteration(
332-
cache, abstol, adaptive; kwargs...)
336+
cache, abstol, adaptive)
333337
end
334338
end
335339

@@ -342,12 +346,11 @@ function SciMLBase.solve!(cache::FIRKCacheNested{iip, T}) where {iip, T}
342346
return __build_solution(cache.prob, odesol, sol_nlprob)
343347
end
344348

345-
function __perform_firk_iteration(cache::Union{FIRKCacheExpand, FIRKCacheNested}, abstol,
346-
adaptive::Bool; nlsolve_kwargs = (;), kwargs...)
349+
function __perform_firk_iteration(
350+
cache::Union{FIRKCacheExpand, FIRKCacheNested}, abstol, adaptive::Bool)
347351
nlprob = __construct_nlproblem(cache, vec(cache.y₀), copy(cache.y₀))
348352
nlsolve_alg = __concrete_nonlinearsolve_algorithm(nlprob, cache.alg.nlsolve)
349-
sol_nlprob = __solve(
350-
nlprob, nlsolve_alg; abstol, kwargs..., nlsolve_kwargs..., alias_u0 = true)
353+
sol_nlprob = __solve(nlprob, nlsolve_alg; cache.nlsolve_kwargs..., alias_u0 = true)
351354
recursive_unflatten!(cache.y₀, sol_nlprob.u)
352355

353356
defect_norm = 2 * abstol

lib/BoundaryValueDiffEqMIRK/src/adaptivity.jl

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -351,17 +351,16 @@ error_estimate for the hybrid error control uses the linear combination of defec
351351
error to estimate the error norm.
352352
"""
353353
# Defect control
354-
@views function error_estimate!(
355-
cache::MIRKCache{iip, T}, controller::GlobalErrorControl, errors,
356-
sol, nlsolve_alg, abstol, kwargs, nlsolve_kwargs) where {iip, T}
357-
return error_estimate!(cache::MIRKCache{iip, T}, controller, controller.method,
358-
errors, sol, nlsolve_alg, abstol, kwargs, nlsolve_kwargs)
354+
@views function error_estimate!(cache::MIRKCache{iip, T}, controller::GlobalErrorControl,
355+
errors, sol, nlsolve_alg, abstol) where {iip, T}
356+
return error_estimate!(
357+
cache, controller, controller.method, errors, sol, nlsolve_alg, abstol)
359358
end
360359

361360
# Global error control
362361
@views function error_estimate!(
363362
cache::MIRKCache{iip, T, use_both, DiffCacheNeeded}, controller::DefectControl,
364-
errors, sol, nlsolve_alg, abstol, kwargs, nlsolve_kwargs) where {iip, T, use_both}
363+
errors, sol, nlsolve_alg, abstol) where {iip, T, use_both}
365364
(; f, alg, mesh, mesh_dt) = cache
366365
(; τ_star) = cache.ITU
367366

@@ -407,7 +406,7 @@ end
407406
end
408407
@views function error_estimate!(
409408
cache::MIRKCache{iip, T, use_both, NoDiffCacheNeeded}, controller::DefectControl,
410-
errors, sol, nlsolve_alg, abstol, kwargs, nlsolve_kwargs) where {iip, T, use_both}
409+
errors, sol, nlsolve_alg, abstol) where {iip, T, use_both}
411410
(; f, alg, mesh, mesh_dt) = cache
412411
(; τ_star) = cache.ITU
413412

@@ -454,42 +453,41 @@ end
454453

455454
# Sequential error control
456455
@views function error_estimate!(
457-
cache::MIRKCache{iip, T}, controller::SequentialErrorControl, errors,
458-
sol, nlsolve_alg, abstol, kwargs, nlsolve_kwargs) where {iip, T}
459-
defect_norm, info = error_estimate!(cache::MIRKCache{iip, T}, controller.defect, errors,
460-
sol, nlsolve_alg, abstol, kwargs, nlsolve_kwargs)
456+
cache::MIRKCache{iip, T}, controller::SequentialErrorControl,
457+
errors, sol, nlsolve_alg, abstol) where {iip, T}
458+
defect_norm, info = error_estimate!(
459+
cache::MIRKCache{iip, T}, controller.defect, errors, sol, nlsolve_alg, abstol)
461460
error_norm = defect_norm
462461
if defect_norm <= abstol
463462
global_error_norm, info = error_estimate!(
464463
cache::MIRKCache{iip, T}, controller.global_error,
465-
controller.global_error.method, errors, sol,
466-
nlsolve_alg, abstol, kwargs, nlsolve_kwargs)
464+
controller.global_error.method, errors, sol, nlsolve_alg, abstol)
467465
error_norm = global_error_norm
468466
return error_norm, info
469467
end
470468
return error_norm, info
471469
end
472470

473471
# Hybrid error control
474-
function error_estimate!(cache::MIRKCache{iip, T}, controller::HybridErrorControl, errors,
475-
sol, nlsolve_alg, abstol, kwargs, nlsolve_kwargs) where {iip, T}
472+
function error_estimate!(cache::MIRKCache{iip, T}, controller::HybridErrorControl,
473+
errors, sol, nlsolve_alg, abstol) where {iip, T}
476474
L = length(cache.mesh) - 1
477475
defect = errors[:, 1:L]
478476
global_error = errors[:, (L + 1):end]
479-
defect_norm, _ = error_estimate!(cache::MIRKCache{iip, T}, controller.defect, defect,
480-
sol, nlsolve_alg, abstol, kwargs, nlsolve_kwargs)
477+
defect_norm, _ = error_estimate!(
478+
cache::MIRKCache{iip, T}, controller.defect, defect, sol, nlsolve_alg, abstol)
481479
global_error_norm, _ = error_estimate!(
482-
cache::MIRKCache{iip, T}, controller.global_error, controller.global_error.method,
483-
global_error, sol, nlsolve_alg, abstol, kwargs, nlsolve_kwargs)
480+
cache, controller.global_error, controller.global_error.method,
481+
global_error, sol, nlsolve_alg, abstol)
484482

485483
error_norm = controller.DE * defect_norm + controller.GE * global_error_norm
486484
copyto!(errors, VectorOfArray(vcat(defect.u, global_error.u)))
487485
return error_norm, ReturnCode.Success
488486
end
489487

490488
@views function error_estimate!(cache::MIRKCache{iip, T}, controller::GlobalErrorControl,
491-
global_error_control::REErrorControl, errors, sol,
492-
nlsolve_alg, abstol, kwargs, nlsolve_kwargs) where {iip, T}
489+
global_error_control::REErrorControl, errors,
490+
sol, nlsolve_alg, abstol) where {iip, T}
493491
(; prob, alg) = cache
494492

495493
# Use the previous solution as the initial guess
@@ -500,16 +498,16 @@ end
500498
high_nlprob = __construct_nlproblem(
501499
high_cache, vec(high_sol), VectorOfArray(high_sol.u))
502500
high_sol_original = __solve(
503-
high_nlprob, nlsolve_alg; abstol, kwargs..., nlsolve_kwargs..., alias_u0 = true)
501+
high_nlprob, nlsolve_alg; cache.nlsolve_kwargs..., alias_u0 = true)
504502
recursive_unflatten!(high_sol, high_sol_original.u)
505503
error_norm = global_error(
506504
VectorOfArray(copy(high_sol.u[1:2:end])), copy(cache.y₀), errors)
507505
return error_norm * 2^cache.order / (2^cache.order - 1), ReturnCode.Success
508506
end
509507

510508
@views function error_estimate!(cache::MIRKCache{iip, T}, controller::GlobalErrorControl,
511-
global_error_control::HOErrorControl, errors, sol,
512-
nlsolve_alg, abstol, kwargs, nlsolve_kwargs) where {iip, T}
509+
global_error_control::HOErrorControl, errors,
510+
sol, nlsolve_alg, abstol) where {iip, T}
513511
(; prob, alg) = cache
514512

515513
# Use the previous solution as the initial guess
@@ -519,7 +517,7 @@ end
519517

520518
high_nlprob = __construct_nlproblem(high_cache, sol.u, high_sol)
521519
high_sol_nlprob = __solve(
522-
high_nlprob, nlsolve_alg; abstol, kwargs..., nlsolve_kwargs..., alias_u0 = true)
520+
high_nlprob, nlsolve_alg; cache.nlsolve_kwargs..., alias_u0 = true)
523521
recursive_unflatten!(high_sol, high_sol_nlprob)
524522
error_norm = global_error(VectorOfArray(high_sol.u), cache.y₀, errors)
525523
return error_norm, ReturnCode.Success

0 commit comments

Comments
 (0)