Skip to content

User-defined Traces #715

@termi-official

Description

@termi-official

Is your feature request related to a problem? Please describe.

As a user of this library, it can be really nice to add a custom trace to the solver. I have two common uses for this.

  1. Convergence rate estimation (see Adaptive IDSolve OrdinaryDiffEq.jl#2881). This can be used to adapt the time step of the solver.
  2. Debugging PDE problems. Here it can be immensely helpful to dump the grid/mesh and solution vector to a file to visually inspect how the Newton diverges. This typically gives insights on why the solver diverges, which can be for example seen in localized blowup or faults in the boundary conditions.

Describe the solution you’d like

I would like to define some struct and dispatch some functions which I can pass into the library to trace. For example something like this:

@concrete struct ConvergenceRateTracing
    inner_tracing
end

@concrete struct ConvergenceRateTraceTrick
    incrementL2norms
    residualL2norms
    trace_wrapper
end

function NonlinearSolveBase.init_nonlinearsolve_trace(
        prob, alg::IDSolve, u, fu, J, δu;
        trace_level::ConvergenceRateTracing, kwargs... # This kind of dispatch does not work. Need to figure out a different way.
)
    inner_trace = NonlinearSolveBase.init_nonlinearsolve_trace(
        prob, alg, u, fu, J, δu;
        trace_level.inner_tracing, kwargs...
    )

    return ConvergenceRateTraceTrick(eltype(δu)[], eltype(fu)[], inner_trace)
end

Describe alternatives you’ve considered

Right now I copy paste the contents of solve! (

function CommonSolve.solve!(cache::AbstractNonlinearSolveCache)
if cache.retcode == ReturnCode.InitialFailure
return SciMLBase.build_solution(
cache.prob, cache.alg, get_u(cache), get_fu(cache);
cache.retcode, cache.stats, cache.trace
)
end
while not_terminated(cache)
CommonSolve.step!(cache)
end
# The solver might have set a different `retcode`
if cache.retcode == ReturnCode.Default
cache.retcode = ifelse(
cache.nsteps cache.maxiters, ReturnCode.MaxIters, ReturnCode.Success
)
end
update_from_termination_cache!(cache.termination_cache, cache)
update_trace!(
cache.trace, cache.nsteps, get_u(cache), get_fu(cache), nothing, nothing, nothing;
last = Val(true)
)
return SciMLBase.build_solution(
cache.prob, cache.alg, get_u(cache), get_fu(cache);
cache.retcode, cache.stats, cache.trace
)
end
) and manually add some statements around the trace function.

An alternative would be some monitor API in addition to the tracing API, which can also communicate with the nonlinear solver to force it to terminate the solve.

Additional context

This could also simplify OrdinaryDiffEqNonlinearSolve.jl quite a bit in the long run.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions