Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 19 additions & 17 deletions src/mprk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ end
p_prototype(u, f) = zeros(eltype(u), length(u), length(u))
p_prototype(u, f::ConservativePDSFunction) = zero(f.p_prototype)
p_prototype(u, f::PDSFunction) = zero(f.p_prototype)
d_prototype(u, f) = zeros(eltype(u), length(u))
d_prototype(u, f::PDSFunction) = zero(f.d_prototype)

#####################################################################
# out-of-place for dense and static arrays
Expand Down Expand Up @@ -218,9 +220,9 @@ end
integrator.u = u
end

struct MPECache{PType, uType, tabType, F} <: OrdinaryDiffEqMutableCache
struct MPECache{PType, DType, uType, tabType, F} <: OrdinaryDiffEqMutableCache
P::PType
D::uType
D::DType
σ::uType
tab::tabType
linsolve_rhs::uType # stores rhs of linear system
Expand Down Expand Up @@ -253,14 +255,15 @@ function alg_cache(alg::MPE, u, rate_prototype, ::Type{uEltypeNoUnits},

MPEConservativeCache(P, σ, tab, linsolve)
elseif f isa PDSFunction
D = d_prototype(u, f)
linsolve_rhs = zero(u)
# We use P to store the evaluation of the PDS
# as well as to store the system matrix of the linear system
linprob = LinearProblem(P, _vec(linsolve_rhs))
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
assumptions = LinearSolve.OperatorAssumptions(true))

MPECache(P, zero(u), σ, tab, linsolve_rhs, linsolve)
MPECache(P, D, σ, tab, linsolve_rhs, linsolve)
else
throw(ArgumentError("MPE can only be applied to production-destruction systems"))
end
Expand Down Expand Up @@ -513,13 +516,13 @@ end
integrator.u = u
end

struct MPRK22Cache{uType, PType, tabType, F} <:
struct MPRK22Cache{uType, PType, DType, tabType, F} <:
OrdinaryDiffEqMutableCache
tmp::uType
P::PType
P2::PType
D::uType
D2::uType
D::DType
D2::DType
σ::uType
tab::tabType
linsolve::F
Expand Down Expand Up @@ -562,14 +565,13 @@ function alg_cache(alg::MPRK22, u, rate_prototype, ::Type{uEltypeNoUnits},
tab, #MPRK22ConstantCache
linsolve)
elseif f isa PDSFunction
D = d_prototype(u, f)
D2 = d_prototype(u, f)
linprob = LinearProblem(P2, _vec(tmp))
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
assumptions = LinearSolve.OperatorAssumptions(true))

