diff --git a/src/mprk.jl b/src/mprk.jl index 9434cec6..b9855cdc 100644 --- a/src/mprk.jl +++ b/src/mprk.jl @@ -9,6 +9,8 @@ end p_prototype(u, f) = zeros(eltype(u), length(u), length(u)) p_prototype(u, f::ConservativePDSFunction) = zero(f.p_prototype) p_prototype(u, f::PDSFunction) = zero(f.p_prototype) +d_prototype(u, f) = zeros(eltype(u), length(u)) +d_prototype(u, f::PDSFunction) = zero(f.d_prototype) ##################################################################### # out-of-place for dense and static arrays @@ -218,9 +220,9 @@ end integrator.u = u end -struct MPECache{PType, uType, tabType, F} <: OrdinaryDiffEqMutableCache +struct MPECache{PType, DType, uType, tabType, F} <: OrdinaryDiffEqMutableCache P::PType - D::uType + D::DType σ::uType tab::tabType linsolve_rhs::uType # stores rhs of linear system @@ -253,6 +255,7 @@ function alg_cache(alg::MPE, u, rate_prototype, ::Type{uEltypeNoUnits}, MPEConservativeCache(P, σ, tab, linsolve) elseif f isa PDSFunction + D = d_prototype(u, f) linsolve_rhs = zero(u) # We use P to store the evaluation of the PDS # as well as to store the system matrix of the linear system @@ -260,7 +263,7 @@ function alg_cache(alg::MPE, u, rate_prototype, ::Type{uEltypeNoUnits}, linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true, assumptions = LinearSolve.OperatorAssumptions(true)) - MPECache(P, zero(u), σ, tab, linsolve_rhs, linsolve) + MPECache(P, D, σ, tab, linsolve_rhs, linsolve) else throw(ArgumentError("MPE can only be applied to production-destruction systems")) end @@ -513,13 +516,13 @@ end integrator.u = u end -struct MPRK22Cache{uType, PType, tabType, F} <: +struct MPRK22Cache{uType, PType, DType, tabType, F} <: OrdinaryDiffEqMutableCache tmp::uType P::PType P2::PType - D::uType - D2::uType + D::DType + D2::DType σ::uType tab::tabType linsolve::F @@ -562,14 +565,13 @@ function alg_cache(alg::MPRK22, u, rate_prototype, ::Type{uEltypeNoUnits}, tab, #MPRK22ConstantCache linsolve) elseif f isa PDSFunction + D = d_prototype(u, f) + D2 = d_prototype(u, f) linprob = LinearProblem(P2, _vec(tmp)) linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true, assumptions = LinearSolve.OperatorAssumptions(true)) - MPRK22Cache(tmp, P, P2, - zero(u), # D - zero(u), # D2 - σ, + MPRK22Cache(tmp, P, P2, D, D2, σ, tab, #MPRK22ConstantCache linsolve) else @@ -1054,15 +1056,15 @@ end integrator.u = u end -struct MPRK43Cache{uType, PType, tabType, F} <: OrdinaryDiffEqMutableCache +struct MPRK43Cache{uType, PType, DType, tabType, F} <: OrdinaryDiffEqMutableCache tmp::uType tmp2::uType P::PType P2::PType P3::PType - D::uType - D2::uType - D3::uType + D::DType + D2::DType + D3::DType σ::uType tab::tabType linsolve::F @@ -1107,9 +1109,9 @@ function alg_cache(alg::Union{MPRK43I, MPRK43II}, u, rate_prototype, ::Type{uElt assumptions = LinearSolve.OperatorAssumptions(true)) MPRK43ConservativeCache(tmp, tmp2, P, P2, P3, σ, tab, linsolve) elseif f isa PDSFunction - D = zero(u) - D2 = zero(u) - D3 = zero(u) + D = d_prototype(u, f) + D2 = d_prototype(u, f) + D3 = d_prototype(u, f) linprob = LinearProblem(P3, _vec(tmp)) linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true, diff --git a/src/proddest.jl b/src/proddest.jl index 01bb76c3..611c0706 100644 --- a/src/proddest.jl +++ b/src/proddest.jl @@ -20,10 +20,10 @@ The functions `P` and `D` can be used either in the out-of-place form with signa ### Keyword arguments: ### -- `p_prototype`: If `P` is given in in-place form, `p_prototype` is used to store evaluations of `P`. +- `p_prototype`: If `P` is given in in-place form, `p_prototype` or copies thereof are used to store evaluations of `P`. If `p_prototype` is not specified explicitly and `P` is in-place, then `p_prototype` will be internally set to `zeros(eltype(u0), (length(u0), length(u0)))`. -- `d_prototype`: If `D` is given in in-place form, `d_prototype` is used to store evaluations of `D`. +- `d_prototype`: If `D` is given in in-place form, `d_prototype` or copies thereof are used to store evaluations of `D`. If `d_prototype` is not specified explicitly and `D` is in-place, then `d_prototype` will be internally set to `zeros(eltype(u0), (length(u0),))`. @@ -177,7 +177,7 @@ The function `P` can be given either in the out-of-place form with signature ### Keyword arguments: ### -- `p_prototype`: If `P` is given in in-place form, `p_prototype` is used to store evaluations of `P`. +- `p_prototype`: If `P` is given in in-place form, `p_prototype` or copies thereof are used to store evaluations of `P`. If `p_prototype` is not specified explicitly and `P` is in-place, then `p_prototype` will be internally set to `zeros(eltype(u0), (length(u0), length(u0)))`. - `analytic`: The analytic solution of a PDS must be given in the form `f(u0,p,t)`. diff --git a/src/sspmprk.jl b/src/sspmprk.jl index bd871d33..8ae7d0ae 100644 --- a/src/sspmprk.jl +++ b/src/sspmprk.jl @@ -201,13 +201,13 @@ end integrator.u = u end -struct SSPMPRK22Cache{uType, PType, tabType, F} <: +struct SSPMPRK22Cache{uType, PType, DType, tabType, F} <: OrdinaryDiffEqMutableCache tmp::uType P::PType P2::PType - D::uType - D2::uType + D::DType + D2::DType σ::uType tab::tabType linsolve::F @@ -251,10 +251,9 @@ function alg_cache(alg::SSPMPRK22, u, rate_prototype, ::Type{uEltypeNoUnits}, linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true, assumptions = LinearSolve.OperatorAssumptions(true)) - SSPMPRK22Cache(tmp, P, P2, - zero(u), # D - zero(u), # D2 - σ, + D = d_prototype(u, f) + D2 = d_prototype(u, f) + SSPMPRK22Cache(tmp, P, P2, D, D2, σ, tab, #MPRK22ConstantCache linsolve) else @@ -671,15 +670,15 @@ end integrator.u = u end -struct SSPMPRK43Cache{uType, PType, tabType, F} <: OrdinaryDiffEqMutableCache +struct SSPMPRK43Cache{uType, PType, DType, tabType, F} <: OrdinaryDiffEqMutableCache tmp::uType tmp2::uType P::PType P2::PType P3::PType - D::uType - D2::uType - D3::uType + D::DType + D2::DType + D3::DType σ::uType ρ::uType tab::tabType @@ -724,9 +723,9 @@ function alg_cache(alg::SSPMPRK43, u, rate_prototype, ::Type{uEltypeNoUnits}, assumptions = LinearSolve.OperatorAssumptions(true)) SSPMPRK43ConservativeCache(tmp, tmp2, P, P2, P3, σ, ρ, tab, linsolve) elseif f isa PDSFunction - D = zero(u) - D2 = zero(u) - D3 = zero(u) + D = d_prototype(u, f) + D2 = d_prototype(u, f) + D3 = d_prototype(u, f) linprob = LinearProblem(P3, _vec(tmp)) linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true, diff --git a/test/runtests.jl b/test/runtests.jl index 3b14f06a..7694063d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1263,6 +1263,84 @@ end end end + # Here we check that the types of p_prototype and d_prototype actually + # define the types of the Ps and Ds inside the algorithm caches. + # We test sparse, tridiagonal and dense matrices as well as sparse and + # dense vectors + @testset "Prototype type check" begin + #prod and dest functions + prod_inner! = (P, u, p, t) -> begin + fill!(P, zero(eltype(P))) + for i in 1:(length(u) - 1) + P[i, i + 1] = i * u[i] + end + return nothing + end + prod_sparse! = (P, u, p, t) -> begin + @test P isa SparseMatrixCSC + prod_inner!(P, u, p, t) + return nothing + end + prod_tridiagonal! = (P, u, p, t) -> begin + @test P isa Tridiagonal + prod_inner!(P, u, p, t) + return nothing + end + prod_dense! = (P, u, p, t) -> begin + @test P isa Matrix + prod_inner!(P, u, p, t) + return nothing + end + dest_sparse! = (D, u, p, t) -> begin + @test D isa SparseVector + fill!(D, zero(eltype(D))) + end + dest_dense! = (D, u, p, t) -> begin + @test D isa Vector + fill!(D, zero(eltype(D))) + end + #prototypes + P_tridiagonal = Tridiagonal([0.1, 0.2, 0.3], + [0.0, 0.0, 0.0, 0.0], + [0.4, 0.5, 0.6]) + P_dense = Matrix(P_tridiagonal) + P_sparse = sparse(P_tridiagonal) + D_sparse = spzeros(4) + D_dense = Vector(D_sparse) + # problem definition + u0 = [1.0, 1.5, 2.0, 2.5] + tspan = (0.0, 1.0) + dt = 0.5 + ## conservative PDS + prob_default = ConservativePDSProblem(prod_dense!, u0, tspan) + prob_tridiagonal = ConservativePDSProblem(prod_tridiagonal!, u0, tspan; + p_prototype = P_tridiagonal) + prob_dense = ConservativePDSProblem(prod_dense!, u0, tspan; + p_prototype = P_dense) + prob_sparse = ConservativePDSProblem(prod_sparse!, u0, tspan; + p_prototype = P_sparse) + ## nonconservative PDS + prob_default2 = PDSProblem(prod_dense!, dest_dense!, u0, tspan) + prob_tridiagonal2 = PDSProblem(prod_tridiagonal!, dest_dense!, u0, tspan; + p_prototype = P_tridiagonal) + prob_dense2 = PDSProblem(prod_dense!, dest_dense!, u0, tspan; + p_prototype = P_dense, + d_prototype = D_dense) + prob_sparse2 = PDSProblem(prod_sparse!, dest_sparse!, u0, tspan; + p_prototype = P_sparse, + d_prototype = D_sparse) + for alg in (MPE(), MPRK22(0.5), MPRK22(1.0), MPRK43I(1.0, 0.5), + MPRK43I(0.5, 0.75), + MPRK43II(2.0 / 3.0), MPRK43II(0.5), SSPMPRK22(0.5, 1.0), + SSPMPRK43()) + for prob in (prob_default, prob_tridiagonal, prob_dense, prob_sparse, + prob_default2, + prob_tridiagonal2, prob_dense2, prob_sparse2) + solve(prob, alg; dt, adaptive = false) + end + end + end + # Here we check the convergence order of pth-order schemes for which # no interpolation of order p is available @testset "Convergence tests (conservative)" begin