-
-
Notifications
You must be signed in to change notification settings - Fork 56
Description
Is your feature request related to a problem? Please describe.
As a user of this library, it can be really nice to add a custom trace to the solver. I have two common uses for this.
- Convergence rate estimation (see Adaptive IDSolve OrdinaryDiffEq.jl#2881). This can be used to adapt the time step of the solver.
- Debugging PDE problems. Here it can be immensely helpful to dump the grid/mesh and solution vector to a file to visually inspect how the Newton diverges. This typically gives insights on why the solver diverges, which can be for example seen in localized blowup or faults in the boundary conditions.
Describe the solution you’d like
I would like to define some struct and dispatch some functions which I can pass into the library to trace. For example something like this:
@concrete struct ConvergenceRateTracing
inner_tracing
end
@concrete struct ConvergenceRateTraceTrick
incrementL2norms
residualL2norms
trace_wrapper
end
function NonlinearSolveBase.init_nonlinearsolve_trace(
prob, alg::IDSolve, u, fu, J, δu;
trace_level::ConvergenceRateTracing, kwargs... # This kind of dispatch does not work. Need to figure out a different way.
)
inner_trace = NonlinearSolveBase.init_nonlinearsolve_trace(
prob, alg, u, fu, J, δu;
trace_level.inner_tracing, kwargs...
)
return ConvergenceRateTraceTrick(eltype(δu)[], eltype(fu)[], inner_trace)
endDescribe alternatives you’ve considered
Right now I copy paste the contents of solve! (
NonlinearSolve.jl/lib/NonlinearSolveBase/src/solve.jl
Lines 231 to 261 in ac9344f
| function CommonSolve.solve!(cache::AbstractNonlinearSolveCache) | |
| if cache.retcode == ReturnCode.InitialFailure | |
| return SciMLBase.build_solution( | |
| cache.prob, cache.alg, get_u(cache), get_fu(cache); | |
| cache.retcode, cache.stats, cache.trace | |
| ) | |
| end | |
| while not_terminated(cache) | |
| CommonSolve.step!(cache) | |
| end | |
| # The solver might have set a different `retcode` | |
| if cache.retcode == ReturnCode.Default | |
| cache.retcode = ifelse( | |
| cache.nsteps ≥ cache.maxiters, ReturnCode.MaxIters, ReturnCode.Success | |
| ) | |
| end | |
| update_from_termination_cache!(cache.termination_cache, cache) | |
| update_trace!( | |
| cache.trace, cache.nsteps, get_u(cache), get_fu(cache), nothing, nothing, nothing; | |
| last = Val(true) | |
| ) | |
| return SciMLBase.build_solution( | |
| cache.prob, cache.alg, get_u(cache), get_fu(cache); | |
| cache.retcode, cache.stats, cache.trace | |
| ) | |
| end |
An alternative would be some monitor API in addition to the tracing API, which can also communicate with the nonlinear solver to force it to terminate the solve.
Additional context
This could also simplify OrdinaryDiffEqNonlinearSolve.jl quite a bit in the long run.