diff --git a/src/OrdinaryDiffEqOperatorSplitting.jl b/src/OrdinaryDiffEqOperatorSplitting.jl index 94f3152..5aa1a2f 100644 --- a/src/OrdinaryDiffEqOperatorSplitting.jl +++ b/src/OrdinaryDiffEqOperatorSplitting.jl @@ -7,7 +7,7 @@ import Unrolled: @unroll import SciMLBase, DiffEqBase, DataStructures -import OrdinaryDiffEqCore +import OrdinaryDiffEqCore: OrdinaryDiffEqCore, isadaptive, alg_order import UnPack: @unpack import DiffEqBase: init, TimeChoiceIterator @@ -16,14 +16,13 @@ abstract type AbstractOperatorSplitFunction <: DiffEqBase.AbstractODEFunction{tr abstract type AbstractOperatorSplittingAlgorithm end abstract type AbstractOperatorSplittingCache end -@inline DiffEqBase.isadaptive(::AbstractOperatorSplittingAlgorithm) = false - include("function.jl") include("problem.jl") include("integrator.jl") include("solver.jl") include("utils.jl") +include("controller.jl") -export GenericSplitFunction, OperatorSplittingProblem, LieTrotterGodunov +export GenericSplitFunction, OperatorSplittingProblem, LieTrotterGodunov, PalindromicPairLieTrotterGodunov end diff --git a/src/controller.jl b/src/controller.jl new file mode 100644 index 0000000..6356e33 --- /dev/null +++ b/src/controller.jl @@ -0,0 +1,98 @@ +@inline OrdinaryDiffEqCore.ispredictive(::AbstractOperatorSplittingAlgorithm) = false +@inline OrdinaryDiffEqCore.isstandard(::AbstractOperatorSplittingAlgorithm) = false +function OrdinaryDiffEqCore.beta2_default(alg::AbstractOperatorSplittingAlgorithm) + isadaptive(alg) ? 2 // (5alg_order(alg)) : 0 // 1 +end +function OrdinaryDiffEqCore.beta1_default(alg::AbstractOperatorSplittingAlgorithm, beta2) + isadaptive(alg) ? 7 // (10alg_order(alg)) : 0 // 1 +end + +function OrdinaryDiffEqCore.qmin_default(alg::AbstractOperatorSplittingAlgorithm) + isadaptive(alg) ? 1 // 5 : 0 // 1 +end +OrdinaryDiffEqCore.qmax_default(alg::AbstractOperatorSplittingAlgorithm) = 10 // 1 +function OrdinaryDiffEqCore.gamma_default(alg::AbstractOperatorSplittingAlgorithm) + isadaptive(alg) ? 9 // 10 : 0 // 1 +end +OrdinaryDiffEqCore.qsteady_min_default(alg::AbstractOperatorSplittingAlgorithm) = 1 // 1 +OrdinaryDiffEqCore.qsteady_max_default(alg::AbstractOperatorSplittingAlgorithm) = 1 // 1 + +mutable struct PIController{T} <: OrdinaryDiffEqCore.AbstractController + qmin::T + qmax::T + qsteady_min::T + qsteady_max::T + qoldinit::T + beta1::T + beta2::T + gamma::T + # Internal + q11::T + qold::T + q::T +end +PIController(; qmin, qmax, qsteady_min, qsteady_max, qoldinit, beta1, beta2, gamma, q11) = PIController(qmin, qmax, qsteady_min, qsteady_max, qoldinit, beta1, beta2, gamma, q11, qoldinit, qoldinit) + +function default_controller(alg, cache) + if !isadaptive(alg) + @warn "Trying to construct a controller for $alg, but the algorithm is not adaptive." + return nothing + end + + beta2 = OrdinaryDiffEqCore.beta2_default(alg) + beta1 = OrdinaryDiffEqCore.beta1_default(alg, beta2) + qmin = OrdinaryDiffEqCore.qmin_default(alg) + qmax = OrdinaryDiffEqCore.qmax_default(alg) + gamma = OrdinaryDiffEqCore.gamma_default(alg) + qsteady_min = OrdinaryDiffEqCore.qsteady_min_default(alg) + qsteady_max = OrdinaryDiffEqCore.qsteady_max_default(alg) + qoldinit = 1 // 10^4 + q11 = 1 // 1 + PIController(; + beta1, beta2, + qmin, qmax, + gamma, + qsteady_min, qsteady_max, + qoldinit, q11 + ) +end + +@inline DiffEqBase.isadaptive(::AbstractOperatorSplittingAlgorithm) = false + +@inline function stepsize_controller!(integrator::OperatorSplittingIntegrator, controller::PIController, alg) + (; qold, qmin, qmax, gamma) = controller + (; beta1, beta2) = controller + EEst = DiffEqBase.value(integrator.EEst) + + if iszero(EEst) + q = inv(qmax) + else + q11 = OrdinaryDiffEqCore.fastpower(EEst, convert(typeof(EEst), beta1)) + q = q11 / OrdinaryDiffEqCore.fastpower(qold, convert(typeof(EEst), beta2)) + controller.q11 = q11 + @fastmath q = max(inv(qmax), min(inv(qmin), q / gamma)) + end + controller.q = q # Return Q for temporary compat with OrdinaryDiffEqCore +end + +function step_accept_controller!(integrator::OperatorSplittingIntegrator, controller::PIController, alg) + (; q, qsteady_min, qsteady_max, qoldinit) = controller + EEst = DiffEqBase.value(integrator.EEst) + + if qsteady_min <= q <= qsteady_max + q = one(q) + end + controller.qold = max(EEst, qoldinit) + integrator.dt /= q + return nothing +end + +function step_reject_controller!(integrator::OperatorSplittingIntegrator, controller::PIController, alg) + (; q11, qmin, gamma) = controller + integrator.dt /= min(inv(qmin), q11 / gamma) + return nothing +end + +@inline function should_accept_step(integrator, controller::OrdinaryDiffEqCore.AbstractController) + return integrator.EEst <= 1 +end diff --git a/src/integrator.jl b/src/integrator.jl index 24167e7..546c99c 100644 --- a/src/integrator.jl +++ b/src/integrator.jl @@ -6,12 +6,13 @@ end IntegratorStats() = IntegratorStats(0, 0) -Base.@kwdef mutable struct IntegratorOptions{tType, fType, F3} +Base.@kwdef mutable struct IntegratorOptions{tType, fType, F2, F3} adaptive::Bool dtmin::tType = eps(Float64) dtmax::tType = Inf failfactor::fType = 4.0 verbose::Bool = false + internalnorm::F2 = DiffEqBase.ODE_DEFAULT_NORM isoutofdomain::F3 = DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN end @@ -70,6 +71,7 @@ mutable struct OperatorSplittingIntegrator{ synchronizer_tree::syncTreeType iter::Int controller::controllerType + EEst::Float64 # TODO integrate with controller cache opts::optionsType stats::IntegratorStats tdir::tType @@ -129,6 +131,8 @@ function DiffEqBase.__init( callback = DiffEqBase.CallbackSet(callback) + opts = IntegratorOptions(; verbose, adaptive, kwargs...) + subintegrator_tree, cache = build_subintegrator_tree_with_cache( prob, alg, @@ -136,9 +140,13 @@ function DiffEqBase.__init( 1:length(u), t0, dt, tf, tstops, saveat, d_discontinuities, callback, - adaptive, verbose + opts, ) + if controller === nothing && adaptive + controller = default_controller(alg, cache) + end + integrator = OperatorSplittingIntegrator( prob.f, alg, @@ -168,7 +176,8 @@ function DiffEqBase.__init( build_synchronizer_tree(prob.f), 0, controller, - IntegratorOptions(; verbose, adaptive), + NaN, + opts, IntegratorStats(), tType(tstops_internal.ordering isa DataStructures.FasterForward ? 1 : -1) ) @@ -280,13 +289,17 @@ increment_iteration(integrator::OperatorSplittingIntegrator) = integrator.iter + # Controller interface function reject_step!(integrator::OperatorSplittingIntegrator) OrdinaryDiffEqCore.increment_reject!(integrator.stats) - reject_step!(integrator, integrator.cache, integrator.controller) + reject_step!(integrator, integrator.controller) end -function reject_step!(integrator::OperatorSplittingIntegrator, cache, controller) +function reject_step!(integrator::OperatorSplittingIntegrator, controller) integrator.u .= integrator.uprev - # TODO what do we need to do with the subintegrators? + if !integrator.force_stepfail + step_reject_controller!(integrator, controller, integrator.alg) + end + # We need to roll-back the sub-integrators + prepare_subintegrators_to_redo_step!(integrator) end -function reject_step!(integrator::OperatorSplittingIntegrator, cache, ::Nothing) +function reject_step!(integrator::OperatorSplittingIntegrator, ::Nothing) if length(integrator.uprev) == 0 error("Cannot roll back integrator. Aborting time integration step at $(integrator.t).") end @@ -297,9 +310,9 @@ function should_accept_step(integrator::OperatorSplittingIntegrator) if integrator.force_stepfail || integrator.isout return false end - return should_accept_step(integrator, integrator.cache, integrator.controller) + return should_accept_step(integrator, integrator.controller) end -function should_accept_step(integrator::OperatorSplittingIntegrator, cache, ::Nothing) +function should_accept_step(integrator::OperatorSplittingIntegrator, ::Nothing) return !(integrator.force_stepfail) end function accept_step!(integrator::OperatorSplittingIntegrator) @@ -366,7 +379,7 @@ function step_footer!(integrator::OperatorSplittingIntegrator) integrator.t = ttmp#OrdinaryDiffEqCore.fixed_t_for_floatingpoint_error!(integrator, ttmp) # OrdinaryDiffEqCore.handle_callbacks!(integrator) step_accept_controller!(integrator) # Noop for non-adaptive algorithms - elseif integrator.force_stepfail + elseif integrator.force_stepfail # Rejected by solver if SciMLBase.isadaptive(integrator) step_reject_controller!(integrator) OrdinaryDiffEqCore.post_newton_controller!(integrator, integrator.alg) @@ -525,9 +538,8 @@ end Updates the controller using the current state of the integrator if the operator splitting algorithm is adaptive. """ @inline function stepsize_controller!(integrator::OperatorSplittingIntegrator) - algorithm = integrator.alg - DiffEqBase.isadaptive(algorithm) || return nothing - stepsize_controller!(integrator, algorithm) + DiffEqBase.isadaptive(integrator) || return nothing + stepsize_controller!(integrator, integrator.controller, integrator.alg) end """ @@ -536,9 +548,8 @@ end Updates `dtcache` of the integrator if the step is accepted and the operator splitting algorithm is adaptive. """ @inline function step_accept_controller!(integrator::OperatorSplittingIntegrator) - algorithm = integrator.alg - DiffEqBase.isadaptive(algorithm) || return nothing - step_accept_controller!(integrator, algorithm, nothing) + DiffEqBase.isadaptive(integrator) || return nothing + step_accept_controller!(integrator, integrator.controller, integrator.alg) end """ @@ -547,9 +558,8 @@ end Updates `dtcache` of the integrator if the step is rejected and the the operator splitting algorithm is adaptive. """ @inline function step_reject_controller!(integrator::OperatorSplittingIntegrator) - algorithm = integrator.alg - DiffEqBase.isadaptive(algorithm) || return nothing - step_reject_controller!(integrator, algorithm, nothing) + DiffEqBase.isadaptive(integrator) || return nothing + step_reject_controller!(integrator, integrator.controller, integrator.alg) end # helper functions for dealing with time-reversed integrators in the same way @@ -646,7 +656,7 @@ end function synchronize_subintegrator!( subintegrator::SciMLBase.DEIntegrator, integrator::OperatorSplittingIntegrator) @unpack t, dt = integrator - @assert subintegrator.t == t + @assert subintegrator.t == t "Integrators out of sync. The outer integrator is at $t, but inner integrator is at $(subintegrator.t)" if !DiffEqBase.isadaptive(subintegrator) SciMLBase.set_proposed_dt!(subintegrator, dt) end @@ -662,72 +672,37 @@ end # Dispatch for tree node construction function build_subintegrator_tree_with_cache( prob::OperatorSplittingProblem, alg::AbstractOperatorSplittingAlgorithm, + f::GenericSplitFunction, p::Tuple, uprevouter::AbstractVector, uouter::AbstractVector, solution_indices, t0, dt, tf, - tstops, saveat, d_discontinuities, callback, - adaptive, verbose + args..., ) - (; f, p) = prob - subintegrator_tree_with_caches = ntuple( - i -> build_subintegrator_tree_with_cache( - prob, - alg.inner_algs[i], - get_operator(f, i), - p[i], - uprevouter, uouter, - f.solution_indices[i], - t0, dt, tf, - tstops, saveat, d_discontinuities, callback, - adaptive, verbose - ), - length(f.functions) - ) + # subintegrator_tree_with_caches = ntuple( + # i -> build_subintegrator_tree_with_cache( + # OperatorSplittingProblem(), + # alg.inner_algs[i], + # get_operator(f, i), + # p[i], + # uprevouter, uouter, + # f.solution_indices[i], + # t0, dt, tf, + # args..., + # ), + # length(f.functions) + # ) + + # subintegrator_tree_leafs = first.(subintegrator_tree_with_caches) + # inner_caches = last.(subintegrator_tree_with_caches) subintegrator_tree = ntuple( - i -> subintegrator_tree_with_caches[i][1], length(f.functions)) - caches = ntuple(i -> subintegrator_tree_with_caches[i][2], length(f.functions)) - - # TODO fix mixed device type problems we have to be smarter - return subintegrator_tree, - init_cache(f, alg; - uprev = uprevouter, u = uouter, alias_u = true, - inner_caches = caches - ) -end - -function build_subintegrator_tree_with_cache( - prob::OperatorSplittingProblem, alg::AbstractOperatorSplittingAlgorithm, - f::GenericSplitFunction, p::Tuple, - uprevouter::AbstractVector, uouter::AbstractVector, - solution_indices, - t0, dt, tf, - tstops, saveat, d_discontinuities, callback, - adaptive, verbose, - save_end = false, - controller = nothing -) - subintegrator_tree_with_caches = ntuple( - i -> build_subintegrator_tree_with_cache( - prob, - alg.inner_algs[i], - get_operator(f, i), - p[i], - uprevouter, uouter, - f.solution_indices[i], - t0, dt, tf, - tstops, saveat, d_discontinuities, callback, - adaptive, verbose - ), + i-> DiffEqBase.__init() length(f.functions) ) - subintegrator_tree = first.(subintegrator_tree_with_caches) - inner_caches = last.(subintegrator_tree_with_caches) - # TODO fix mixed device type problems we have to be smarter uprev = @view uprevouter[solution_indices] - u = @view uouter[solution_indices] + u = @view uouter[solution_indices] return subintegrator_tree, init_cache(f, alg; uprev = uprev, u = u, @@ -741,13 +716,11 @@ function build_subintegrator_tree_with_cache( uprevouter::S, uouter::S, solution_indices, t0::T, dt::T, tf::T, - tstops, saveat, d_discontinuities, callback, - adaptive, verbose, - save_end = false, - controller = nothing + opts, + args..., ) where {S, T, P, F} uprev = @view uprevouter[solution_indices] - u = @view uouter[solution_indices] + u = @view uouter[solution_indices] integrator = DiffEqBase.__init( SciMLBase.ODEProblem(f, u, (t0, min(t0 + dt, tf)), p), @@ -757,9 +730,9 @@ function build_subintegrator_tree_with_cache( d_discontinuities, save_everystep = false, advance_to_tstop = false, - adaptive, - controller, - verbose + opts.adaptive, + opts.verbose, + args..., ) return integrator, integrator.cache diff --git a/src/solver.jl b/src/solver.jl index 16837ca..d45675f 100644 --- a/src/solver.jl +++ b/src/solver.jl @@ -46,8 +46,137 @@ end if !(subinteg isa Tuple) && subinteg.sol.retcode ∉ (SciMLBase.ReturnCode.Default, SciMLBase.ReturnCode.Success) + outer_integrator.force_stepfail = true + end + outer_integrator.force_stepfail && return + backward_sync_subintegrator!(outer_integrator, subinteg, idxs, synchronizer) + end +end + +OrdinaryDiffEqCore.alg_order(alg::LieTrotterGodunov) = 1 + +# Adaptive Lie-Trotter-Godunov Splitting Implementation +""" + PalindromicPairLieTrotterGodunov <: AbstractOperatorSplittingAlgorithm + +A second order sequential operator splitting algorithm using the midpoint rule. +""" +struct PalindromicPairLieTrotterGodunov{AlgTupleType} <: AbstractOperatorSplittingAlgorithm + inner_algs::AlgTupleType # Tuple of timesteppers for inner problems + # transfer_algs::TransferTupleType # Tuple of transfer algorithms from the master solution into the individual ones +end + +struct PalindromicPairLieTrotterGodunovCache{uType, uprevType, iiType} <: AbstractOperatorSplittingCache + u::uType + u0::uType + u2::uType + udiff::uType + uprev::uprevType + inner_caches::iiType +end + +function init_cache(f::GenericSplitFunction, alg::PalindromicPairLieTrotterGodunov; + uprev::AbstractArray, u::AbstractVector, + inner_caches, + alias_uprev = true, + alias_u = false +) + @assert length(inner_caches) == 2 "PP-LTG works only for two operators, but $(length(inner_caches)) have been provided." + + _uprev = alias_uprev ? uprev : SciMLBase.recursivecopy(uprev) + _u = alias_u ? u : SciMLBase.recursivecopy(u) + PalindromicPairLieTrotterGodunovCache(_u, copy(u), copy(u), copy(u), _uprev, inner_caches) +end + +@inline function advance_solution_to!( + outer_integrator::OperatorSplittingIntegrator, + subintegrators::Tuple, solution_indices::Tuple, + synchronizers::Tuple, cache::PalindromicPairLieTrotterGodunovCache, tnext) + advance_solution_to_palindromic!( + outer_integrator, subintegrators, reverse(subintegrators), + solution_indices, synchronizers, cache, tnext, + ) +end + +@inline @unroll function advance_solution_to_palindromic!( + outer_integrator::OperatorSplittingIntegrator, + subintegrators::Tuple, rsubintegrators::Tuple, solution_indices::Tuple, + synchronizers::Tuple, cache::PalindromicPairLieTrotterGodunovCache, tnext) + # @unpack u0, u2, udiff, uprev, inner_caches = cache + @unpack udiff, uprev, inner_caches = cache + + # FIXME + # u0 .= outer_integrator.u + u0 = copy(outer_integrator.u) + + # For each inner operator + i = 0 + @unroll for subinteg in subintegrators + i += 1 + synchronizer = synchronizers[i] + idxs = solution_indices[i] + cache = inner_caches[i] + + @timeit_debug "sync ->" forward_sync_subintegrator!(outer_integrator, subinteg, idxs, synchronizer) + @timeit_debug "time solve" advance_solution_to!( + outer_integrator, subinteg, idxs, synchronizer, cache, tnext) + if !(subinteg isa Tuple) && + subinteg.sol.retcode ∉ + (SciMLBase.ReturnCode.Default, SciMLBase.ReturnCode.Success) + integrator.force_stepfail = true + end + outer_integrator.force_stepfail && return + backward_sync_subintegrator!(outer_integrator, subinteg, idxs, synchronizer) + end + + # Store solution + # FIXME + # u2 .= outer_integrator.u + u2 = copy(outer_integrator.u) + + # Roll back + outer_integrator.u .= u0 + + @unroll for subinteg in rsubintegrators + synchronizer = synchronizers[i] + idxs = solution_indices[i] + cache = inner_caches[i] + + @timeit_debug "sync ->" forward_sync_subintegrator!(outer_integrator, subinteg, idxs, synchronizer) + @timeit_debug "time solve" advance_solution_to!( + outer_integrator, subinteg, idxs, synchronizer, cache, tnext) + if !(subinteg isa Tuple) && + subinteg.sol.retcode ∉ + (SciMLBase.ReturnCode.Default, SciMLBase.ReturnCode.Success) + integrator.force_stepfail = true return end backward_sync_subintegrator!(outer_integrator, subinteg, idxs, synchronizer) + i -= 1 + end + + if outer_integrator.opts.adaptive + # FIXME + # udiff .= outer_integrator.u - u2 + udiff = outer_integrator.u - u2 + outer_integrator.EEst = outer_integrator.opts.internalnorm(udiff, tnext) end + + outer_integrator.u .+= u2 + outer_integrator.u ./= 2 end + +OrdinaryDiffEqCore.isadaptive(alg::PalindromicPairLieTrotterGodunov) = true +OrdinaryDiffEqCore.alg_order(alg::PalindromicPairLieTrotterGodunov) = 2 + +# @inline function stepsize_controller!(integrator::OperatorSplittingIntegrator, alg::PalindromicPairLieTrotterGodunov) +# return nothing +# end + +# @inline function step_accept_controller!(integrator::OperatorSplittingIntegrator, alg::PalindromicPairLieTrotterGodunov, q) +# integrator.dt = integrator.dtcache +# return nothing +# end +# @inline function step_reject_controller!(integrator::OperatorSplittingIntegrator, alg::PalindromicPairLieTrotterGodunov, q) +# return nothing # Do nothing +# end diff --git a/src/utils.jl b/src/utils.jl index 7c29950..0123603 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -92,6 +92,7 @@ function forward_sync_internal!(outer_integrator::OperatorSplittingIntegrator, @views uouter = outer_integrator.u[solution_indices] sync_vectors!(inner_integrator.uprev, uouter) sync_vectors!(inner_integrator.u, uouter) + inner_integrator.t = outer_integrator.t SciMLBase.u_modified!(inner_integrator, true) end function backward_sync_internal!(outer_integrator::OperatorSplittingIntegrator, @@ -163,3 +164,18 @@ end function build_synchronizer_tree_recursion(f, synchronizer) return synchronizer end + +function prepare_subintegrators_to_redo_step!(integrator) + prepare_subintegrators_to_redo_step!(integrator.subintegrator_tree, integrator) +end + +@unroll function prepare_subintegrators_to_redo_step!(subintegrator_tree::Tuple, outer_integrator) + @unroll for subintegrator in subintegrator_tree + prepare_subintegrators_to_redo_step!(subintegrator, outer_integrator) + end +end +function prepare_subintegrators_to_redo_step!(subintegrator, outer_integrator) + subintegrator.t = outer_integrator.t + subintegrator.u .= subintegrator.uprev + DiffEqBase.u_modified!(subintegrator, true) +end diff --git a/test/operator_splitting_api.jl b/test/operator_splitting_api.jl index 0f3aac5..c632003 100644 --- a/test/operator_splitting_api.jl +++ b/test/operator_splitting_api.jl @@ -45,77 +45,22 @@ function ode3(du, u, p, t) du[2] = -0.005u[1] end f3 = ODEFunction(ode3) -# The time stepper carries the individual solver information. - -# Test whether adaptive code path works in principle -struct FakeAdaptiveAlgorithm{T} <: OS.AbstractOperatorSplittingAlgorithm - alg::T -end -struct FakeAdaptiveAlgorithmCache{T} <: OS.AbstractOperatorSplittingCache - cache::T -end -@inline DiffEqBase.isadaptive(::FakeAdaptiveAlgorithm) = true - -@inline function OS.stepsize_controller!(integrator::OS.OperatorSplittingIntegrator, alg::FakeAdaptiveAlgorithm) - return nothing -end - -@inline function OS.step_accept_controller!(integrator::OS.OperatorSplittingIntegrator, alg::FakeAdaptiveAlgorithm, q) - integrator.dt = integrator.dtcache - return nothing -end -@inline function OS.step_reject_controller!(integrator::OS.OperatorSplittingIntegrator, alg::FakeAdaptiveAlgorithm, q) - error("The tests should never run into this scenario!") - return nothing # Do nothing -end -function OS.build_subintegrator_tree_with_cache( - prob::OS.OperatorSplittingProblem, alg::FakeAdaptiveAlgorithm, - uprevouter::AbstractVector, uouter::AbstractVector, - solution_indices, - t0, dt, tf, - tstops, saveat, d_discontinuities, callback, - adaptive, verbose, -) - subintegrators, inner_cache = OS.build_subintegrator_tree_with_cache( - prob, alg.alg, uprevouter, uouter, solution_indices, - t0, dt, tf, - tstops, saveat, d_discontinuities, callback, - adaptive, verbose, - ) - return subintegrators, FakeAdaptiveAlgorithmCache( - inner_cache, - ) -end -function OS.build_subintegrator_tree_with_cache( - prob::OS.OperatorSplittingProblem, alg::FakeAdaptiveAlgorithm, - f::GenericSplitFunction, p::Tuple, - uprevouter::AbstractVector, uouter::AbstractVector, - solution_indices, - t0, dt, tf, - tstops, saveat, d_discontinuities, callback, - adaptive, verbose, - save_end = false, - controller = nothing +A₁nc = [-2.0 0.1 0.0; 0.1 -2.0 1.0; 0.0 0.1 -1.0] +A₂nc = [0.0 0.3 0.0; 0.0 0.0 0.2; 0.333 0.0 0.0] +fsplitnc = GenericSplitFunction( + ( + ODEFunction((du, u, p, t) -> du .= A₁nc * u), + ODEFunction((du, u, p, t) -> du .= A₂nc * u), + ), + ([1, 2, 3], [1, 2, 3]) ) - subintegrators, inner_cache = OS.build_subintegrator_tree_with_cache( - prob, alg.alg, f, p, uprevouter, uouter, solution_indices, - t0, dt, tf, - tstops, saveat, d_discontinuities, callback, - adaptive, verbose, - ) - return subintegrators, FakeAdaptiveAlgorithmCache( - inner_cache, - ) -end -FakeAdaptiveLTG(inner) = FakeAdaptiveAlgorithm(LieTrotterGodunov(inner)) - -@inline DiffEqBase.get_tmp_cache(integrator::OS.OperatorSplittingIntegrator, alg::OS.AbstractOperatorSplittingAlgorithm, cache::FakeAdaptiveAlgorithmCache) = DiffEqBase.get_tmp_cache(integrator, alg, cache.cache) -@inline function OS.advance_solution_to!(outer_integrator::OS.OperatorSplittingIntegrator, subintegrators::Tuple, solution_indices::Tuple, synchronizers::Tuple, cache::FakeAdaptiveAlgorithmCache, tnext) - OS.advance_solution_to!(outer_integrator, subintegrators, solution_indices, synchronizers, cache.cache, tnext) -end +# Now the usual setup just with our new problem type. +prob_nc = OperatorSplittingProblem(fsplitnc, u0, tspan) +trueunc = exp((tspan[2] - tspan[1]) * (A₁nc+A₂nc)) * u0 +# The time stepper carries the individual solver information. @testset "reinit and convergence" begin dt = 0.01π @@ -137,8 +82,12 @@ end fsplit2_inner = GenericSplitFunction((f3, f3), (f3dofs, f3dofs)) fsplit2_outer = GenericSplitFunction((f1, fsplit2_inner), (f1dofs, f2dofs)) + num_stages_a(alg) = 1 + num_stages_a(alg::Type{PalindromicPairLieTrotterGodunov}) = 2 + prob2 = OperatorSplittingProblem(fsplit2_outer, u0, tspan) - for TimeStepperType in (LieTrotterGodunov, FakeAdaptiveLTG) + + for TimeStepperType in (LieTrotterGodunov, PalindromicPairLieTrotterGodunov) @testset "Solver type $TimeStepperType | $tstepper" for (prob, tstepper) in ( (prob1, TimeStepperType((Euler(), Euler()))), (prob1, TimeStepperType((Tsit5(), Euler()))), @@ -163,7 +112,7 @@ end @test integrator.subintegrator_tree[1].t ≈ tspan[2] @test integrator.dtcache ≈ dt @test integrator.iter == ceil(Int, (tspan[2]-tspan[1])/dt) - @test integrator.subintegrator_tree[1].iter == ceil(Int, (tspan[2]-tspan[1])/dt) + @test integrator.subintegrator_tree[1].iter == ceil(Int, (tspan[2]-tspan[1])/dt)*num_stages_a(TimeStepperType) DiffEqBase.reinit!(integrator) @test integrator.sol.retcode == DiffEqBase.ReturnCode.Default @@ -174,7 +123,7 @@ end @test integrator.subintegrator_tree[1].t ≈ tspan[2] @test integrator.dtcache ≈ dt @test integrator.iter == ceil(Int, (tspan[2]-tspan[1])/dt) - @test integrator.subintegrator_tree[1].iter == ceil(Int, (tspan[2]-tspan[1])/dt) + @test integrator.subintegrator_tree[1].iter == ceil(Int, (tspan[2]-tspan[1])/dt)*num_stages_a(TimeStepperType) DiffEqBase.reinit!(integrator) @test integrator.sol.retcode == DiffEqBase.ReturnCode.Default @@ -185,7 +134,7 @@ end @test integrator.subintegrator_tree[1].t ≈ tspan[2] @test integrator.dtcache ≈ dt @test integrator.iter == ceil(Int, (tspan[2]-tspan[1])/dt) - @test integrator.subintegrator_tree[1].iter == ceil(Int, (tspan[2]-tspan[1])/dt) + @test integrator.subintegrator_tree[1].iter == ceil(Int, (tspan[2]-tspan[1])/dt)*num_stages_a(TimeStepperType) DiffEqBase.reinit!(integrator) @test integrator.sol.retcode == DiffEqBase.ReturnCode.Default @@ -195,14 +144,15 @@ end @test integrator.subintegrator_tree[1].t ≈ tspan[2] @test integrator.dtcache ≈ dt @test integrator.iter == ceil(Int, (tspan[2]-tspan[1])/dt) - @test integrator.subintegrator_tree[1].iter == ceil(Int, (tspan[2]-tspan[1])/dt) + @test integrator.subintegrator_tree[1].iter == ceil(Int, (tspan[2]-tspan[1])/dt)*num_stages_a(TimeStepperType) end end - for TimeStepperType in (FakeAdaptiveLTG,) - @testset "Adaptive solver type $TimeStepperType | $tstepper" for (prob, tstepper) in ( - (prob1, TimeStepperType((Tsit5(), Tsit5()))), - (prob2, TimeStepperType((Tsit5(), TimeStepperType((Tsit5(), Tsit5()))))) + for TimeStepperType in (PalindromicPairLieTrotterGodunov,) + @testset "Solver type $TimeStepperType | $tstepper" for (prob, trueu, tstepper) in ( + (prob1, trueu, TimeStepperType((Tsit5(), Tsit5()))), + (prob2, trueu, TimeStepperType((Tsit5(), TimeStepperType((Tsit5(), Tsit5()))))), + (prob_nc, trueunc, TimeStepperType((Tsit5(), Tsit5()))), ) # The remaining code works as usual. integrator = DiffEqBase.init( @@ -213,10 +163,8 @@ end ufinal = copy(integrator.u) @test isapprox(ufinal, trueu, atol = 1e-6) @test integrator.t ≈ tspan[2] + @test integrator.iter < 20 @test integrator.subintegrator_tree[1].t ≈ tspan[2] - @test integrator.dtcache ≈ dt - @test integrator.iter == ceil(Int, (tspan[2]-tspan[1])/dt) - @test integrator.subintegrator_tree[1].iter == ceil(Int, (tspan[2]-tspan[1])/dt) DiffEqBase.reinit!(integrator) integrator.dt = dt @@ -225,10 +173,8 @@ end end @test isapprox(ufinal, integrator.u, atol = 1e-12) @test integrator.t ≈ tspan[2] + @test integrator.iter < 20 @test integrator.subintegrator_tree[1].t ≈ tspan[2] - @test integrator.dtcache ≈ dt - @test integrator.iter == ceil(Int, (tspan[2]-tspan[1])/dt) - @test integrator.subintegrator_tree[1].iter == ceil(Int, (tspan[2]-tspan[1])/dt)+1 # We need one extra step after reinit for some reason... DiffEqBase.reinit!(integrator) integrator.dt = dt @@ -237,10 +183,8 @@ end end @test isapprox(ufinal, integrator.u, atol = 1e-12) @test integrator.t ≈ tspan[2] + @test integrator.iter < 20 @test integrator.subintegrator_tree[1].t ≈ tspan[2] - @test integrator.dtcache ≈ dt - @test integrator.iter == ceil(Int, (tspan[2]-tspan[1])/dt) - @test integrator.subintegrator_tree[1].iter == ceil(Int, (tspan[2]-tspan[1])/dt)+1 DiffEqBase.reinit!(integrator) integrator.dt = dt @@ -248,10 +192,8 @@ end DiffEqBase.solve!(integrator) @test integrator.sol.retcode == DiffEqBase.ReturnCode.Success @test integrator.t ≈ tspan[2] + @test integrator.iter < 20 @test integrator.subintegrator_tree[1].t ≈ tspan[2] - @test integrator.dtcache ≈ dt - @test integrator.iter == ceil(Int, (tspan[2]-tspan[1])/dt) - @test integrator.subintegrator_tree[1].iter == ceil(Int, (tspan[2]-tspan[1])/dt)+1 end end