diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml deleted file mode 100644 index 3494a9f..0000000 --- a/.JuliaFormatter.toml +++ /dev/null @@ -1,3 +0,0 @@ -style = "sciml" -format_markdown = true -format_docstrings = true diff --git a/.github/workflows/FormatCheck.yml b/.github/workflows/FormatCheck.yml new file mode 100644 index 0000000..6762c6f --- /dev/null +++ b/.github/workflows/FormatCheck.yml @@ -0,0 +1,19 @@ +name: format-check + +on: + push: + branches: + - 'master' + - 'main' + - 'release-' + tags: '*' + pull_request: + +jobs: + runic: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: fredrikekre/runic-action@v1 + with: + version: '1' diff --git a/docs/make.jl b/docs/make.jl index 3f9a05d..0f968bf 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,8 +1,10 @@ using OrdinaryDiffEqOperatorSplitting using Documenter, DocumenterCitations -DocMeta.setdocmeta!(OrdinaryDiffEqOperatorSplitting, :DocTestSetup, - :(using OrdinaryDiffEqOperatorSplitting); recursive = true) +DocMeta.setdocmeta!( + OrdinaryDiffEqOperatorSplitting, :DocTestSetup, + :(using OrdinaryDiffEqOperatorSplitting); recursive = true +) const is_ci = haskey(ENV, "GITHUB_ACTIONS") @@ -16,7 +18,7 @@ makedocs( format = Documenter.HTML( assets = [ "assets/citations.css", - # "assets/favicon.ico" + # "assets/favicon.ico" ], # canonical = "https://localhost/", collapselevel = 1 @@ -31,7 +33,7 @@ makedocs( "Theory Manual" => "topics/time-integration.md", "api-reference/index.md", "devdocs/index.md", - "references.md" + "references.md", ], plugins = [ bibtex_plugin, @@ -45,6 +47,6 @@ deploydocs( devbranch = "main", versions = [ "stable" => "v^", - "dev" => "dev" + "dev" => "dev", ] ) diff --git a/src/function.jl b/src/function.jl index 99e45e1..31e18ac 100644 --- a/src/function.jl +++ b/src/function.jl @@ -5,7 +5,7 @@ This type of function describes a set of connected inner functions in mass-matrix form, as usually found in operator splitting procedures. """ struct GenericSplitFunction{fSetType <: Tuple, idxSetType <: Tuple, sSetType <: Tuple} <: - AbstractOperatorSplitFunction + AbstractOperatorSplitFunction # Tuple containing the atomic ode functions or further nested split functions. functions::fSetType # The ranges for the values in the solution vector. @@ -14,7 +14,7 @@ struct GenericSplitFunction{fSetType <: Tuple, idxSetType <: Tuple, sSetType <: synchronizers::sSetType function GenericSplitFunction(fs::Tuple, drs::Tuple, syncers::Tuple) @assert length(fs) == length(drs) == length(syncers) - new{typeof(fs), typeof(drs), typeof(syncers)}(fs, drs, syncers) + return new{typeof(fs), typeof(drs), typeof(syncers)}(fs, drs, syncers) end end @@ -28,7 +28,7 @@ Indicator that no synchronization between parameters and solution vectors is nec struct NoExternalSynchronization end function GenericSplitFunction(fs::Tuple, drs::Tuple) - GenericSplitFunction(fs, drs, ntuple(_->NoExternalSynchronization(), length(fs))) + return GenericSplitFunction(fs, drs, ntuple(_ -> NoExternalSynchronization(), length(fs))) end @inline get_operator(f::GenericSplitFunction, i::Integer) = f.functions[i] diff --git a/src/integrator.jl b/src/integrator.jl index 339137b..1490d56 100644 --- a/src/integrator.jl +++ b/src/integrator.jl @@ -23,23 +23,23 @@ A variant of [`ODEIntegrator`](https://github.com/SciML/OrdinaryDiffEq.jl/blob/6 Derived from https://github.com/CliMA/ClimaTimeSteppers.jl/blob/ef3023747606d2750e674d321413f80638136632/src/integrators.jl. """ mutable struct OperatorSplittingIntegrator{ - fType, - algType, - uType, - tType, - pType, - heapType, - tstopsType, - saveatType, - callbackType, - cacheType, - solType, - subintTreeType, - solidxTreeType, - syncTreeType, - controllerType, - optionsType -} <: SciMLBase.AbstractODEIntegrator{algType, true, uType, tType} + fType, + algType, + uType, + tType, + pType, + heapType, + tstopsType, + saveatType, + callbackType, + cacheType, + solType, + subintTreeType, + solidxTreeType, + syncTreeType, + controllerType, + optionsType, + } <: SciMLBase.AbstractODEIntegrator{algType, true, uType, tType} const f::fType const alg::algType u::uType # Master Solution @@ -92,7 +92,7 @@ function SciMLBase.__init( alias_u0 = false, verbose = true, kwargs... -) + ) (; u0, p) = prob t0, tf = prob.tspan @@ -115,10 +115,12 @@ function SciMLBase.__init( # Setup tstop logic tstops_internal = OrdinaryDiffEqCore.initialize_tstops( - tType, tstops, d_discontinuities, prob.tspan) + tType, tstops, d_discontinuities, prob.tspan + ) saveat_internal = OrdinaryDiffEqCore.initialize_saveat(tType, saveat, prob.tspan) d_discontinuities_internal = OrdinaryDiffEqCore.initialize_d_discontinuities( - tType, d_discontinuities, prob.tspan) + tType, d_discontinuities, prob.tspan + ) u = setup_u(prob, alg, alias_u0) uprev = setup_u(prob, alg, false) @@ -130,7 +132,7 @@ function SciMLBase.__init( callback = DiffEqBase.CallbackSet(callback) subintegrator_tree, - cache = build_subintegrator_tree_with_cache( + cache = build_subintegrator_tree_with_cache( prob, alg, uprev, u, 1:length(u), @@ -188,7 +190,7 @@ function DiffEqBase.reinit!( saveat = integrator._saveat, reinit_callbacks = true, reinit_retcode = true -) + ) integrator.u .= u0 integrator.uprev .= u0 integrator.t = t0 @@ -210,10 +212,11 @@ function DiffEqBase.reinit!( end if reinit_retcode integrator.sol = SciMLBase.solution_new_retcode( - integrator.sol, ReturnCode.Default) + integrator.sol, ReturnCode.Default + ) end - subreinit!( + return subreinit!( integrator.f, u0, 1:length(u0), @@ -234,12 +237,12 @@ function subreinit!( subintegrator::DEIntegrator; dt, kwargs... -) + ) # dt is not reset as expected in reinit! if dt !== nothing subintegrator.dt = dt end - DiffEqBase.reinit!(subintegrator, u0[solution_indices]; kwargs...) + return DiffEqBase.reinit!(subintegrator, u0[solution_indices]; kwargs...) end @unroll function subreinit!( @@ -248,7 +251,7 @@ end solution_indices, subintegrators::Tuple; kwargs... -) + ) i = 1 @unroll for subintegrator in subintegrators subreinit!(get_operator(f, i), u0, f.solution_indices[i], subintegrator; kwargs...) @@ -264,14 +267,16 @@ function OrdinaryDiffEqCore.handle_tstop!(integrator::OperatorSplittingIntegrato while tdir_t == tdir_tstop #remove all redundant copies res = SciMLBase.pop_tstop!(integrator) SciMLBase.has_tstop(integrator) ? - (tdir_tstop = SciMLBase.first_tstop(integrator)) : break + (tdir_tstop = SciMLBase.first_tstop(integrator)) : break end notify_integrator_hit_tstop!(integrator) elseif tdir_t > tdir_tstop if !integrator.dtchangeable - SciMLBase.change_t_via_interpolation!(integrator, + SciMLBase.change_t_via_interpolation!( + integrator, tdir(integrator) * - SciMLBase.pop_tstop!(integrator), Val{true}) + SciMLBase.pop_tstop!(integrator), Val{true} + ) notify_integrator_hit_tstop!(integrator) else error("Something went wrong. Integrator stepped past tstops but the algorithm was dtchangeable. Please report this error.") @@ -289,14 +294,14 @@ increment_iteration(integrator::OperatorSplittingIntegrator) = integrator.iter + # Controller interface function reject_step!(integrator::OperatorSplittingIntegrator) OrdinaryDiffEqCore.increment_reject!(integrator.stats) - reject_step!(integrator, integrator.cache, integrator.controller) + return reject_step!(integrator, integrator.cache, integrator.controller) end function reject_step!(integrator::OperatorSplittingIntegrator, cache, controller) - integrator.u .= integrator.uprev + return integrator.u .= integrator.uprev # TODO what do we need to do with the subintegrators? end function reject_step!(integrator::OperatorSplittingIntegrator, cache, ::Nothing) - if length(integrator.uprev) == 0 + return if length(integrator.uprev) == 0 error("Cannot roll back integrator. Aborting time integration step at $(integrator.t).") end end @@ -313,20 +318,20 @@ function should_accept_step(integrator::OperatorSplittingIntegrator, cache, ::No end function accept_step!(integrator::OperatorSplittingIntegrator) OrdinaryDiffEqCore.increment_accept!(integrator.stats) - accept_step!(integrator, integrator.cache, integrator.controller) + return accept_step!(integrator, integrator.cache, integrator.controller) end function accept_step!(integrator::OperatorSplittingIntegrator, cache, controller) - store_previous_info!(integrator) + return store_previous_info!(integrator) end function store_previous_info!(integrator::OperatorSplittingIntegrator) - if length(integrator.uprev) > 0 # Integrator can rollback + return if length(integrator.uprev) > 0 # Integrator can rollback update_uprev!(integrator) end end function update_uprev!(integrator::OperatorSplittingIntegrator) RecursiveArrayTools.recursivecopy!(integrator.uprev, integrator.u) - nothing + return nothing end function step_header!(integrator::OperatorSplittingIntegrator) @@ -346,19 +351,19 @@ function step_header!(integrator::OperatorSplittingIntegrator) # OrdinaryDiffEqCore.choose_algorithm!(integrator, integrator.cache) OrdinaryDiffEqCore.fix_dt_at_bounds!(integrator) OrdinaryDiffEqCore.modify_dt_for_tstops!(integrator) - integrator.force_stepfail = false + return integrator.force_stepfail = false end function footer_reset_flags!(integrator) - integrator.u_modified = false + return integrator.u_modified = false end function setup_validity_flags!(integrator, t_next) - integrator.isout = false #integrator.opts.isoutofdomain(integrator.u, integrator.p, t_next) + return integrator.isout = false #integrator.opts.isoutofdomain(integrator.u, integrator.p, t_next) end function fix_solution_buffer_sizes!(integrator, sol) resize!(integrator.sol.t, integrator.saveiter) resize!(integrator.sol.u, integrator.saveiter) - if !(integrator.sol isa SciMLBase.DAESolution) + return if !(integrator.sol isa SciMLBase.DAESolution) resize!(integrator.sol.k, integrator.saveiter_dense) end end @@ -372,7 +377,7 @@ function step_footer!(integrator::OperatorSplittingIntegrator) if should_accept_step(integrator) integrator.last_step_failed = false integrator.tprev = integrator.t - integrator.t = ttmp#OrdinaryDiffEqCore.fixed_t_for_floatingpoint_error!(integrator, ttmp) + integrator.t = ttmp #OrdinaryDiffEqCore.fixed_t_for_floatingpoint_error!(integrator, ttmp) # OrdinaryDiffEqCore.handle_callbacks!(integrator) step_accept_controller!(integrator) # Noop for non-adaptive algorithms elseif integrator.force_stepfail @@ -393,10 +398,12 @@ function step_footer!(integrator::OperatorSplittingIntegrator) end # called by DiffEqBase.solve -function SciMLBase.__solve(prob::OperatorSplittingProblem, - alg::AbstractOperatorSplittingAlgorithm, args...; kwargs...) +function SciMLBase.__solve( + prob::OperatorSplittingProblem, + alg::AbstractOperatorSplittingAlgorithm, args...; kwargs... + ) integrator = SciMLBase.__init(prob, alg, args...; kwargs...) - DiffEqBase.solve!(integrator) + return DiffEqBase.solve!(integrator) end # either called directly (after init), or by DiffEqBase.solve (via __solve) @@ -405,7 +412,8 @@ function DiffEqBase.solve!(integrator::OperatorSplittingIntegrator) while tdir(integrator) * integrator.t < SciMLBase.first_tstop(integrator) step_header!(integrator) @timeit_debug "check_error" SciMLBase.check_error!(integrator) ∉ ( - ReturnCode.Success, ReturnCode.Default)&&return + ReturnCode.Success, ReturnCode.Default, + )&&return __step!(integrator) step_footer!(integrator) if !SciMLBase.has_tstop(integrator) @@ -419,7 +427,8 @@ function DiffEqBase.solve!(integrator::OperatorSplittingIntegrator) return integrator.sol end return integrator.sol = SciMLBase.solution_new_retcode( - integrator.sol, ReturnCode.Success) + integrator.sol, ReturnCode.Success + ) end function DiffEqBase.step!(integrator::OperatorSplittingIntegrator) @@ -428,7 +437,8 @@ function DiffEqBase.step!(integrator::OperatorSplittingIntegrator) while !reached_tstop(integrator, tstop) step_header!(integrator) @timeit_debug "check_error" SciMLBase.check_error!(integrator) ∉ ( - ReturnCode.Success, ReturnCode.Default)&&return + ReturnCode.Success, ReturnCode.Default, + )&&return __step!(integrator) step_footer!(integrator) if !SciMLBase.has_tstop(integrator) @@ -438,23 +448,25 @@ function DiffEqBase.step!(integrator::OperatorSplittingIntegrator) else step_header!(integrator) @timeit_debug "check_error" SciMLBase.check_error!(integrator) ∉ ( - ReturnCode.Success, ReturnCode.Default)&&return + ReturnCode.Success, ReturnCode.Default, + )&&return __step!(integrator) step_footer!(integrator) while !should_accept_step(integrator) step_header!(integrator) @timeit_debug "check_error" SciMLBase.check_error!(integrator) ∉ ( - ReturnCode.Success, ReturnCode.Default)&&return + ReturnCode.Success, ReturnCode.Default, + )&&return __step!(integrator) step_footer!(integrator) end end - OrdinaryDiffEqCore.handle_tstop!(integrator) + return OrdinaryDiffEqCore.handle_tstop!(integrator) end function SciMLBase.check_error(integrator::OperatorSplittingIntegrator) if !SciMLBase.successful_retcode(integrator.sol) && - integrator.sol.retcode != ReturnCode.Default + integrator.sol.retcode != ReturnCode.Default return integrator.sol.retcode end @@ -485,7 +497,7 @@ function check_error_subintegrators(integrator, subintegrator::DEIntegrator) end function DiffEqBase.step!(integrator::OperatorSplittingIntegrator, dt, stop_at_tdt = false) - @timeit_debug "step!" begin + return @timeit_debug "step!" begin # OridinaryDiffEq lets dt be negative if tdir is -1, but that's inconsistent dt <= zero(dt) && error("dt must be positive") stop_at_tdt && !integrator.dtchangeable && @@ -495,7 +507,8 @@ function DiffEqBase.step!(integrator::OperatorSplittingIntegrator, dt, stop_at_t while !reached_tstop(integrator, tnext, stop_at_tdt) step_header!(integrator) @timeit_debug "check_error" SciMLBase.check_error!(integrator) ∉ ( - ReturnCode.Success, ReturnCode.Default)&&return + ReturnCode.Success, ReturnCode.Default, + )&&return __step!(integrator) step_footer!(integrator) end @@ -513,7 +526,7 @@ end # TimeChoiceIterator API @inline function DiffEqBase.get_tmp_cache(integrator::OperatorSplittingIntegrator) # DiffEqBase.get_tmp_cache(integrator, integrator.alg, integrator.cache) - (integrator.tmp,) + return (integrator.tmp,) end # @inline function DiffEqBase.get_tmp_cache(integrator::OperatorSplittingIntegrator, ::AbstractOperatorSplittingAlgorithm, cache) # return (cache.tmp,) @@ -521,11 +534,12 @@ end # Interpolation # TODO via https://github.com/SciML/SciMLBase.jl/blob/master/src/interpolation.jl function linear_interpolation!(y, t, y1, y2, t1, t2) - y .= y1 + (t - t1) * (y2 - y1) / (t2 - t1) + return y .= y1 + (t - t1) * (y2 - y1) / (t2 - t1) end function (integrator::OperatorSplittingIntegrator)(tmp, t) - linear_interpolation!( - tmp, t, integrator.uprev, integrator.u, integrator.tprev, integrator.t) + return linear_interpolation!( + tmp, t, integrator.uprev, integrator.u, integrator.tprev, integrator.t + ) end """ @@ -536,7 +550,7 @@ Updates the controller using the current state of the integrator if the operator @inline function stepsize_controller!(integrator::OperatorSplittingIntegrator) algorithm = integrator.alg isadaptive(algorithm) || return nothing - stepsize_controller!(integrator, algorithm) + return stepsize_controller!(integrator, algorithm) end """ @@ -547,7 +561,7 @@ Updates `dtcache` of the integrator if the step is accepted and the operator spl @inline function step_accept_controller!(integrator::OperatorSplittingIntegrator) algorithm = integrator.alg isadaptive(algorithm) || return nothing - step_accept_controller!(integrator, algorithm, nothing) + return step_accept_controller!(integrator, algorithm, nothing) end """ @@ -558,7 +572,7 @@ Updates `dtcache` of the integrator if the step is rejected and the the operator @inline function step_reject_controller!(integrator::OperatorSplittingIntegrator) algorithm = integrator.alg isadaptive(algorithm) || return nothing - step_reject_controller!(integrator, algorithm, nothing) + return step_reject_controller!(integrator, algorithm, nothing) end # helper functions for dealing with time-reversed integrators in the same way @@ -577,36 +591,41 @@ end # Dunno stuff function SciMLBase.done(integrator::OperatorSplittingIntegrator) - if !(integrator.sol.retcode in ( - ReturnCode.Default, ReturnCode.Success)) + if !( + integrator.sol.retcode in ( + ReturnCode.Default, ReturnCode.Success, + ) + ) return true elseif isempty(integrator.tstops) SciMLBase.postamble!(integrator) return true end - false + return false end function SciMLBase.postamble!(integrator::OperatorSplittingIntegrator) - DiffEqBase.finalize!(integrator.callback, integrator.u, integrator.t, integrator) + return DiffEqBase.finalize!(integrator.callback, integrator.u, integrator.t, integrator) end function __step!(integrator) tnext = integrator.t + integrator.dt synchronize_subintegrator_tree!(integrator) advance_solution_to!(integrator, tnext) - stepsize_controller!(integrator) + return stepsize_controller!(integrator) end # solvers need to define this interface function advance_solution_to!(integrator::OperatorSplittingIntegrator, tnext) - advance_solution_to!(integrator, integrator.cache, tnext) + return advance_solution_to!(integrator, integrator.cache, tnext) end -function advance_solution_to!(outer_integrator::OperatorSplittingIntegrator, - integrator::DEIntegrator, solution_indices, sync, cache, tend) +function advance_solution_to!( + outer_integrator::OperatorSplittingIntegrator, + integrator::DEIntegrator, solution_indices, sync, cache, tend + ) dt = tend - integrator.t - SciMLBase.step!(integrator, dt, true) + return SciMLBase.step!(integrator, dt, true) end # ----------------------------------- SciMLBase.jl Integrator Interface ------------------------------------ @@ -620,21 +639,21 @@ DiffEqBase.get_dt(integrator::OperatorSplittingIntegrator) = integrator.dt function set_dt!(integrator::OperatorSplittingIntegrator, dt) # TODO: figure out interface for recomputing other objects (linear operators, etc) dt <= zero(dt) && error("dt must be positive") - integrator.dt = dt + return integrator.dt = dt end function DiffEqBase.add_tstop!(integrator::OperatorSplittingIntegrator, t) is_past_t(integrator, t) && error("Cannot add a tstop at $t because that is behind the current \ integrator time $(integrator.t)") - push!(integrator.tstops, t) + return push!(integrator.tstops, t) end function DiffEqBase.add_saveat!(integrator::OperatorSplittingIntegrator, t) is_past_t(integrator, t) && error("Cannot add a saveat point at $t because that is behind the \ current integrator time $(integrator.t)") - push!(integrator.saveat, t) + return push!(integrator.saveat, t) end # not sure what this should do? @@ -642,30 +661,35 @@ end DiffEqBase.u_modified!(i::OperatorSplittingIntegrator, bool) = nothing function synchronize_subintegrator_tree!(integrator::OperatorSplittingIntegrator) - synchronize_subintegrator!(integrator.subintegrator_tree, integrator) + return synchronize_subintegrator!(integrator.subintegrator_tree, integrator) end @unroll function synchronize_subintegrator!( - subintegrator_tree::Tuple, integrator::OperatorSplittingIntegrator) + subintegrator_tree::Tuple, integrator::OperatorSplittingIntegrator + ) @unroll for subintegrator in subintegrator_tree synchronize_subintegrator!(subintegrator, integrator) end end function synchronize_subintegrator!( - subintegrator::DEIntegrator, integrator::OperatorSplittingIntegrator) + subintegrator::DEIntegrator, integrator::OperatorSplittingIntegrator + ) @unpack t, dt = integrator @assert subintegrator.t == t - if !isadaptive(subintegrator) + return if !isadaptive(subintegrator) SciMLBase.set_proposed_dt!(subintegrator, dt) end end -function advance_solution_to!(integrator::OperatorSplittingIntegrator, - cache::AbstractOperatorSplittingCache, tnext::Number) - advance_solution_to!( +function advance_solution_to!( + integrator::OperatorSplittingIntegrator, + cache::AbstractOperatorSplittingCache, tnext::Number + ) + return advance_solution_to!( integrator, integrator.subintegrator_tree, integrator.solution_index_tree, - integrator.synchronizer_tree, cache, tnext) + integrator.synchronizer_tree, cache, tnext + ) end # Dispatch for tree node construction @@ -676,7 +700,7 @@ function build_subintegrator_tree_with_cache( t0, dt, tf, tstops, saveat, d_discontinuities, callback, adaptive, verbose -) + ) (; f, p) = prob subintegrator_tree_with_caches = ntuple( i -> build_subintegrator_tree_with_cache( @@ -694,15 +718,17 @@ function build_subintegrator_tree_with_cache( ) subintegrator_tree = ntuple( - i -> subintegrator_tree_with_caches[i][1], length(f.functions)) + i -> subintegrator_tree_with_caches[i][1], length(f.functions) + ) caches = ntuple(i -> subintegrator_tree_with_caches[i][2], length(f.functions)) # TODO fix mixed device type problems we have to be smarter return subintegrator_tree, - init_cache(f, alg; - uprev = uprevouter, u = uouter, alias_u = true, - inner_caches = caches - ) + init_cache( + f, alg; + uprev = uprevouter, u = uouter, alias_u = true, + inner_caches = caches + ) end function build_subintegrator_tree_with_cache( @@ -715,7 +741,7 @@ function build_subintegrator_tree_with_cache( adaptive, verbose, save_end = false, controller = nothing -) + ) subintegrator_tree_with_caches = ntuple( i -> build_subintegrator_tree_with_cache( prob, @@ -738,10 +764,11 @@ function build_subintegrator_tree_with_cache( uprev = @view uprevouter[solution_indices] u = @view uouter[solution_indices] return subintegrator_tree, - init_cache(f, alg; - uprev = uprev, u = u, - inner_caches = inner_caches - ) + init_cache( + f, alg; + uprev = uprev, u = u, + inner_caches = inner_caches + ) end function build_subintegrator_tree_with_cache( @@ -755,7 +782,7 @@ function build_subintegrator_tree_with_cache( adaptive, verbose, save_end = false, controller = nothing -) where {S, T, P, F} + ) where {S, T, P, F} uprev = @view uprevouter[solution_indices] u = @view uouter[solution_indices] @@ -787,10 +814,13 @@ end function forward_sync_subintegrator!( outer_integrator::OperatorSplittingIntegrator, subintegrator_tree::Tuple, - solution_indices::Tuple, synchronizers::Tuple) - nothing + solution_indices::Tuple, synchronizers::Tuple + ) + return nothing end -function backward_sync_subintegrator!(outer_integrator::OperatorSplittingIntegrator, - subintegrator_tree::Tuple, solution_indices::Tuple, synchronizer::Tuple) - nothing +function backward_sync_subintegrator!( + outer_integrator::OperatorSplittingIntegrator, + subintegrator_tree::Tuple, solution_indices::Tuple, synchronizer::Tuple + ) + return nothing end diff --git a/src/problem.jl b/src/problem.jl index 36fca68..59e5b1d 100644 --- a/src/problem.jl +++ b/src/problem.jl @@ -2,23 +2,30 @@ OperatorSplittingProblem(f::AbstractOperatorSplitFunction, u0, tspan, p::Tuple) """ mutable struct OperatorSplittingProblem{ - fType <: AbstractOperatorSplitFunction, uType, tType, pType <: Tuple, K} <: - SciMLBase.AbstractODEProblem{uType, tType, true} + fType <: AbstractOperatorSplitFunction, uType, tType, pType <: Tuple, K, + } <: + SciMLBase.AbstractODEProblem{uType, tType, true} f::fType u0::uType tspan::tType p::pType kwargs::K # TODO what to do with these? - function OperatorSplittingProblem(f::AbstractOperatorSplitFunction, + function OperatorSplittingProblem( + f::AbstractOperatorSplitFunction, u0, tspan, p = recursive_null_parameters(f); - kwargs...) - new{typeof(f), typeof(u0), + kwargs... + ) + return new{ + typeof(f), typeof(u0), typeof(tspan), typeof(p), - typeof(kwargs)}(f, + typeof(kwargs), + }( + f, u0, tspan, p, - kwargs) + kwargs + ) end end @@ -26,8 +33,8 @@ num_operators(prob::OperatorSplittingProblem) = num_operators(prob.f) recursive_null_parameters(f::AbstractOperatorSplitFunction) = @error "Not implemented" function recursive_null_parameters(f::GenericSplitFunction) - ntuple(i->recursive_null_parameters(get_operator(f, i)), length(f.functions)) + return ntuple(i -> recursive_null_parameters(get_operator(f, i)), length(f.functions)) end function recursive_null_parameters(f) # Wildcard for leafs - NullParameters() + return NullParameters() end diff --git a/src/solver.jl b/src/solver.jl index dcee8b4..c638fd5 100644 --- a/src/solver.jl +++ b/src/solver.jl @@ -15,21 +15,23 @@ struct LieTrotterGodunovCache{uType, uprevType, iiType} <: AbstractOperatorSplit inner_caches::iiType end -function init_cache(f::GenericSplitFunction, alg::LieTrotterGodunov; +function init_cache( + f::GenericSplitFunction, alg::LieTrotterGodunov; uprev::AbstractArray, u::AbstractVector, inner_caches, alias_uprev = true, alias_u = false -) + ) _uprev = alias_uprev ? uprev : RecursiveArrayTools.recursivecopy(uprev) _u = alias_u ? u : RecursiveArrayTools.recursivecopy(u) - LieTrotterGodunovCache(_u, _uprev, inner_caches) + return LieTrotterGodunovCache(_u, _uprev, inner_caches) end @inline @unroll function advance_solution_to!( outer_integrator::OperatorSplittingIntegrator, subintegrators::Tuple, solution_indices::Tuple, - synchronizers::Tuple, cache::LieTrotterGodunovCache, tnext) + synchronizers::Tuple, cache::LieTrotterGodunovCache, tnext + ) # We assume that the integrators are already synced @unpack inner_caches = cache # For each inner operator @@ -42,10 +44,11 @@ end @timeit_debug "sync ->" forward_sync_subintegrator!(outer_integrator, subinteg, idxs, synchronizer) @timeit_debug "time solve" advance_solution_to!( - outer_integrator, subinteg, idxs, synchronizer, cache, tnext) + outer_integrator, subinteg, idxs, synchronizer, cache, tnext + ) if !(subinteg isa Tuple) && - subinteg.sol.retcode ∉ - (ReturnCode.Default, ReturnCode.Success) + subinteg.sol.retcode ∉ + (ReturnCode.Default, ReturnCode.Success) return end backward_sync_subintegrator!(outer_integrator, subinteg, idxs, synchronizer) diff --git a/src/utils.jl b/src/utils.jl index 2421b0c..dcc9ee0 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -42,7 +42,7 @@ need_sync(a::SubArray, b::SubArray) = a.parent !== b.parent Copies the information in object b into object a, if synchronization is necessary. """ function sync_vectors!(a, b) - if need_sync(a, b) && a !== b + return if need_sync(a, b) && a !== b a .= b end end @@ -55,10 +55,12 @@ If the inner integrator is synchronized with other inner integrators using `sync The `sync` object is passed from the outside and is the main entry point to dispatch custom types on for parameter synchronization. The `solution_indices` are global indices in the outer integrators solution vectors. """ -function forward_sync_subintegrator!(outer_integrator::OperatorSplittingIntegrator, - inner_integrator::DEIntegrator, solution_indices, sync) +function forward_sync_subintegrator!( + outer_integrator::OperatorSplittingIntegrator, + inner_integrator::DEIntegrator, solution_indices, sync + ) forward_sync_internal!(outer_integrator, inner_integrator, solution_indices) - forward_sync_external!(outer_integrator, inner_integrator, sync) + return forward_sync_external!(outer_integrator, inner_integrator, sync) end """ @@ -69,85 +71,110 @@ If the inner integrator is synchronized with other inner integrators using `sync The `sync` object is passed from the outside and is the main entry point to dispatch custom types on for parameter synchronization. The `solution_indices` are global indices in the outer integrators solution vectors. """ -function backward_sync_subintegrator!(outer_integrator::OperatorSplittingIntegrator, - inner_integrator::DEIntegrator, solution_indices, sync) +function backward_sync_subintegrator!( + outer_integrator::OperatorSplittingIntegrator, + inner_integrator::DEIntegrator, solution_indices, sync + ) backward_sync_internal!(outer_integrator, inner_integrator, solution_indices) - backward_sync_external!(outer_integrator, inner_integrator, sync) + return backward_sync_external!(outer_integrator, inner_integrator, sync) end # This is a bit tricky, because per default the operator splitting integrators share their solution vector. However, there is also the case # when part of the problem is on a different device (thing e.g. about operator A being on CPU and B being on GPU). # This case should be handled with special synchronizers. -function forward_sync_internal!(outer_integrator::OperatorSplittingIntegrator, - inner_integrator::OperatorSplittingIntegrator, solution_indices) - nothing -end -function backward_sync_internal!(outer_integrator::OperatorSplittingIntegrator, - inner_integrator::OperatorSplittingIntegrator, solution_indices) - nothing -end - -function forward_sync_internal!(outer_integrator::OperatorSplittingIntegrator, - inner_integrator::DEIntegrator, solution_indices) +function forward_sync_internal!( + outer_integrator::OperatorSplittingIntegrator, + inner_integrator::OperatorSplittingIntegrator, solution_indices + ) + return nothing +end +function backward_sync_internal!( + outer_integrator::OperatorSplittingIntegrator, + inner_integrator::OperatorSplittingIntegrator, solution_indices + ) + return nothing +end + +function forward_sync_internal!( + outer_integrator::OperatorSplittingIntegrator, + inner_integrator::DEIntegrator, solution_indices + ) @views uouter = outer_integrator.u[solution_indices] sync_vectors!(inner_integrator.uprev, uouter) sync_vectors!(inner_integrator.u, uouter) - SciMLBase.u_modified!(inner_integrator, true) + return SciMLBase.u_modified!(inner_integrator, true) end -function backward_sync_internal!(outer_integrator::OperatorSplittingIntegrator, - inner_integrator::DEIntegrator, solution_indices) +function backward_sync_internal!( + outer_integrator::OperatorSplittingIntegrator, + inner_integrator::DEIntegrator, solution_indices + ) @views uouter = outer_integrator.u[solution_indices] - sync_vectors!(uouter, inner_integrator.u) + return sync_vectors!(uouter, inner_integrator.u) end # This is a noop, because operator splitting integrators do not have parameters for now -function forward_sync_external!(outer_integrator::OperatorSplittingIntegrator, - inner_integrator::OperatorSplittingIntegrator, sync::NoExternalSynchronization) - nothing -end -function forward_sync_external!(outer_integrator::OperatorSplittingIntegrator, - inner_integrator::DEIntegrator, sync::NoExternalSynchronization) -nothing -end -function forward_sync_external!(outer_integrator::OperatorSplittingIntegrator, - inner_integrator::DEIntegrator, sync) - synchronize_solution_with_parameters!(outer_integrator, inner_integrator.p, sync) -end - -function backward_sync_external!(outer_integrator::OperatorSplittingIntegrator, - inner_integrator::OperatorSplittingIntegrator, sync::NoExternalSynchronization) - nothing -end -function backward_sync_external!(outer_integrator::OperatorSplittingIntegrator, - inner_integrator::DEIntegrator, sync::NoExternalSynchronization) - nothing -end -function backward_sync_external!(outer_integrator::OperatorSplittingIntegrator, - inner_integrator::DEIntegrator, sync) - synchronize_solution_with_parameters!(outer_integrator, inner_integrator.p, sync) +function forward_sync_external!( + outer_integrator::OperatorSplittingIntegrator, + inner_integrator::OperatorSplittingIntegrator, sync::NoExternalSynchronization + ) + return nothing +end +function forward_sync_external!( + outer_integrator::OperatorSplittingIntegrator, + inner_integrator::DEIntegrator, sync::NoExternalSynchronization + ) + return nothing +end +function forward_sync_external!( + outer_integrator::OperatorSplittingIntegrator, + inner_integrator::DEIntegrator, sync + ) + return synchronize_solution_with_parameters!(outer_integrator, inner_integrator.p, sync) +end + +function backward_sync_external!( + outer_integrator::OperatorSplittingIntegrator, + inner_integrator::OperatorSplittingIntegrator, sync::NoExternalSynchronization + ) + return nothing +end +function backward_sync_external!( + outer_integrator::OperatorSplittingIntegrator, + inner_integrator::DEIntegrator, sync::NoExternalSynchronization + ) + return nothing +end +function backward_sync_external!( + outer_integrator::OperatorSplittingIntegrator, + inner_integrator::DEIntegrator, sync + ) + return synchronize_solution_with_parameters!(outer_integrator, inner_integrator.p, sync) end function synchronize_solution_with_parameters!(outer_integrator::OperatorSplittingIntegrator, p, sync) - @warn "Outer synchronizer not dispatched for parameter type $(typeof(p)) with synchronizer type $(typeof(sync))." maxlog=1 - nothing + @warn "Outer synchronizer not dispatched for parameter type $(typeof(p)) with synchronizer type $(typeof(sync))." maxlog = 1 + return nothing end # If we encounter NullParameters, then we have the convention for the standard sync map that no external solution is necessary. function synchronize_solution_with_parameters!( - outer_integrator::OperatorSplittingIntegrator, p::NullParameters, sync) - nothing + outer_integrator::OperatorSplittingIntegrator, p::NullParameters, sync + ) + return nothing end # TODO this should go into a custom tree data structure instead of into a tuple-tree function build_solution_index_tree(f::GenericSplitFunction) return ntuple( - i->build_solution_index_tree_recursion(f.functions[i], f.solution_indices[i]), - length(f.functions)) + i -> build_solution_index_tree_recursion(f.functions[i], f.solution_indices[i]), + length(f.functions) + ) end function build_solution_index_tree_recursion(f::GenericSplitFunction, solution_indices) return ntuple( - i->build_solution_index_tree_recursion(f.functions[i], f.solution_indices[i]), - length(f.functions)) + i -> build_solution_index_tree_recursion(f.functions[i], f.solution_indices[i]), + length(f.functions) + ) end function build_solution_index_tree_recursion(f, solution_indices) @@ -155,11 +182,11 @@ function build_solution_index_tree_recursion(f, solution_indices) end function build_synchronizer_tree(f::GenericSplitFunction) - return ntuple(i->build_synchronizer_tree_recursion(f.functions[i], f.synchronizers[i]), length(f.functions)) + return ntuple(i -> build_synchronizer_tree_recursion(f.functions[i], f.synchronizers[i]), length(f.functions)) end function build_synchronizer_tree_recursion(f::GenericSplitFunction, synchronizers) - return ntuple(i->build_synchronizer_tree_recursion(f.functions[i], f.synchronizers[i]), length(f.functions)) + return ntuple(i -> build_synchronizer_tree_recursion(f.functions[i], f.synchronizers[i]), length(f.functions)) end function build_synchronizer_tree_recursion(f, synchronizer) diff --git a/test/alias_u0.jl b/test/alias_u0.jl index aeeb8aa..3646a1a 100644 --- a/test/alias_u0.jl +++ b/test/alias_u0.jl @@ -4,14 +4,14 @@ using DiffEqBase using OrdinaryDiffEqLowOrderRK function ode1(du, u, p, t) - @. du = -0.1u + return @. du = -0.1u end f1 = ODEFunction(ode1) f1dofs = [1, 2, 3] function ode2(du, u, p, t) du[1] = -0.01u[2] - du[2] = -0.01u[1] + return du[2] = -0.01u[1] end f2 = ODEFunction(ode2) f2dofs = [1, 3] diff --git a/test/consistency.jl b/test/consistency.jl index 2f332c3..da60f10 100644 --- a/test/consistency.jl +++ b/test/consistency.jl @@ -13,8 +13,8 @@ prob1 = OperatorSplittingProblem(split_f, u0, tspan) prob2 = ODEProblem(f, u0, tspan) splitting_solver = LieTrotterGodunov((Euler(),)) -integrator1 = init(prob1, splitting_solver; dt=dt) -integrator2 = init(prob2, Euler(); dt=dt) +integrator1 = init(prob1, splitting_solver; dt = dt) +integrator2 = init(prob2, Euler(); dt = dt) for ((u1, t1), (u2, t2)) in zip(TimeChoiceIterator(integrator1, tspan[1]:(2dt):tspan[2]), TimeChoiceIterator(integrator2, tspan[1]:(2dt):tspan[2])) @test u1 ≈ u2 @test t1 ≈ t2 diff --git a/test/operator_splitting_api.jl b/test/operator_splitting_api.jl index 8067bd3..51537c5 100644 --- a/test/operator_splitting_api.jl +++ b/test/operator_splitting_api.jl @@ -10,40 +10,46 @@ using ModelingToolkit # Reference tspan = (0.0, 100.0) -u0 = [0.7611944793397108 - 0.9059606424982555 - 0.5755174199139956] -trueA = [-0.1 0.0 -0.0; - 0.0 -0.1 0.0; - -0.0 0.0 -0.1] -trueB = [-0.0 0.0 -0.01; - 0.0 -0.0 0.0; - -0.01 0.0 -0.0] +u0 = [ + 0.7611944793397108 + 0.9059606424982555 + 0.5755174199139956 +] +trueA = [ + -0.1 0.0 -0.0; + 0.0 -0.1 0.0; + -0.0 0.0 -0.1 +] +trueB = [ + -0.0 0.0 -0.01; + 0.0 -0.0 0.0; + -0.01 0.0 -0.0 +] function ode_true(du, u, p, t) du .= -0.1u du[1] -= 0.01u[3] - du[3] -= 0.01u[1] + return du[3] -= 0.01u[1] end trueu = exp((tspan[2] - tspan[1]) * (trueA + trueB)) * u0 # Setup individual functions # Diagonal components function ode1(du, u, p, t) - @. du = -0.1u + return @. du = -0.1u end f1 = ODEFunction(ode1) # Off-diagonal components function ode2(du, u, p, t) du[1] = -0.01u[2] - du[2] = -0.01u[1] + return du[2] = -0.01u[1] end f2 = ODEFunction(ode2) # Now some recursive splitting function ode3(du, u, p, t) du[1] = -0.005u[2] - du[2] = -0.005u[1] + return du[2] = -0.005u[1] end f3 = ODEFunction(ode3) # The time stepper carries the individual solver information. @@ -61,7 +67,7 @@ Dt = Differential(time) end end @named testmodel2 = TestModelODE2() -testsys2 = mtkcompile(testmodel2; sort_eqs=false) +testsys2 = mtkcompile(testmodel2; sort_eqs = false) # Test whether adaptive code path works in principle struct FakeAdaptiveAlgorithm{T} <: OS.AbstractOperatorSplittingAlgorithm @@ -85,13 +91,13 @@ end return nothing # Do nothing end function OS.build_subintegrator_tree_with_cache( - prob::OS.OperatorSplittingProblem, alg::FakeAdaptiveAlgorithm, - uprevouter::AbstractVector, uouter::AbstractVector, - solution_indices, - t0, dt, tf, - tstops, saveat, d_discontinuities, callback, - adaptive, verbose, -) + prob::OS.OperatorSplittingProblem, alg::FakeAdaptiveAlgorithm, + uprevouter::AbstractVector, uouter::AbstractVector, + solution_indices, + t0, dt, tf, + tstops, saveat, d_discontinuities, callback, + adaptive, verbose, + ) subintegrators, inner_cache = OS.build_subintegrator_tree_with_cache( prob, alg.alg, uprevouter, uouter, solution_indices, t0, dt, tf, @@ -100,20 +106,20 @@ function OS.build_subintegrator_tree_with_cache( ) return subintegrators, FakeAdaptiveAlgorithmCache( - inner_cache, - ) + inner_cache, + ) end function OS.build_subintegrator_tree_with_cache( - prob::OS.OperatorSplittingProblem, alg::FakeAdaptiveAlgorithm, - f::GenericSplitFunction, p::Tuple, - uprevouter::AbstractVector, uouter::AbstractVector, - solution_indices, - t0, dt, tf, - tstops, saveat, d_discontinuities, callback, - adaptive, verbose, - save_end = false, - controller = nothing -) + prob::OS.OperatorSplittingProblem, alg::FakeAdaptiveAlgorithm, + f::GenericSplitFunction, p::Tuple, + uprevouter::AbstractVector, uouter::AbstractVector, + solution_indices, + t0, dt, tf, + tstops, saveat, d_discontinuities, callback, + adaptive, verbose, + save_end = false, + controller = nothing + ) subintegrators, inner_cache = OS.build_subintegrator_tree_with_cache( prob, alg.alg, f, p, uprevouter, uouter, solution_indices, t0, dt, tf, @@ -122,14 +128,14 @@ function OS.build_subintegrator_tree_with_cache( ) return subintegrators, FakeAdaptiveAlgorithmCache( - inner_cache, - ) + inner_cache, + ) end FakeAdaptiveLTG(inner) = FakeAdaptiveAlgorithm(LieTrotterGodunov(inner)) @inline DiffEqBase.get_tmp_cache(integrator::OS.OperatorSplittingIntegrator, alg::OS.AbstractOperatorSplittingAlgorithm, cache::FakeAdaptiveAlgorithmCache) = DiffEqBase.get_tmp_cache(integrator, alg, cache.cache) @inline function OS.advance_solution_to!(outer_integrator::OS.OperatorSplittingIntegrator, subintegrators::Tuple, solution_indices::Tuple, synchronizers::Tuple, cache::FakeAdaptiveAlgorithmCache, tnext) - OS.advance_solution_to!(outer_integrator, subintegrators, solution_indices, synchronizers, cache.cache, tnext) + return OS.advance_solution_to!(outer_integrator, subintegrators, solution_indices, synchronizers, cache.cache, tnext) end @testset "reinit and convergence" begin @@ -158,56 +164,57 @@ end prob2 = OperatorSplittingProblem(fsplit2_outer, u0, tspan) for TimeStepperType in (LieTrotterGodunov, FakeAdaptiveLTG) @testset "Solver type $TimeStepperType | $tstepper" for (prob, tstepper) in ( - (prob1a, TimeStepperType((Euler(), Euler()))), - (prob1a, TimeStepperType((Tsit5(), Euler()))), - (prob1a, TimeStepperType((Euler(), Tsit5()))), - (prob1a, TimeStepperType((Tsit5(), Tsit5()))), - (prob1b, TimeStepperType((Euler(), Euler()))), - (prob1b, TimeStepperType((Tsit5(), Euler()))), - (prob1b, TimeStepperType((Euler(), Tsit5()))), - (prob1b, TimeStepperType((Tsit5(), Tsit5()))), - (prob2, TimeStepperType((Euler(), TimeStepperType((Euler(), Euler()))))), - (prob2, TimeStepperType((Euler(), TimeStepperType((Tsit5(), Euler()))))), - (prob2, TimeStepperType((Euler(), TimeStepperType((Euler(), Tsit5()))))), - (prob2, TimeStepperType((Tsit5(), TimeStepperType((Tsit5(), Euler()))))), - (prob2, TimeStepperType((Tsit5(), TimeStepperType((Euler(), Tsit5()))))), - (prob2, TimeStepperType((Tsit5(), TimeStepperType((Tsit5(), Tsit5()))))) - ) + (prob1a, TimeStepperType((Euler(), Euler()))), + (prob1a, TimeStepperType((Tsit5(), Euler()))), + (prob1a, TimeStepperType((Euler(), Tsit5()))), + (prob1a, TimeStepperType((Tsit5(), Tsit5()))), + (prob1b, TimeStepperType((Euler(), Euler()))), + (prob1b, TimeStepperType((Tsit5(), Euler()))), + (prob1b, TimeStepperType((Euler(), Tsit5()))), + (prob1b, TimeStepperType((Tsit5(), Tsit5()))), + (prob2, TimeStepperType((Euler(), TimeStepperType((Euler(), Euler()))))), + (prob2, TimeStepperType((Euler(), TimeStepperType((Tsit5(), Euler()))))), + (prob2, TimeStepperType((Euler(), TimeStepperType((Euler(), Tsit5()))))), + (prob2, TimeStepperType((Tsit5(), TimeStepperType((Tsit5(), Euler()))))), + (prob2, TimeStepperType((Tsit5(), TimeStepperType((Euler(), Tsit5()))))), + (prob2, TimeStepperType((Tsit5(), TimeStepperType((Tsit5(), Tsit5()))))), + ) # The remaining code works as usual. integrator = DiffEqBase.init( - prob, tstepper, dt = dt, verbose = true, alias_u0 = false, adaptive=false) + prob, tstepper, dt = dt, verbose = true, alias_u0 = false, adaptive = false + ) @test integrator.sol.retcode == DiffEqBase.ReturnCode.Default DiffEqBase.solve!(integrator) @test integrator.sol.retcode == DiffEqBase.ReturnCode.Success ufinal = copy(integrator.u) - @test isapprox(ufinal, trueu, atol = 1e-6) + @test isapprox(ufinal, trueu, atol = 1.0e-6) @test integrator.t ≈ tspan[2] @test integrator.subintegrator_tree[1].t ≈ tspan[2] @test integrator.dtcache ≈ dt - @test integrator.iter == ceil(Int, (tspan[2]-tspan[1])/dt) - @test integrator.subintegrator_tree[1].iter == ceil(Int, (tspan[2]-tspan[1])/dt) + @test integrator.iter == ceil(Int, (tspan[2] - tspan[1]) / dt) + @test integrator.subintegrator_tree[1].iter == ceil(Int, (tspan[2] - tspan[1]) / dt) DiffEqBase.reinit!(integrator; dt = dt) @test integrator.sol.retcode == DiffEqBase.ReturnCode.Default for (u, t) in DiffEqBase.TimeChoiceIterator(integrator, tspan[1]:5.0:tspan[2]) end - @test isapprox(ufinal, integrator.u, atol = 1e-12) + @test isapprox(ufinal, integrator.u, atol = 1.0e-12) @test integrator.t ≈ tspan[2] @test integrator.subintegrator_tree[1].t ≈ tspan[2] @test integrator.dtcache ≈ dt - @test integrator.iter == ceil(Int, (tspan[2]-tspan[1])/dt) - @test integrator.subintegrator_tree[1].iter == ceil(Int, (tspan[2]-tspan[1])/dt) + @test integrator.iter == ceil(Int, (tspan[2] - tspan[1]) / dt) + @test integrator.subintegrator_tree[1].iter == ceil(Int, (tspan[2] - tspan[1]) / dt) DiffEqBase.reinit!(integrator; dt = dt) @test integrator.sol.retcode == DiffEqBase.ReturnCode.Default for (uprev, tprev, u, t) in DiffEqBase.intervals(integrator) end - @test isapprox(ufinal, integrator.u, atol = 1e-12) + @test isapprox(ufinal, integrator.u, atol = 1.0e-12) @test integrator.t ≈ tspan[2] @test integrator.subintegrator_tree[1].t ≈ tspan[2] @test integrator.dtcache ≈ dt - @test integrator.iter == ceil(Int, (tspan[2]-tspan[1])/dt) - @test integrator.subintegrator_tree[1].iter == ceil(Int, (tspan[2]-tspan[1])/dt) + @test integrator.iter == ceil(Int, (tspan[2] - tspan[1]) / dt) + @test integrator.subintegrator_tree[1].iter == ceil(Int, (tspan[2] - tspan[1]) / dt) DiffEqBase.reinit!(integrator; dt = dt) @test integrator.sol.retcode == DiffEqBase.ReturnCode.Default @@ -216,51 +223,52 @@ end @test integrator.t ≈ tspan[2] @test integrator.subintegrator_tree[1].t ≈ tspan[2] @test integrator.dtcache ≈ dt - @test integrator.iter == ceil(Int, (tspan[2]-tspan[1])/dt) - @test integrator.subintegrator_tree[1].iter == ceil(Int, (tspan[2]-tspan[1])/dt) + @test integrator.iter == ceil(Int, (tspan[2] - tspan[1]) / dt) + @test integrator.subintegrator_tree[1].iter == ceil(Int, (tspan[2] - tspan[1]) / dt) end end for TimeStepperType in (FakeAdaptiveLTG,) @testset "Adaptive solver type $TimeStepperType | $tstepper" for (prob, tstepper) in ( - (prob1a, TimeStepperType((Tsit5(), Tsit5()))), - (prob2, TimeStepperType((Tsit5(), TimeStepperType((Tsit5(), Tsit5()))))) - ) + (prob1a, TimeStepperType((Tsit5(), Tsit5()))), + (prob2, TimeStepperType((Tsit5(), TimeStepperType((Tsit5(), Tsit5()))))), + ) # The remaining code works as usual. integrator = DiffEqBase.init( - prob, tstepper, dt = dt, verbose = true, alias_u0 = false, adaptive=true) + prob, tstepper, dt = dt, verbose = true, alias_u0 = false, adaptive = true + ) @test integrator.sol.retcode == DiffEqBase.ReturnCode.Default DiffEqBase.solve!(integrator) @test integrator.sol.retcode == DiffEqBase.ReturnCode.Success ufinal = copy(integrator.u) - @test isapprox(ufinal, trueu, atol = 1e-6) + @test isapprox(ufinal, trueu, atol = 1.0e-6) @test integrator.t ≈ tspan[2] @test integrator.subintegrator_tree[1].t ≈ tspan[2] @test integrator.dtcache ≈ dt - @test integrator.iter == ceil(Int, (tspan[2]-tspan[1])/dt) - @test integrator.subintegrator_tree[1].iter == ceil(Int, (tspan[2]-tspan[1])/dt) + @test integrator.iter == ceil(Int, (tspan[2] - tspan[1]) / dt) + @test integrator.subintegrator_tree[1].iter == ceil(Int, (tspan[2] - tspan[1]) / dt) DiffEqBase.reinit!(integrator; dt = dt) @test integrator.sol.retcode == DiffEqBase.ReturnCode.Default for (u, t) in DiffEqBase.TimeChoiceIterator(integrator, tspan[1]:5.0:tspan[2]) end - @test isapprox(ufinal, integrator.u, atol = 1e-12) + @test isapprox(ufinal, integrator.u, atol = 1.0e-12) @test integrator.t ≈ tspan[2] @test integrator.subintegrator_tree[1].t ≈ tspan[2] @test integrator.dtcache ≈ dt - @test integrator.iter == ceil(Int, (tspan[2]-tspan[1])/dt) - @test integrator.subintegrator_tree[1].iter == ceil(Int, (tspan[2]-tspan[1])/dt) + @test integrator.iter == ceil(Int, (tspan[2] - tspan[1]) / dt) + @test integrator.subintegrator_tree[1].iter == ceil(Int, (tspan[2] - tspan[1]) / dt) DiffEqBase.reinit!(integrator; dt = dt) @test integrator.sol.retcode == DiffEqBase.ReturnCode.Default for (uprev, tprev, u, t) in DiffEqBase.intervals(integrator) end - @test isapprox(ufinal, integrator.u, atol = 1e-12) + @test isapprox(ufinal, integrator.u, atol = 1.0e-12) @test integrator.t ≈ tspan[2] @test integrator.subintegrator_tree[1].t ≈ tspan[2] @test integrator.dtcache ≈ dt - @test integrator.iter == ceil(Int, (tspan[2]-tspan[1])/dt) - @test integrator.subintegrator_tree[1].iter == ceil(Int, (tspan[2]-tspan[1])/dt) + @test integrator.iter == ceil(Int, (tspan[2] - tspan[1]) / dt) + @test integrator.subintegrator_tree[1].iter == ceil(Int, (tspan[2] - tspan[1]) / dt) DiffEqBase.reinit!(integrator; dt = dt) @test integrator.sol.retcode == DiffEqBase.ReturnCode.Default @@ -269,8 +277,8 @@ end @test integrator.t ≈ tspan[2] @test integrator.subintegrator_tree[1].t ≈ tspan[2] @test integrator.dtcache ≈ dt - @test integrator.iter == ceil(Int, (tspan[2]-tspan[1])/dt) - @test integrator.subintegrator_tree[1].iter == ceil(Int, (tspan[2]-tspan[1])/dt) + @test integrator.iter == ceil(Int, (tspan[2] - tspan[1]) / dt) + @test integrator.subintegrator_tree[1].iter == ceil(Int, (tspan[2] - tspan[1]) / dt) end end @@ -292,17 +300,18 @@ end for TimeStepperType in (LieTrotterGodunov,) @testset "Solver type $TimeStepperType | $tstepper" for tstepper in ( - TimeStepperType((Euler(), Euler())), - TimeStepperType((Tsit5(), Euler())), - TimeStepperType((Euler(), Tsit5())), - TimeStepperType((Tsit5(), Tsit5())) - ) + TimeStepperType((Euler(), Euler())), + TimeStepperType((Tsit5(), Euler())), + TimeStepperType((Euler(), Tsit5())), + TimeStepperType((Tsit5(), Tsit5())), + ) integrator_NaN = DiffEqBase.init( - prob_NaN, tstepper, dt = dt, verbose = true, alias_u0 = false) + prob_NaN, tstepper, dt = dt, verbose = true, alias_u0 = false + ) @test integrator_NaN.sol.retcode == DiffEqBase.ReturnCode.Default DiffEqBase.solve!(integrator_NaN) @test integrator_NaN.sol.retcode ∈ - (DiffEqBase.ReturnCode.Unstable, DiffEqBase.ReturnCode.DtNaN) + (DiffEqBase.ReturnCode.Unstable, DiffEqBase.ReturnCode.DtNaN) end end end