MPRK22Cache(tmp, P, P2,
zero(u), # D
zero(u), # D2
σ,
MPRK22Cache(tmp, P, P2, D, D2, σ,
tab, #MPRK22ConstantCache
linsolve)
else
Expand Down Expand Up @@ -1054,15 +1056,15 @@ end
integrator.u = u
end

struct MPRK43Cache{uType, PType, tabType, F} <: OrdinaryDiffEqMutableCache
struct MPRK43Cache{uType, PType, DType, tabType, F} <: OrdinaryDiffEqMutableCache
tmp::uType
tmp2::uType
P::PType
P2::PType
P3::PType
D::uType
D2::uType
D3::uType
D::DType
D2::DType
D3::DType
σ::uType
tab::tabType
linsolve::F
Expand Down Expand Up @@ -1107,9 +1109,9 @@ function alg_cache(alg::Union{MPRK43I, MPRK43II}, u, rate_prototype, ::Type{uElt
assumptions = LinearSolve.OperatorAssumptions(true))
MPRK43ConservativeCache(tmp, tmp2, P, P2, P3, σ, tab, linsolve)
elseif f isa PDSFunction
D = zero(u)
D2 = zero(u)
D3 = zero(u)
D = d_prototype(u, f)
D2 = d_prototype(u, f)
D3 = d_prototype(u, f)

linprob = LinearProblem(P3, _vec(tmp))
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
Expand Down
6 changes: 3 additions & 3 deletions src/proddest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ The functions `P` and `D` can be used either in the out-of-place form with signa

### Keyword arguments: ###

- `p_prototype`: If `P` is given in in-place form, `p_prototype` is used to store evaluations of `P`.
- `p_prototype`: If `P` is given in in-place form, `p_prototype` or copies thereof are used to store evaluations of `P`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should definitely merge these documentation clarifications!

If `p_prototype` is not specified explicitly and `P` is in-place, then `p_prototype` will be internally
set to `zeros(eltype(u0), (length(u0), length(u0)))`.
- `d_prototype`: If `D` is given in in-place form, `d_prototype` is used to store evaluations of `D`.
- `d_prototype`: If `D` is given in in-place form, `d_prototype` or copies thereof are used to store evaluations of `D`.
If `d_prototype` is not specified explicitly and `D` is in-place, then `d_prototype` will be internally
set to `zeros(eltype(u0), (length(u0),))`.

Expand Down Expand Up @@ -177,7 +177,7 @@ The function `P` can be given either in the out-of-place form with signature

### Keyword arguments: ###

- `p_prototype`: If `P` is given in in-place form, `p_prototype` is used to store evaluations of `P`.
- `p_prototype`: If `P` is given in in-place form, `p_prototype` or copies thereof are used to store evaluations of `P`.
If `p_prototype` is not specified explicitly and `P` is in-place, then `p_prototype` will be internally
set to `zeros(eltype(u0), (length(u0), length(u0)))`.
- `analytic`: The analytic solution of a PDS must be given in the form `f(u0,p,t)`.
Expand Down
27 changes: 13 additions & 14 deletions src/sspmprk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -201,13 +201,13 @@ end
integrator.u = u
end

struct SSPMPRK22Cache{uType, PType, tabType, F} <:
struct SSPMPRK22Cache{uType, PType, DType, tabType, F} <:
OrdinaryDiffEqMutableCache
tmp::uType
P::PType
P2::PType
D::uType
D2::uType
D::DType
D2::DType
σ::uType
tab::tabType
linsolve::F
Expand Down Expand Up @@ -251,10 +251,9 @@ function alg_cache(alg::SSPMPRK22, u, rate_prototype, ::Type{uEltypeNoUnits},
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
assumptions = LinearSolve.OperatorAssumptions(true))

SSPMPRK22Cache(tmp, P, P2,
zero(u), # D
zero(u), # D2
σ,
D = d_prototype(u, f)
D2 = d_prototype(u, f)
SSPMPRK22Cache(tmp, P, P2, D, D2, σ,
tab, #MPRK22ConstantCache
linsolve)
else
Expand Down Expand Up @@ -671,15 +670,15 @@ end
integrator.u = u
end

struct SSPMPRK43Cache{uType, PType, tabType, F} <: OrdinaryDiffEqMutableCache
struct SSPMPRK43Cache{uType, PType, DType, tabType, F} <: OrdinaryDiffEqMutableCache
tmp::uType
tmp2::uType
P::PType
P2::PType
P3::PType
D::uType
D2::uType
D3::uType
D::DType
D2::DType
D3::DType
σ::uType
ρ::uType
tab::tabType
Expand Down Expand Up @@ -724,9 +723,9 @@ function alg_cache(alg::SSPMPRK43, u, rate_prototype, ::Type{uEltypeNoUnits},
assumptions = LinearSolve.OperatorAssumptions(true))
SSPMPRK43ConservativeCache(tmp, tmp2, P, P2, P3, σ, ρ, tab, linsolve)
elseif f isa PDSFunction
D = zero(u)
D2 = zero(u)
D3 = zero(u)
D = d_prototype(u, f)
D2 = d_prototype(u, f)
D3 = d_prototype(u, f)

linprob = LinearProblem(P3, _vec(tmp))
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
Expand Down
78 changes: 78 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1263,6 +1263,84 @@ end
end
end

# Here we check that the types of p_prototype and d_prototype actually
Copy link
Collaborator Author

@SKopecz SKopecz Jul 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we should also keep this test (without the d_prototype parts) ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes 👍

# define the types of the Ps and Ds inside the algorithm caches.
# We test sparse, tridiagonal and dense matrices as well as sparse and
# dense vectors
@testset "Prototype type check" begin
#prod and dest functions
prod_inner! = (P, u, p, t) -> begin
fill!(P, zero(eltype(P)))
for i in 1:(length(u) - 1)
P[i, i + 1] = i * u[i]
end
return nothing
end
prod_sparse! = (P, u, p, t) -> begin
@test P isa SparseMatrixCSC
prod_inner!(P, u, p, t)
return nothing
end
prod_tridiagonal! = (P, u, p, t) -> begin
@test P isa Tridiagonal
prod_inner!(P, u, p, t)
return nothing
end
prod_dense! = (P, u, p, t) -> begin
@test P isa Matrix
prod_inner!(P, u, p, t)
return nothing
end
dest_sparse! = (D, u, p, t) -> begin
@test D isa SparseVector
fill!(D, zero(eltype(D)))
end
dest_dense! = (D, u, p, t) -> begin
@test D isa Vector
fill!(D, zero(eltype(D)))
end
#prototypes
P_tridiagonal = Tridiagonal([0.1, 0.2, 0.3],
[0.0, 0.0, 0.0, 0.0],
[0.4, 0.5, 0.6])
P_dense = Matrix(P_tridiagonal)
P_sparse = sparse(P_tridiagonal)
D_sparse = spzeros(4)
D_dense = Vector(D_sparse)
# problem definition
u0 = [1.0, 1.5, 2.0, 2.5]
tspan = (0.0, 1.0)
dt = 0.5
## conservative PDS
prob_default = ConservativePDSProblem(prod_dense!, u0, tspan)
prob_tridiagonal = ConservativePDSProblem(prod_tridiagonal!, u0, tspan;
p_prototype = P_tridiagonal)
prob_dense = ConservativePDSProblem(prod_dense!, u0, tspan;
p_prototype = P_dense)
prob_sparse = ConservativePDSProblem(prod_sparse!, u0, tspan;
p_prototype = P_sparse)
## nonconservative PDS
prob_default2 = PDSProblem(prod_dense!, dest_dense!, u0, tspan)
prob_tridiagonal2 = PDSProblem(prod_tridiagonal!, dest_dense!, u0, tspan;
p_prototype = P_tridiagonal)
prob_dense2 = PDSProblem(prod_dense!, dest_dense!, u0, tspan;
p_prototype = P_dense,
d_prototype = D_dense)
prob_sparse2 = PDSProblem(prod_sparse!, dest_sparse!, u0, tspan;
p_prototype = P_sparse,
d_prototype = D_sparse)
for alg in (MPE(), MPRK22(0.5), MPRK22(1.0), MPRK43I(1.0, 0.5),
MPRK43I(0.5, 0.75),
MPRK43II(2.0 / 3.0), MPRK43II(0.5), SSPMPRK22(0.5, 1.0),
SSPMPRK43())
for prob in (prob_default, prob_tridiagonal, prob_dense, prob_sparse,
prob_default2,
prob_tridiagonal2, prob_dense2, prob_sparse2)
solve(prob, alg; dt, adaptive = false)
end
end
end

# Here we check the convergence order of pth-order schemes for which
# no interpolation of order p is available
@testset "Convergence tests (conservative)" begin
Expand Down