diff --git a/src/integrators.jl b/src/integrators.jl index dc5c0165..a2d08c39 100644 --- a/src/integrators.jl +++ b/src/integrators.jl @@ -21,22 +21,22 @@ end # called by DiffEqBase.init and solve (see below) function DiffEqBase.__init( - prob::DiffEqBase.AbstractODEProblem, - alg::DistributedODEAlgorithm, - args...; + prob::DiffEqBase.AbstractODEProblem, + alg::DistributedODEAlgorithm, + args...; dt, # required - stepstop=-1, + stepstop=-1, adjustfinal=false, callback=nothing, - kwargs...) - + kwargs...) + u = prob.u0 t = prob.tspan[1] tstop = prob.tspan[2] callbackset = DiffEqBase.CallbackSet(callback) isempty(callbackset.continuous_callbacks) || error("Continuous callbacks are not supported") - integrator = DistributedODEIntegrator(prob, alg, u, dt, t, tstop, 0, stepstop, adjustfinal, callbackset, false, cache(prob, alg; dt=dt, kwargs...)) + integrator = DistributedODEIntegrator(prob, alg, u, dt, t, tstop, 0, stepstop, adjustfinal, callbackset, false, init_cache(prob, alg; dt=dt, kwargs...)) DiffEqBase.initialize!(callbackset,u,t,integrator) return integrator @@ -46,10 +46,10 @@ end # called by DiffEqBase.solve function DiffEqBase.__solve( prob::DiffEqBase.AbstractODEProblem, - alg::DistributedODEAlgorithm, + alg::DistributedODEAlgorithm, args...; kwargs...) - + integrator = DiffEqBase.__init(prob, alg, args...; kwargs...) DiffEqBase.solve!(integrator) return integrator.u # ODEProblem returns a Solution objec @@ -61,6 +61,10 @@ function DiffEqBase.solve!(integrator::DistributedODEIntegrator) if integrator.adjustfinal && integrator.t + integrator.dt > integrator.tstop adjust_dt!(integrator, integrator.tstop - integrator.t) end + if !integrator.adjustfinal && integrator.t + integrator.dt/2 > integrator.tstop + break + end + DiffEqBase.step!(integrator) if integrator.step == integrator.stepstop @@ -90,13 +94,81 @@ function DiffEqBase.step!(integrator::DistributedODEIntegrator) end # solvers need to define this interface -step_u!(integrator) = step_u!(integrator, integrator.cache) +step_u!(integrator) = step_u!(integrator, integrator.cache) + +""" + adjust_dt!(integrator::DistributedODEIntegrator, dt[, dt_cache=nothing]) -function adjust_dt!(integrator::DistributedODEIntegrator, dt) +Adjust the time step of the integrator to `dt`. The optional `dt_cache` object +can be passed when the integrator has a `dt`-dependent component that needs to +be updated (such as a linear solver). +""" +function adjust_dt!(integrator::DistributedODEIntegrator, dt, dt_cache=nothing) # TODO: figure out interface for recomputing other objects (linear operators, etc) integrator.dt = dt + adjust_dt!(integrator.cache, dt, dt_cache) end +# interfaces + +""" + init_cache(prob, alg::A; kwargs...)::AC + +Construct an algorithm cache for the algorithm `alg`. This should be defined +for any algorithm type `A`, and should return an object of an appropriate cache +type `AC` that can be dispatched on for [`step_u!`](@ref) and/or +[`init_inner`](@ref)/[`update_inner!`](@ref). +""" +function init_cache end + +""" + step_u!(integrator, cache::AC) + +Perform a single step that updates the state `integrator.u` using accordint to +the algorithm corresponding to `cache`. + +This should be defined for any algorithm cache type `AC` that can be used +directly or as an inner timestepper. For outer timesteppers, +[`init_inner`](@ref) and [`update_inner!`](@ref) need to be defined instead. +""" +step_u!(integrator, cache) + +""" + init_dt_cache(cache::AC, prob, dt) + +Construct a `dt`-dependent subcache of `cache` for the ODE problem `prob`. This +should _not_ modify `cache` itself, but return an object that can be passed as +the `dt_cache` argument to [`adjust_dt!`](@ref). + +By default this returns `nothing`. This should be defined for any algorithm +cache type `AC` which has `dt`-dependent components. + +For example, an implicit solver can use this to return a factorized Euler +operator ``I-dt*L`` that is used as part of the implicit solve. + +This initialization will typically be done as part of [`init_cache`](@ref) +itself: this interface is provided for multirate schemes which need to modify +the `dt` of the inner solver at each outer stage. +""" +function init_dt_cache(cache, prob, dt) + return nothing +end + + +function get_dt_cache(cache) + return nothing +end + +""" + adjust_dt!(cache::AC, dt, dt_cache) + +Adjust the time step of the algorithm cache `cache`. This should be defined for +any algorithm cache type `AC`, where `dt_cache` is an object returned by +[`init_dt_cache`](@ref). +""" +adjust_dt!(cache, dt, dt_cache) + + # not sure what this should do? # defined as default initialize: https://github.com/SciML/DiffEqBase.jl/blob/master/src/callbacks.jl#L3 diff --git a/src/solvers/ark.jl b/src/solvers/ark.jl index e00917f4..13fb49e9 100644 --- a/src/solvers/ark.jl +++ b/src/solvers/ark.jl @@ -26,7 +26,8 @@ struct AdditiveRungeKuttaTableau{Nstages, Nstages², RT} C::NTuple{Nstages, RT} end -struct AdditiveRungeKuttaFullCache{Nstages, RT, A, O, L} +struct AdditiveRungeKuttaFullCache{Nstages,RT, A, G, O, L} + alg::G "stage value of the state variable" U::A #Qstages "evaluated linear part of each stage ``f_L(U^{(i)})``" @@ -38,8 +39,32 @@ struct AdditiveRungeKuttaFullCache{Nstages, RT, A, O, L} linsolve!::L end +function implicit_part(f::DiffEqBase.ODEFunction) + f.jvp === nothing && error("IMEX solvers require a `SplitODEFunction` or an `ODEFunction` with a `jvp` component.") + return f.jvp +end +implicit_part(f::DiffEqBase.SplitFunction) = f.f1 +implicit_part(f::OffsetODEFunction) = implicit_part(f.f) -function cache( +function init_dt_cache(cache::AdditiveRungeKuttaFullCache, prob, dt) + _init_dt_cache(cache.alg, cache.tableau, prob, dt) +end +function _init_dt_cache(alg::AdditiveRungeKutta, tab, prob, dt) + f_impl = implicit_part(prob.f) + W = EulerOperator(f_impl , -dt*tab.Aimpl[2,2], prob.p, prob.tspan[1]) + linsolve! = alg.linsolve(Val{:init}, W, prob.u0) + return (W, linsolve!) +end + +function get_dt_cache(cache::AdditiveRungeKuttaFullCache) + return (cache.W, cache.linsolve!) +end +function adjust_dt!(cache::AdditiveRungeKuttaFullCache, dt, (W, linsolve!)::Tuple) + cache.W = W + cache.linsolve! = linsolve! +end + +function init_cache( prob::DiffEqBase.AbstractODEProblem{uType, tType, true}, alg::AdditiveRungeKutta; dt, kwargs...) where {uType,tType} @@ -49,14 +74,8 @@ function cache( L = ntuple(i -> zero(prob.u0), Nstages) R = ntuple(i -> zero(prob.u0), Nstages) - if prob.f isa DiffEqBase.ODEFunction - W = EulerOperator(prob.f.jvp, -dt*tab.Aimpl[2,2], prob.p, prob.tspan[1]) - elseif prob.f isa DiffEqBase.SplitFunction - W = EulerOperator(prob.f.f1, -dt*tab.Aimpl[2,2], prob.p, prob.tspan[1]) - end - linsolve! = alg.linsolve(Val{:init}, W, prob.u0; kwargs...) - - AdditiveRungeKuttaFullCache(U, L, R, tab, W, linsolve!) + W, linsolve! = _init_dt_cache(alg, tab, prob, dt) + AdditiveRungeKuttaFullCache(alg, U, L, R, tab, W, linsolve!) end diff --git a/src/solvers/lsrk.jl b/src/solvers/lsrk.jl index e5493283..66a640e3 100644 --- a/src/solvers/lsrk.jl +++ b/src/solvers/lsrk.jl @@ -33,7 +33,7 @@ struct LowStorageRungeKutta2NIncCache{Nstages, RT, A} du::A end -function cache(prob::DiffEqBase.ODEProblem, alg::LowStorageRungeKutta2N; kwargs...) +function init_cache(prob::DiffEqBase.ODEProblem, alg::LowStorageRungeKutta2N; kwargs...) # @assert prob.problem_type isa DiffEqBase.IncrementingODEProblem || # prob.f isa DiffEqBase.IncrementingODEFunction du = zero(prob.u0) @@ -59,8 +59,19 @@ function step_u!(int, cache::LowStorageRungeKutta2NIncCache) end end +adjust_dt!(cache::LowStorageRungeKutta2NIncCache, dt, ::Nothing) = nothing + # for Multirate -function init_inner(prob, outercache::LowStorageRungeKutta2NIncCache, dt) +function inner_dts(outercache::LowStorageRungeKutta2NIncCache, dt, fast_dt) + N = nstages(outercache) + tab = outercache.tableau + ntuple(N) do i + Δt = (i == N ? 1-tab.C[i] : tab.C[i+1] - tab.C[i]) * dt + Δt / round(Δt / fast_dt) + end +end + +function init_inner_fun(prob, outercache::LowStorageRungeKutta2NIncCache, dt) OffsetODEFunction(prob.f.f1, zero(dt), one(dt), zero(dt), outercache.du) end function update_inner!(innerinteg, outercache::LowStorageRungeKutta2NIncCache, diff --git a/src/solvers/mis.jl b/src/solvers/mis.jl index 26068567..34341732 100644 --- a/src/solvers/mis.jl +++ b/src/solvers/mis.jl @@ -51,7 +51,7 @@ end nstages(::MultirateInfinitesimalStepCache{Nstages}) where {Nstages} = Nstages -function cache( +function init_cache( prob::DiffEqBase.AbstractODEProblem{uType, tType, true}, alg::MultirateInfinitesimalStep; kwargs...) where {uType,tType} @@ -66,8 +66,15 @@ function cache( return MultirateInfinitesimalStepCache(ΔU, F, tab) end +function inner_dts(outercache::MultirateInfinitesimalStepCache, dt, fast_dt) + tab = outercache.tableau + map(tab.d) do d_i + Δt = d_i*dt + Δt / round(Δt / fast_dt) + end +end -function init_inner(prob, outercache::MultirateInfinitesimalStepCache, dt) +function init_inner_fun(prob, outercache::MultirateInfinitesimalStepCache, dt) OffsetODEFunction(prob.f.f1, zero(dt), one(dt), one(dt), outercache.ΔU[end]) end diff --git a/src/solvers/multirate.jl b/src/solvers/multirate.jl index 0e379423..d018ec5d 100644 --- a/src/solvers/multirate.jl +++ b/src/solvers/multirate.jl @@ -19,12 +19,13 @@ struct Multirate{F,S} <: DistributedODEAlgorithm end -struct MultirateCache{OC,II} +struct MultirateCache{OC,II,SD} outercache::OC innerinteg::II + dt_cache::SD end -function cache( +function init_cache( prob::DiffEqBase.AbstractODEProblem, alg::Multirate; dt, fast_dt, kwargs...) @@ -33,13 +34,49 @@ function cache( # subproblems outerprob = DiffEqBase.remake(prob; f=prob.f.f2) - outercache = cache(outerprob, alg.slow) + outercache = init_cache(outerprob, alg.slow) - innerfun = init_inner(prob, outercache, dt) + sub_dts = inner_dts(outercache, dt, fast_dt) + unique_sub_dts = unique(sub_dts) + + innerfun = init_inner_fun(prob, outercache, dt) innerprob = DiffEqBase.remake(prob; f=innerfun) - innerinteg = DiffEqBase.init(innerprob, alg.fast; dt=fast_dt, kwargs...) - return MultirateCache(outercache, innerinteg) + innerinteg = DiffEqBase.init(innerprob, alg.fast; dt=unique_sub_dts[1], adjustfinal=false, kwargs...) + + # build dt_cache + unique_dt_caches = [ + i == 1 ? get_dt_cache(innerinteg.cache) : init_dt_cache(innerinteg.cache, innerinteg.prob, unique_sub_dts[i]) + for i = 1:length(unique_sub_dts)] + + dt_cache = map(sub_dts) do sub_dt + i = findfirst(==(sub_dt), unique_sub_dts) + unique_sub_dts[i] => unique_dt_caches[i] + end + + return MultirateCache(outercache, innerinteg, dt_cache) +end + +get_dt_cache(cache::Multirate) = cache.dt_cache +function init_dt_cache(cache::Multirate, prob, dt) + outercache = cache.outercache + innerinteg = cache.innerinteg + + fast_dt = innerinteg.dt # TODO: get the original fast_dt from somewhere + + sub_dts = inner_dts(outercache, dt, fast_dt) + unique_sub_dts = unique(sub_dts) + + unique_dt_caches = [ + init_dt_cache(innerinteg.cache, innerinteg.prob, unique_sub_dts[i]) + for i = 1:length(unique_sub_dts)] + + dt_cache = map(sub_dts) do sub_dt + i = findfirst(==(sub_dt), unique_sub_dts) + unique_sub_dts[i] => unique_dt_caches[i] + end + return dt_cache end +adjust_dt!(cache::Multirate, dt, dt_cache::Tuple) = cache.dt_cache function step_u!(int, cache::MultirateCache) @@ -54,23 +91,54 @@ function step_u!(int, cache::MultirateCache) innerinteg = cache.innerinteg fast_dt = innerinteg.dt - N = nstages(outercache) - for stage in 1:N + for i in 1:nstages(outercache) + sub_dt, sub_dt_cache = cache.dt_cache[i] + adjust_dt!(innerinteg, sub_dt, sub_dt_cache) + update_inner!(innerinteg, outercache, int.prob.f.f2, u, p, t, dt, i) + DiffEqBase.solve!(innerinteg) + end +end + +# interface +""" + nstages(outercache::AC) - update_inner!(innerinteg, outercache, int.prob.f.f2, u, p, t, dt, stage) +The number of stages of the algorithm determined by cache type `AC`. This should +be defined for any algorithm cache type `AC` used as an outer solver. +""" +function nstages end - # solve inner problem - # dv/dτ .= B[s]/(C[s+1] - C[s]) .* du .+ f_fast(v,τ) τ ∈ [τ0,τ1] - # TODO: make this more generic - # there are 2 strategies we can use here: - # a. use same fast_dt for all slow stages, use `adjustfinal=true` - # - problems for ARK (e.g. requires expensive LU factorization) - # b. use different fast_dt, cache expensive ops +""" + inner_dts(outercache::AC, dt, fast_dt) - innerinteg.adjustfinal = true - DiffEqBase.solve!(innerinteg) - innerinteg.dt = fast_dt # reset - end -end +The inner timesteps that will be used at each stage of the multirate procedure. +This should be defined for any algorithm cache type `AC` that will be used as an +outer solver, and should return a tuple of the length of the number of stages. +Each value will be approximately `fast_dt`, but rounded so that an integer +number of steps can be used at each outer stage (where `dt` is the slow time +step). +""" +function inner_dts end + +""" + init_inner_fun(prob, outercache::AC, dt) + +Construct the inner `ODEFunction` that will be used with inner solver. This +should be defined for any algorithm cache type `AC` that will be used as an +outer solver. +""" +function init_inner_fun end + +""" + update_inner!(innerinteg, outercache::AC, f_slow, u, p, t, dt, i) + +Update the inner integrator `innerinteg` for stage `i` of the outer algorithm. +This should be defined for any `outercache` type `AC`, and will typically modify: +- `innerinteg.prob.f` +- `innerinteg.u` +- `innerinteg.t` +- `innerinteg.tstop` +""" +function update_inner! end \ No newline at end of file diff --git a/src/solvers/ssprk.jl b/src/solvers/ssprk.jl index 46daf40e..6f7bfa65 100644 --- a/src/solvers/ssprk.jl +++ b/src/solvers/ssprk.jl @@ -34,7 +34,7 @@ struct StrongStabilityPreservingRungeKuttaCache{Nstages, RT, A} U::A end -function cache( +function init_cache( prob::DiffEqBase.AbstractODEProblem{uType, tType, true}, alg::StrongStabilityPreservingRungeKutta; kwargs...) where {uType,tType} @@ -44,7 +44,7 @@ function cache( U = zero(prob.u0) return StrongStabilityPreservingRungeKuttaCache(tab, fU, U) end - +adjust_dt!(cache::StrongStabilityPreservingRungeKutta, dt, ::Nothing) = nothing function step_u!(int, cache::StrongStabilityPreservingRungeKuttaCache{Nstages, RT, A}) where {Nstages, RT, A} tab = cache.tableau diff --git a/src/solvers/wickerskamarock.jl b/src/solvers/wickerskamarock.jl index 6a1c5341..3a808e6b 100644 --- a/src/solvers/wickerskamarock.jl +++ b/src/solvers/wickerskamarock.jl @@ -24,7 +24,7 @@ struct WickerSkamarockRungeKuttaCache{Nstages, RT, A} U::A F::A end -function cache(prob::DiffEqBase.ODEProblem, alg::WickerSkamarockRungeKutta; kwargs...) +function init_cache(prob::DiffEqBase.ODEProblem, alg::WickerSkamarockRungeKutta; kwargs...) U = similar(prob.u0) F = similar(prob.u0) return WickerSkamarockRungeKuttaCache(tableau(alg, eltype(F)), U, F) @@ -32,8 +32,18 @@ end nstages(::WickerSkamarockRungeKuttaCache{Nstages}) where {Nstages} = Nstages +function inner_dts(outercache::WickerSkamarockRungeKuttaCache, dt, fast_dt) + tab = outercache.tableau + if length(tab.c) == 2 # WSRK2 + Δt = dt/2 + else # WSRK3 + Δt = dt/6 + end + sub_dt = Δt / round(Δt / fast_dt) + return map(c -> sub_dt, tab.c) +end -function init_inner(prob, outercache::WickerSkamarockRungeKuttaCache, dt) +function init_inner_fun(prob, outercache::WickerSkamarockRungeKuttaCache, dt) OffsetODEFunction(prob.f.f1, zero(dt), one(dt), one(dt), outercache.F) end function update_inner!(innerinteg, outercache::WickerSkamarockRungeKuttaCache,