diff --git a/examples/kdv_1d/kdv_1d_IMEX.jl b/examples/kdv_1d/kdv_1d_IMEX.jl new file mode 100644 index 00000000..6ca8c8c7 --- /dev/null +++ b/examples/kdv_1d/kdv_1d_IMEX.jl @@ -0,0 +1,44 @@ +using OrdinaryDiffEqSDIRK +using DispersiveShallowWater +using SummationByPartsOperators: upwind_operators, periodic_derivative_operator + +############################################################################### +# Semidiscretization of the KdV equation + +equations = KdVEquation1D(gravity = 9.81, D = 1.0) +initial_condition = initial_condition_convergence_test +boundary_conditions = boundary_condition_periodic + +# create homogeneous mesh +coordinates_min = -50.0 +coordinates_max = 50.0 +N = 512 +mesh = Mesh1D(coordinates_min, coordinates_max, N) + +# Create solver with periodic SBP operators of accuracy order 3, +# which results in a 4th order accurate semi discretizations. +# We can set the accuracy order of the upwind operators to 3 since +# we only use central versions/combinations of the upwind operators. +D1_upwind = upwind_operators(periodic_derivative_operator; + derivative_order = 1, accuracy_order = 3, + xmin = xmin(mesh), xmax = xmax(mesh), + N = nnodes(mesh)) +solver = Solver(D1_upwind) + +semi = Semidiscretization(mesh, equations, initial_condition, solver, + boundary_conditions = boundary_conditions) + +tspan = (0.0, 5.0) +ode = semidiscretize(semi, tspan) + +summary_callback = SummaryCallback() +analysis_callback = AnalysisCallback(semi; interval = 100, + extra_analysis_errors = (:conservation_error,), + extra_analysis_integrals = (waterheight_total, + waterheight)) +callbacks = CallbackSet(analysis_callback, summary_callback) +saveat = range(tspan..., length = 100) + +alg = KenCarp4() # use an IMEX method +sol = solve(ode, alg, abstol = 1e-7, reltol = 1e-7, + save_everystep = false, callback = callbacks, saveat = saveat) diff --git a/examples/kdv_1d/kdv_1d_basic.jl b/examples/kdv_1d/kdv_1d_basic.jl index 6b7de136..344775be 100644 --- a/examples/kdv_1d/kdv_1d_basic.jl +++ b/examples/kdv_1d/kdv_1d_basic.jl @@ -29,7 +29,7 @@ semi = Semidiscretization(mesh, equations, initial_condition, solver, boundary_conditions = boundary_conditions) tspan = (0.0, 5.0) -ode = semidiscretize(semi, tspan) +ode = semidiscretize(semi, tspan, split_ode = Val{false}()) # no IMEX for now summary_callback = SummaryCallback() analysis_callback = AnalysisCallback(semi; interval = 100, diff --git a/examples/kdv_1d/kdv_1d_fourier.jl b/examples/kdv_1d/kdv_1d_fourier.jl index 25d49015..0b1cf64a 100644 --- a/examples/kdv_1d/kdv_1d_fourier.jl +++ b/examples/kdv_1d/kdv_1d_fourier.jl @@ -25,7 +25,7 @@ semi = Semidiscretization(mesh, equations, initial_condition, solver, boundary_conditions = boundary_conditions) tspan = (0.0, 5.0) -ode = semidiscretize(semi, tspan) +ode = semidiscretize(semi, tspan, split_ode = Val{false}()) # no IMEX for now summary_callback = SummaryCallback() analysis_callback = AnalysisCallback(semi; interval = 100, diff --git a/examples/kdv_1d/kdv_1d_implicit.jl b/examples/kdv_1d/kdv_1d_implicit.jl index ee8c0fc2..9765f5be 100644 --- a/examples/kdv_1d/kdv_1d_implicit.jl +++ b/examples/kdv_1d/kdv_1d_implicit.jl @@ -29,7 +29,7 @@ semi = Semidiscretization(mesh, equations, initial_condition, solver, boundary_conditions = boundary_conditions) tspan = (0.0, 5.0) -ode = semidiscretize(semi, tspan) +ode = semidiscretize(semi, tspan, split_ode = Val{false}()) summary_callback = SummaryCallback() analysis_callback = AnalysisCallback(semi; interval = 100, diff --git a/examples/kdv_1d/kdv_1d_manufactured.jl b/examples/kdv_1d/kdv_1d_manufactured.jl index 81a8ecff..9b648594 100644 --- a/examples/kdv_1d/kdv_1d_manufactured.jl +++ b/examples/kdv_1d/kdv_1d_manufactured.jl @@ -32,7 +32,7 @@ semi = Semidiscretization(mesh, equations, initial_condition, solver, source_terms = source_terms) tspan = (0.0, 1.0) -ode = semidiscretize(semi, tspan) +ode = semidiscretize(semi, tspan, split_ode = Val{false}()) summary_callback = SummaryCallback() analysis_callback = AnalysisCallback(semi; interval = 100, @@ -46,3 +46,124 @@ saveat = range(tspan..., length = 100) alg = Rodas5() sol = solve(ode, alg, abstol = 1e-12, reltol = 1e-12, save_everystep = false, callback = callbacks, saveat = saveat) + +""" +using alg = KenCarp4() I get the following Benchmarks: +For some reason IMEX seems to perform worse - doing more steps +no IMEX: (split_ode = Val{false}()) +────────────────────────────────────────────────────────────────────────────────────── + DispersiveSWE Time Allocations + ─────────────────────── ──────────────────────── + Tot / % measured: 611ms / 34.2% 40.0MiB / 10.0% + +Section ncalls time %tot avg alloc %tot avg +────────────────────────────────────────────────────────────────────────────────────── +rhs! 6.62k 196ms 93.6% 29.6μs 3.94MiB 98.4% 624B + source terms 6.62k 108ms 51.5% 16.3μs 3.94MiB 98.4% 624B + third-order derivatives 6.62k 44.1ms 21.1% 6.66μs 0.00B 0.0% 0.00B + hyperbolic 6.62k 41.1ms 19.7% 6.22μs 0.00B 0.0% 0.00B + ~rhs!~ 6.62k 2.97ms 1.4% 450ns 1.25KiB 0.0% 0.19B +analyze solution 3 13.3ms 6.4% 4.45ms 64.8KiB 1.6% 21.6KiB +────────────────────────────────────────────────────────────────────────────────────── + +IMEX with sources in nonstiff part: +────────────────────────────────────────────────────────────────────────────────────── + DispersiveSWE Time Allocations + ─────────────────────── ──────────────────────── + Tot / % measured: 2.85s / 12.3% 98.6MiB / 5.7% + +Section ncalls time %tot avg alloc %tot avg +────────────────────────────────────────────────────────────────────────────────────── +rhs_split_nonstiff! 8.93k 169ms 48.3% 18.9μs 5.31MiB 93.8% 624B + source terms 8.93k 160ms 45.6% 17.9μs 5.31MiB 93.8% 624B + hyperbolic 8.93k 5.67ms 1.6% 635ns 0.00B 0.0% 0.00B + ~rhs_split_nonstiff!~ 8.93k 3.44ms 1.0% 386ns 976B 0.0% 0.11B +rhs_split_stiff! 45.4k 103ms 29.4% 2.27μs 672B 0.0% 0.01B + third-order derivatives 45.4k 95.9ms 27.4% 2.11μs 0.00B 0.0% 0.00B + ~rhs_split_stiff!~ 45.4k 7.07ms 2.0% 156ns 672B 0.0% 0.01B +analyze solution 15 78.2ms 22.3% 5.21ms 360KiB 6.2% 24.0KiB +────────────────────────────────────────────────────────────────────────────────────── + +IMEX with sources in stiff part: +────────────────────────────────────────────────────────────────────────────────────── + DispersiveSWE Time Allocations + ─────────────────────── ──────────────────────── + Tot / % measured: 4.89s / 28.2% 88.7MiB / 49.0% + +Section ncalls time %tot avg alloc %tot avg +────────────────────────────────────────────────────────────────────────────────────── +rhs_split_stiff! 72.2k 1.27s 92.0% 17.6μs 43.0MiB 98.9% 624B + source terms 72.2k 1.17s 85.2% 16.3μs 43.0MiB 98.9% 624B + third-order derivatives 72.2k 70.8ms 5.1% 981ns 0.00B 0.0% 0.00B + ~rhs_split_stiff!~ 72.2k 22.1ms 1.6% 306ns 976B 0.0% 0.01B +analyze solution 23 102ms 7.4% 4.42ms 501KiB 1.1% 21.8KiB +rhs_split_nonstiff! 13.3k 8.86ms 0.6% 665ns 672B 0.0% 0.05B + hyperbolic 13.3k 6.82ms 0.5% 512ns 0.00B 0.0% 0.00B + ~rhs_split_nonstiff!~ 13.3k 2.03ms 0.1% 153ns 672B 0.0% 0.05B +────────────────────────────────────────────────────────────────────────────────────── + + + + + + +No with fixed time step dt = 1e-2: + + + +no IMEX +────────────────────────────────────────────────────────────────────────────────────── + DispersiveSWE Time Allocations + ─────────────────────── ──────────────────────── + Tot / % measured: 6.33s / 57.1% 0.99GiB / 1.3% + +Section ncalls time %tot avg alloc %tot avg +────────────────────────────────────────────────────────────────────────────────────── +rhs! 22.0k 3.46s 96.0% 157μs 13.1MiB 99.8% 624B + third-order derivatives 22.0k 1.61s 44.6% 73.1μs 0.00B 0.0% 0.00B + hyperbolic 22.0k 1.49s 41.2% 67.5μs 0.00B 0.0% 0.00B + source terms 22.0k 355ms 9.8% 16.1μs 13.1MiB 99.8% 624B + ~rhs!~ 22.0k 13.5ms 0.4% 615ns 1.25KiB 0.0% 0.06B +analyze solution 1 146ms 4.0% 146ms 21.9KiB 0.2% 21.9KiB +────────────────────────────────────────────────────────────────────────────────────── + +IMEX with sources in nonstiff part: +────────────────────────────────────────────────────────────────────────────────────── + DispersiveSWE Time Allocations + ─────────────────────── ──────────────────────── + Tot / % measured: 4.34s / 39.4% 0.98GiB / 0.0% + +Section ncalls time %tot avg alloc %tot avg +────────────────────────────────────────────────────────────────────────────────────── +rhs_split_stiff! 22.2k 1.70s 99.0% 76.3μs 672B 0.2% 0.03B + third-order derivatives 22.2k 1.69s 98.7% 76.1μs 0.00B 0.0% 0.00B + ~rhs_split_stiff!~ 22.2k 5.55ms 0.3% 250ns 672B 0.2% 0.03B +rhs_split_nonstiff! 601 15.1ms 0.9% 25.0μs 367KiB 89.4% 626B + source terms 601 12.9ms 0.8% 21.5μs 366KiB 89.2% 624B + hyperbolic 601 1.41ms 0.1% 2.35μs 0.00B 0.0% 0.00B + ~rhs_split_nonstiff!~ 601 713μs 0.0% 1.19μs 976B 0.2% 1.62B +analyze solution 1 1.86ms 0.1% 1.86ms 42.7KiB 10.4% 42.7KiB +────────────────────────────────────────────────────────────────────────────────────── + +IMEX with sources in stiff part: +────────────────────────────────────────────────────────────────────────────────────── + DispersiveSWE Time Allocations + ─────────────────────── ──────────────────────── + Tot / % measured: 4.29s / 48.3% 0.99GiB / 1.3% + +Section ncalls time %tot avg alloc %tot avg +────────────────────────────────────────────────────────────────────────────────────── +rhs_split_stiff! 22.2k 2.06s 99.7% 92.9μs 13.2MiB 99.8% 624B + third-order derivatives 22.2k 1.70s 82.1% 76.4μs 0.00B 0.0% 0.00B + source terms 22.2k 356ms 17.2% 16.0μs 13.2MiB 99.8% 624B + ~rhs_split_stiff!~ 22.2k 8.90ms 0.4% 401ns 976B 0.0% 0.04B +analyze solution 1 3.62ms 0.2% 3.62ms 21.9KiB 0.2% 21.9KiB +rhs_split_nonstiff! 601 1.61ms 0.1% 2.68μs 672B 0.0% 1.12B + hyperbolic 601 1.16ms 0.1% 1.93μs 0.00B 0.0% 0.00B + ~rhs_split_nonstiff!~ 601 448μs 0.0% 746ns 672B 0.0% 1.12B +────────────────────────────────────────────────────────────────────────────────────── + + +Also for dt = 1e-1 the solution for IMEX already looks bad(ish), +will for split_ode = Val{false}() it looks still good. +""" diff --git a/examples/kdv_1d/kdv_1d_narrow_stencil.jl b/examples/kdv_1d/kdv_1d_narrow_stencil.jl index e5aae1c8..756f8482 100644 --- a/examples/kdv_1d/kdv_1d_narrow_stencil.jl +++ b/examples/kdv_1d/kdv_1d_narrow_stencil.jl @@ -22,7 +22,7 @@ semi = Semidiscretization(mesh, equations, initial_condition, solver, boundary_conditions = boundary_conditions) tspan = (0.0, 5.0) -ode = semidiscretize(semi, tspan) +ode = semidiscretize(semi, tspan, split_ode = Val{false}()) # no IMEX for now summary_callback = SummaryCallback() analysis_callback = AnalysisCallback(semi; interval = 100, diff --git a/src/DispersiveShallowWater.jl b/src/DispersiveShallowWater.jl index 215f89e7..727f0bb3 100644 --- a/src/DispersiveShallowWater.jl +++ b/src/DispersiveShallowWater.jl @@ -30,7 +30,7 @@ using RecursiveArrayTools: ArrayPartition using Reexport: @reexport using Roots: AlefeldPotraShi, find_zero -using SciMLBase: SciMLBase, DiscreteCallback, ODEProblem, ODESolution +using SciMLBase: SciMLBase, DiscreteCallback, ODEProblem, ODESolution, SplitFunction import SciMLBase: u_modified! @reexport using StaticArrays: SVector @@ -77,7 +77,7 @@ export LinearDispersionRelation, EulerEquations1D, wave_speed export prim2prim, prim2cons, cons2prim, prim2phys, waterheight_total, waterheight, velocity, momentum, discharge, - gravity, + gravity, have_stiff_terms, bathymetry, still_water_surface, energy_total, entropy, lake_at_rest_error, energy_total_modified, entropy_modified, diff --git a/src/equations/equations.jl b/src/equations/equations.jl index ab495535..d0fc6c7c 100644 --- a/src/equations/equations.jl +++ b/src/equations/equations.jl @@ -265,6 +265,21 @@ Return the gravitational acceleration ``g`` for a given set of `equations`. return equations.gravity end +""" + DispersiveShallowWater.have_stiff_terms(equations) + +Returns `Val{true}()` if the equations have stiff terms that benefit from +implicit time integration methods and `Val{false}()` otherwise (default). + +This trait is used to determine whether to create a `SplitFunction` in +[`semidiscretize`](@ref) for IMEX time integration methods. + +!!! note "Implementation details" + This function is used for internal dispatch to determine the appropriate + ODE problem formulation. +""" +have_stiff_terms(::AbstractEquations) = Val{false}() + """ energy_total(q, equations) diff --git a/src/equations/kdv_1d.jl b/src/equations/kdv_1d.jl index 194dad29..0ddfd712 100644 --- a/src/equations/kdv_1d.jl +++ b/src/equations/kdv_1d.jl @@ -53,6 +53,9 @@ function KdVEquation1D(; gravity, D = 1.0, eta0 = 0.0) KdVEquation1D(gravity, D, eta0) end +# KdV equations have stiff third-order derivative terms that benefit from IMEX methods +have_stiff_terms(::KdVEquation1D) = Val{true}() + """ initial_condition_soliton(x, t, equations::KdVEquation1D, mesh) @@ -142,6 +145,7 @@ function create_cache(mesh, equations::KdVEquation1D, return cache end +""" function rhs!(dq, q, t, mesh, equations::KdVEquation1D, initial_condition, ::BoundaryConditionPeriodic, source_terms, solver, cache) eta, = q.x @@ -193,3 +197,135 @@ function rhs!(dq, q, t, mesh, equations::KdVEquation1D, initial_condition, return nothing end + +function rhs_split_stiff!(dq, q, t, mesh, equations::KdVEquation1D, initial_condition, + ::BoundaryConditionPeriodic, source_terms, solver, cache) + eta, = q.x + deta, = dq.x + + (; c_0, DD) = cache + # In order to use automatic differentiation, we need to extract + # the storage vectors using `get_tmp` from PreallocationTools.jl + # so they can also hold dual numbers when needed. + tmp_1 = get_tmp(cache.tmp_1, eta) + tmp_2 = get_tmp(cache.tmp_2, eta) + + @trixi_timeit timer() "third-order derivatives" begin + if solver.D1 isa PeriodicUpwindOperators && isnothing(solver.D3) + # eta_xxx = Dp * Dc * Dm * eta + mul!(tmp_1, solver.D1.minus, eta) + mul!(tmp_2, solver.D1.central, tmp_1) + mul!(tmp_1, solver.D1.plus, tmp_2) + else + # eta_xxx = D3 * eta + mul!(tmp_1, solver.D3, eta) + end + + # deta = 1 / 6 sqrt(g * D) D^2 eta_xxx + + @.. deta = -1 / 6 * c_0 * DD * tmp_1 + end + + return nothing +end + +function rhs_split_nonstiff!(dq, q, t, mesh, equations::KdVEquation1D, initial_condition, + ::BoundaryConditionPeriodic, source_terms, solver, cache) + eta, = q.x + deta, = dq.x + + (; c_0, c_1) = cache + # In order to use automatic differentiation, we need to extract + # the storage vectors using `get_tmp` from PreallocationTools.jl + # so they can also hold dual numbers when needed. + tmp_1 = get_tmp(cache.tmp_1, eta) + tmp_2 = get_tmp(cache.tmp_2, eta) + + if solver.D1 isa PeriodicUpwindOperators && isnothing(solver.D3) + D1 = solver.D1.central + else + D1 = solver.D1 + end + + @trixi_timeit timer() "hyperbolic" begin + # eta2 = eta^2 + @.. tmp_1 = eta^2 + + # eta2_x = D1 * eta2 + mul!(tmp_2, D1, tmp_1) + + # eta_x = D1 * eta + mul!(tmp_1, D1, eta) + + # deta -= sqrt(g * D) * eta_x + 1 / 2 * sqrt(g / D) * (eta * eta_x + eta2_x) + @.. deta = -(c_0 * tmp_1 + c_1 * (eta * tmp_1 + tmp_2)) + end + + @trixi_timeit timer() "source terms" calc_sources!(dq, q, t, source_terms, equations, + solver) + + return nothing +end +""" + +function rhs!(dq, q, t, mesh, equations::KdVEquation1D, initial_condition, + ::BoundaryConditionPeriodic, source_terms, solver, cache, + mode::Symbol = :full) + eta, = q.x + deta, = dq.x + + (; c_0, c_1, DD) = cache + tmp_1 = get_tmp(cache.tmp_1, eta) + tmp_2 = get_tmp(cache.tmp_2, eta) + # In order to use automatic differentiation, we need to extract + # the storage vectors using `get_tmp` from PreallocationTools.jl + # so they can also hold dual numbers when needed. + if solver.D1 isa PeriodicUpwindOperators && isnothing(solver.D3) + D1 = solver.D1.central + else + D1 = solver.D1 + end + + # Initialize deta based on mode + if mode == :full || mode == :stiff + @trixi_timeit timer() "third-order derivatives" begin + if solver.D1 isa PeriodicUpwindOperators && isnothing(solver.D3) + mul!(tmp_1, solver.D1.minus, eta) + mul!(tmp_2, solver.D1.central, tmp_1) + mul!(tmp_1, solver.D1.plus, tmp_2) + else + mul!(tmp_1, solver.D3, eta) + end + + # add stiff part + # deta = 1 / 6 sqrt(g * D) D^2 eta_xxx + @.. deta = -1 / 6 * c_0 * DD * tmp_1 + end + end + + if mode == :full || mode == :nonstiff + @trixi_timeit timer() "hyperbolic" begin + # eta2 = eta^2 + @.. tmp_1 = eta^2 + + # eta2_x = D1 * eta2 + mul!(tmp_2, D1, tmp_1) + + # eta_x = D1 * eta + mul!(tmp_1, D1, eta) + + # Set or add non-stiff part + # deta -= sqrt(g * D) * eta_x + 1 / 2 * sqrt(g / D) * (eta * eta_x + eta2_x) + if mode == :nonstiff + @.. deta = -(c_0 * tmp_1 + c_1 * (eta * tmp_1 + tmp_2)) + else # mode == :full + @.. deta -= (c_0 * tmp_1 + c_1 * (eta * tmp_1 + tmp_2)) # Add to existing stiff part + end + end + + @trixi_timeit timer() "source terms" calc_sources!(dq, q, t, source_terms, + equations, solver) + end + + return nothing +end diff --git a/src/semidiscretization.jl b/src/semidiscretization.jl index 30deba8b..877e7ebc 100644 --- a/src/semidiscretization.jl +++ b/src/semidiscretization.jl @@ -186,6 +186,26 @@ function rhs!(dq, q, semi::Semidiscretization, t) return nothing end +function rhs_split_stiff!(dq, q, semi::Semidiscretization, t) + @unpack mesh, equations, initial_condition, boundary_conditions, solver, source_terms, cache = semi + + @trixi_timeit timer() "rhs_split_stiff!" rhs!(dq, q, t, mesh, equations, + initial_condition, + boundary_conditions, source_terms, solver, + cache, :stiff) + return nothing +end + +function rhs_split_nonstiff!(dq, q, semi::Semidiscretization, t) + @unpack mesh, equations, initial_condition, boundary_conditions, solver, source_terms, cache = semi + + @trixi_timeit timer() "rhs_split_nonstiff!" rhs!(dq, q, t, mesh, equations, + initial_condition, + boundary_conditions, source_terms, + solver, cache, :nonstiff) + return nothing +end + function compute_coefficients(func, t, semi::Semidiscretization) @unpack mesh, equations, solver = semi q = allocate_coefficients(mesh_equations_solver(semi)...) @@ -214,18 +234,53 @@ function check_bathymetry(equations::AbstractShallowWaterEquations, q0) end """ - semidiscretize(semi::Semidiscretization, tspan) + semidiscretize(semi::Semidiscretization, tspan; split_ode = have_stiff_terms(semi.equations)) Wrap the semidiscretization `semi` as an ODE problem in the time interval `tspan` that can be passed to `solve` from the [SciML ecosystem](https://diffeq.sciml.ai/latest/). + +If `split_ode` is `Val{false}()`, a regular `ODEFunction` is created. +If `split_ode` is `Val{true}()`, a `SplitFunction` is created for IMEX time integration if available. +By default, `split_ode` is determined by the [`DispersiveShallowWater.have_stiff_terms`](@ref) trait. """ -function semidiscretize(semi::Semidiscretization, tspan) +function semidiscretize(semi::Semidiscretization, tspan; + split_ode = have_stiff_terms(semi.equations)) q0 = compute_coefficients(semi.initial_condition, first(tspan), semi) check_bathymetry(semi.equations, q0) iip = true # is-inplace, i.e., we modify a vector when calling rhs! + return _semidiscretize_ode(split_ode, q0, tspan, semi, iip) +end + +# Type-stable dispatch based on split_ode trait +function _semidiscretize_ode(::Val{false}, q0, tspan, semi, iip) return ODEProblem{iip}(rhs!, q0, tspan, semi) end +function _semidiscretize_ode(::Val{true}, q0, tspan, semi, iip) + _check_split_rhs_implementation(semi) + return ODEProblem{iip}(SplitFunction(rhs_split_stiff!, rhs_split_nonstiff!), q0, tspan, + semi) +end + +function _check_split_rhs_implementation(semi) + @unpack mesh, equations, initial_condition, boundary_conditions, solver, source_terms, cache = semi + + equation_name = get_name(equations) + args = (nothing, nothing, nothing, mesh, equations, initial_condition, + boundary_conditions, source_terms, solver, cache) + + # Check if methods are applicable + if !applicable(rhs!, args..., :stiff) + throw(ArgumentError("Split RHS method with :stiff argument not implemented for $equation_name.")) + end + + if !applicable(rhs!, args..., :nonstiff) + throw(ArgumentError("Split RHS method with :nonstiff argument not implemented for $equation_name.")) + end + + return nothing +end + """ DispersiveShallowWater.jacobian(semi::Semidiscretization; t = 0.0, @@ -241,7 +296,8 @@ of the semidiscretization `semi` at the state `q0`. function jacobian(semi::Semidiscretization; t = 0.0, q0 = compute_coefficients(semi.initial_condition, t, semi)) - J = ForwardDiff.jacobian(similar(q0), q0) do dq, q + @unpack tmp_partitioned = semi.cache + J = ForwardDiff.jacobian(tmp_partitioned, q0) do dq, q DispersiveShallowWater.rhs!(dq, q, semi, t) end return J diff --git a/test/Project.toml b/test/Project.toml index e29234ec..e8b31a9a 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -6,6 +6,7 @@ JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" OrdinaryDiffEqLowStorageRK = "b0944070-b475-4768-8dec-fb6eb410534d" OrdinaryDiffEqRosenbrock = "43230ef6-c299-4910-a778-202eb28ce4ce" OrdinaryDiffEqTsit5 = "b1df2697-797e-41e3-8120-5422d3b24e4a" +OrdinaryDiffEqSDIRK = "2d112036-d095-4a1e-ab9a-08536f3ecdbf" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SummationByPartsOperators = "9f78cca6-572e-554e-b819-917d2f1cf240" @@ -22,6 +23,7 @@ JET = "0.9.9" OrdinaryDiffEqLowStorageRK = "1.1" OrdinaryDiffEqRosenbrock = "1.3" OrdinaryDiffEqTsit5 = "1.1" +OrdinaryDiffEqSDIRK = "1.1" Plots = "1.25" SparseArrays = "1" SummationByPartsOperators = "0.5.79" diff --git a/test/test_kdv_1d.jl b/test/test_kdv_1d.jl index e0c2e403..a627c30a 100644 --- a/test/test_kdv_1d.jl +++ b/test/test_kdv_1d.jl @@ -59,3 +59,14 @@ end @test_allocations(semi, sol, allocs=5_000) end + +@testitem "kdv_1d_IMEX" setup=[Setup, KdVEquation1D] begin + @test_trixi_include(joinpath(EXAMPLES_DIR, "kdv_1d_IMEX.jl"), + tspan=(0.0, 5.0), + l2=[0.004952174509850488], + linf=[0.003962890861977875], + cons_error=[2.220446049250313e-15], + change_waterheight=-2.220446049250313e-15) + + @test_allocations_split_ode(semi, sol, allocs=5_000) +end diff --git a/test/test_unit.jl b/test/test_unit.jl index 6d9d024a..119c7b2a 100644 --- a/test/test_unit.jl +++ b/test/test_unit.jl @@ -83,6 +83,7 @@ end solver = Solver(mesh, 4) semi_flat = Semidiscretization(mesh, equations_flat, initial_condition, solver) @test_throws ArgumentError semidiscretize(semi_flat, (0.0, 1.0)) + @test_throws ArgumentError semidiscretize(semi, (0.0, 1.0), split_ode = Val{true}()) end @testitem "Boundary conditions" setup=[Setup] begin diff --git a/test/test_util.jl b/test/test_util.jl index 5772aa02..110d04b3 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -176,6 +176,24 @@ macro test_allocations(semi, sol, allocs) end end +""" + @test_allocations_split_ode(semi, sol, allocs) + +Test that the memory allocations of `DispersiveShallowWater.rhs_split_stiff!` +and `DispersiveShallowWater.rhs_split_nonstiff!` are below `allocs` +(e.g., from type instabilities). +""" +macro test_allocations_split_ode(semi, sol, allocs) + quote + t = $sol.t[end] + q = $sol.u[end] + dq = similar(q) + a1 = @allocated DispersiveShallowWater.rhs_split_stiff!(dq, q, $semi, t) + a2 = @allocated DispersiveShallowWater.rhs_split_nonstiff!(dq, q, $semi, t) + @test (a1 + a2) < $allocs + end +end + macro test_nowarn_mod(expr, additional_ignore_content = []) quote add_to_additional_ignore_content = [