Skip to content

Commit 92b7e71

Browse files
ranochaSKopecz
andauthored
add possibility to call a direct implementation of the ODE RHS (#117)
* improve docstrings * add possibility to call a direct implementation of the ODE RHS * add tests * format * Apply suggestions from code review Co-authored-by: Stefan Kopecz <[email protected]> * more tests as suggested --------- Co-authored-by: Stefan Kopecz <[email protected]>
1 parent b91a7ba commit 92b7e71

File tree

3 files changed

+211
-38
lines changed

3 files changed

+211
-38
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "PositiveIntegrators"
22
uuid = "d1b20bf0-b083-4985-a874-dc5121669aa5"
33
authors = ["Stefan Kopecz, Hendrik Ranocha, and contributors"]
4-
version = "0.2.2"
4+
version = "0.2.3"
55

66
[deps]
77
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"

src/proddest.jl

Lines changed: 104 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@ abstract type AbstractPDSProblem end
33

44
"""
55
PDSProblem(P, D, u0, tspan, p = NullParameters();
6-
p_prototype = nothing,
7-
analytic = nothing)
6+
p_prototype = nothing,
7+
analytic = nothing,
8+
std_rhs = nothing)
89
910
A structure describing a system of ordinary differential equations in form of a production-destruction system (PDS).
10-
`P` denotes the production matrix.
11-
The diagonal of `P` contains production terms without destruction counterparts.
12-
`D` is the vector of destruction terms without production counterparts.
11+
`P` denotes the function defining the production matrix ``P``.
12+
The diagonal of ``P`` contains production terms without destruction counterparts.
13+
`D` is the function defining the vector of destruction terms ``D`` without production counterparts.
1314
`u0` is the vector of initial conditions and `tspan` the time span
1415
`(t_initial, t_final)` of the problem. The optional argument `p` can be used
1516
to pass additional parameters to the functions `P` and `D`.
@@ -20,10 +21,16 @@ The functions `P` and `D` can be used either in the out-of-place form with signa
2021
### Keyword arguments: ###
2122
2223
- `p_prototype`: If `P` is given in in-place form, `p_prototype` or copies thereof are used to store evaluations of `P`.
23-
If `p_prototype` is not specified explicitly and `P` is in-place, then `p_prototype` will be internally
24+
If `p_prototype` is not specified explicitly and `P` is in-place, then `p_prototype` will be internally
2425
set to `zeros(eltype(u0), (length(u0), length(u0)))`.
2526
- `analytic`: The analytic solution of a PDS must be given in the form `f(u0,p,t)`.
26-
Specifying the analytic solution can be useful for plotting and convergence tests.
27+
Specifying the analytic solution can be useful for plotting and convergence tests.
28+
- `std_rhs`: The standard ODE right-hand side evaluation function callable
29+
as `du = std_rhs(u, p, t)` for the out-of-place form and
30+
as `std_rhs(du, u, p, t)` for the in-place form. Solvers that do not rely on
31+
the production-destruction representation of the ODE, will use this function
32+
instead to compute the solution. If not specified,
33+
a default implementation calling `P` and `D` is used.
2734
2835
## References
2936
@@ -36,12 +43,13 @@ The functions `P` and `D` can be used either in the out-of-place form with signa
3643
struct PDSProblem{iip} <: AbstractPDSProblem end
3744

3845
# New ODE function PDSFunction
39-
struct PDSFunction{iip, specialize, P, D, PrototypeP, PrototypeD, Ta} <:
46+
struct PDSFunction{iip, specialize, P, D, PrototypeP, PrototypeD, StdRHS, Ta} <:
4047
AbstractODEFunction{iip}
4148
p::P
4249
d::D
4350
p_prototype::PrototypeP
4451
d_prototype::PrototypeD
52+
std_rhs::StdRHS
4553
analytic::Ta
4654
end
4755

@@ -82,18 +90,19 @@ end
8290
function PDSProblem{iip}(P, D, u0, tspan, p = NullParameters();
8391
p_prototype = nothing,
8492
analytic = nothing,
93+
std_rhs = nothing,
8594
kwargs...) where {iip}
8695

8796
# p_prototype is used to store evaluations of P, if P is in-place.
8897
if isnothing(p_prototype) && iip
8998
p_prototype = zeros(eltype(u0), (length(u0), length(u0)))
9099
end
91-
# If a PDSFunction is to be evaluated and D is in-place, then d_prototype is used to store
100+
# If a PDSFunction is to be evaluated and D is in-place, then d_prototype is used to store
92101
# evaluations of D.
93102
d_prototype = similar(u0)
94103

95-
PD = PDSFunction{iip}(P, D; p_prototype = p_prototype, d_prototype = d_prototype,
96-
analytic = analytic)
104+
PD = PDSFunction{iip}(P, D; p_prototype, d_prototype,
105+
analytic, std_rhs)
97106
PDSProblem{iip}(PD, u0, tspan, p; kwargs...)
98107
end
99108

@@ -112,21 +121,45 @@ end
112121
# Most specific constructor for PDSFunction
113122
function PDSFunction{iip, FullSpecialize}(P, D;
114123
p_prototype = nothing,
115-
d_prototype = nothing,
116-
analytic = nothing) where {iip}
124+
d_prototype,
125+
analytic = nothing,
126+
std_rhs = nothing) where {iip}
127+
if std_rhs === nothing
128+
std_rhs = PDSStdRHS(P, D, p_prototype, d_prototype)
129+
end
117130
PDSFunction{iip, FullSpecialize, typeof(P), typeof(D), typeof(p_prototype),
118131
typeof(d_prototype),
119-
typeof(analytic)}(P, D, p_prototype, d_prototype, analytic)
132+
typeof(std_rhs), typeof(analytic)}(P, D, p_prototype, d_prototype, std_rhs,
133+
analytic)
120134
end
121135

122-
# Evaluation of a PDSFunction (out-of-place)
136+
# Evaluation of a PDSFunction
123137
function (PD::PDSFunction)(u, p, t)
124-
diag(PD.p(u, p, t)) + vec(sum(PD.p(u, p, t), dims = 2)) -
125-
vec(sum(PD.p(u, p, t), dims = 1)) - vec(PD.d(u, p, t))
138+
return PD.std_rhs(u, p, t)
126139
end
127140

128-
# Evaluation of a PDSFunction (in-place)
129141
function (PD::PDSFunction)(du, u, p, t)
142+
return PD.std_rhs(du, u, p, t)
143+
end
144+
145+
# Default implementation of the standard right-hand side evaluation function
146+
struct PDSStdRHS{P, D, PrototypeP, PrototypeD} <: Function
147+
p::P
148+
d::D
149+
p_prototype::PrototypeP
150+
d_prototype::PrototypeD
151+
end
152+
153+
# Evaluation of a PDSStdRHS (out-of-place)
154+
function (PD::PDSStdRHS)(u, p, t)
155+
P = PD.p(u, p, t)
156+
D = PD.d(u, p, t)
157+
diag(P) + vec(sum(P, dims = 2)) -
158+
vec(sum(P, dims = 1)) - vec(D)
159+
end
160+
161+
# Evaluation of a PDSStdRHS (in-place)
162+
function (PD::PDSStdRHS)(du, u, p, t)
130163
PD.p(PD.p_prototype, u, p, t)
131164

132165
if PD.p_prototype isa AbstractSparseMatrix
@@ -157,24 +190,32 @@ end
157190
"""
158191
ConservativePDSProblem(P, u0, tspan, p = NullParameters();
159192
p_prototype = nothing,
160-
analytic = nothing)
193+
analytic = nothing,
194+
std_rhs = nothing)
161195
162196
A structure describing a conservative system of ordinary differential equation in form of a production-destruction system (PDS).
163-
`P` denotes the production matrix.
197+
`P` denotes the function defining the production matrix ``P``.
198+
The diagonal of ``P`` contains production terms without destruction counterparts.
164199
`u0` is the vector of initial conditions and `tspan` the time span
165200
`(t_initial, t_final)` of the problem. The optional argument `p` can be used
166-
to pass additional parameters to the function P.
201+
to pass additional parameters to the function `P`.
167202
168203
The function `P` can be given either in the out-of-place form with signature
169204
`production_terms = P(u, p, t)` or the in-place form `P(production_terms, u, p, t)`.
170205
171206
### Keyword arguments: ###
172207
173208
- `p_prototype`: If `P` is given in in-place form, `p_prototype` or copies thereof are used to store evaluations of `P`.
174-
If `p_prototype` is not specified explicitly and `P` is in-place, then `p_prototype` will be internally
209+
If `p_prototype` is not specified explicitly and `P` is in-place, then `p_prototype` will be internally
175210
set to `zeros(eltype(u0), (length(u0), length(u0)))`.
176211
- `analytic`: The analytic solution of a PDS must be given in the form `f(u0,p,t)`.
177-
Specifying the analytic solution can be useful for plotting and convergence tests.
212+
Specifying the analytic solution can be useful for plotting and convergence tests.
213+
- `std_rhs`: The standard ODE right-hand side evaluation function callable
214+
as `du = std_rhs(u, p, t)` for the out-of-place form and
215+
as `std_rhs(du, u, p, t)` for the in-place form. Solvers that do not rely on
216+
the production-destruction representation of the ODE, will use this function
217+
instead to compute the solution. If not specified,
218+
a default implementation calling `P` is used.
178219
179220
## References
180221
@@ -187,12 +228,12 @@ The function `P` can be given either in the out-of-place form with signature
187228
struct ConservativePDSProblem{iip} <: AbstractPDSProblem end
188229

189230
# New ODE function ConservativePDSFunction
190-
struct ConservativePDSFunction{iip, specialize, P, PrototypeP, TMP, Ta} <:
231+
struct ConservativePDSFunction{iip, specialize, P, PrototypeP, StdRHS, Ta} <:
191232
AbstractODEFunction{iip}
192-
p::P
193-
p_prototype::PrototypeP
194-
tmp::TMP
195-
analytic::Ta
233+
p::P # production terms
234+
p_prototype::PrototypeP # prototype for production terms
235+
std_rhs::StdRHS # standard right-hand side evaluation function
236+
analytic::Ta # analytic solution (or nothing)
196237
end
197238

198239
# define behavior of ConservativePDSFunction for non-existing fields
@@ -226,14 +267,15 @@ end
226267
function ConservativePDSProblem{iip}(P, u0, tspan, p = NullParameters();
227268
p_prototype = nothing,
228269
analytic = nothing,
270+
std_rhs = nothing,
229271
kwargs...) where {iip}
230272

231273
# p_prototype is used to store evaluations of P, if P is in-place.
232274
if isnothing(p_prototype) && iip
233275
p_prototype = zeros(eltype(u0), (length(u0), length(u0)))
234276
end
235277

236-
PD = ConservativePDSFunction{iip}(P; p_prototype = p_prototype, analytic = analytic)
278+
PD = ConservativePDSFunction{iip}(P; p_prototype, analytic, std_rhs)
237279
ConservativePDSProblem{iip}(PD, u0, tspan, p; kwargs...)
238280
end
239281

@@ -252,18 +294,43 @@ end
252294
# Most specific constructor for ConservativePDSFunction
253295
function ConservativePDSFunction{iip, FullSpecialize}(P;
254296
p_prototype = nothing,
255-
analytic = nothing) where {iip}
297+
analytic = nothing,
298+
std_rhs = nothing) where {iip}
299+
if std_rhs === nothing
300+
std_rhs = ConservativePDSStdRHS(P, p_prototype)
301+
end
302+
ConservativePDSFunction{iip, FullSpecialize, typeof(P), typeof(p_prototype),
303+
typeof(std_rhs), typeof(analytic)}(P, p_prototype, std_rhs,
304+
analytic)
305+
end
306+
307+
# Evaluation of a ConservativePDSFunction
308+
function (PD::ConservativePDSFunction)(u, p, t)
309+
return PD.std_rhs(u, p, t)
310+
end
311+
312+
function (PD::ConservativePDSFunction)(du, u, p, t)
313+
return PD.std_rhs(du, u, p, t)
314+
end
315+
316+
# Default implementation of the standard right-hand side evaluation function
317+
struct ConservativePDSStdRHS{P, PrototypeP, TMP} <: Function
318+
p::P
319+
p_prototype::PrototypeP
320+
tmp::TMP
321+
end
322+
323+
function ConservativePDSStdRHS(P, p_prototype)
256324
if p_prototype isa AbstractSparseMatrix
257325
tmp = zeros(eltype(p_prototype), (size(p_prototype, 1),))
258326
else
259327
tmp = nothing
260328
end
261-
ConservativePDSFunction{iip, FullSpecialize, typeof(P), typeof(p_prototype),
262-
typeof(tmp), typeof(analytic)}(P, p_prototype, tmp, analytic)
329+
ConservativePDSStdRHS(P, p_prototype, tmp)
263330
end
264331

265-
# Evaluation of a ConservativePDSFunction (out-of-place)
266-
function (PD::ConservativePDSFunction)(u, p, t)
332+
# Evaluation of a ConservativePDSStdRHS (out-of-place)
333+
function (PD::ConservativePDSStdRHS)(u, p, t)
267334
#vec(sum(PD.p(u, p, t), dims = 2)) - vec(sum(PD.p(u, p, t), dims = 1))
268335
P = PD.p(u, p, t)
269336

@@ -277,7 +344,7 @@ function (PD::ConservativePDSFunction)(u, p, t)
277344
return f
278345
end
279346

280-
function (PD::ConservativePDSFunction)(u::SVector, p, t)
347+
function (PD::ConservativePDSStdRHS)(u::SVector, p, t)
281348
P = PD.p(u, p, t)
282349

283350
f = similar(u) #constructs MVector
@@ -296,8 +363,8 @@ function (PD::ConservativePDSFunction)(u::SVector, p, t)
296363
return SVector(f)
297364
end
298365

299-
# Evaluation of a ConservativePDSFunction (in-place)
300-
function (PD::ConservativePDSFunction)(du, u, p, t)
366+
# Evaluation of a ConservativePDSStdRHS (in-place)
367+
function (PD::ConservativePDSStdRHS)(du, u, p, t)
301368
PD.p(PD.p_prototype, u, p, t)
302369
sum_terms!(du, PD.tmp, PD.p_prototype)
303370
return nothing

test/runtests.jl

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,112 @@ end
216216
@test isnothing(check_no_stale_explicit_imports(PositiveIntegrators))
217217
end
218218

219+
@testset "ODE RHS" begin
220+
let counter_p = Ref(1), counter_d = Ref(1), counter_rhs = Ref(1)
221+
# out-of-place
222+
prod1 = (u, p, t) -> begin
223+
counter_p[] += 1
224+
return [0 u[2]; u[1] 0]
225+
end
226+
dest1 = (u, p, t) -> begin
227+
counter_d[] += 1
228+
return zero(u)
229+
end
230+
rhs1 = (u, p, t) -> begin
231+
counter_rhs[] += 1
232+
return [-u[1] + u[2], u[1] - u[2]]
233+
end
234+
u0 = [1.0, 0.0]
235+
tspan = (0.0, 1.0)
236+
prob_default = PDSProblem(prod1, dest1, u0, tspan)
237+
prob_special = PDSProblem(prod1, dest1, u0, tspan; std_rhs = rhs1)
238+
239+
counter_p[] = 0
240+
counter_d[] = 0
241+
counter_rhs[] = 0
242+
@inferred prob_default.f(u0, nothing, 0.0)
243+
@test counter_p[] == 1
244+
@test counter_d[] == 1
245+
@test counter_rhs[] == 0
246+
247+
counter_p[] = 0
248+
counter_d[] = 0
249+
counter_rhs[] = 0
250+
@inferred prob_special.f(u0, nothing, 0.0)
251+
@test counter_p[] == 0
252+
@test counter_d[] == 0
253+
@test counter_rhs[] == 1
254+
255+
# in-place
256+
prod1! = (P, u, p, t) -> begin
257+
counter_p[] += 1
258+
P[1, 1] = 0
259+
P[1, 2] = u[2]
260+
P[2, 1] = u[1]
261+
P[2, 2] = 0
262+
return nothing
263+
end
264+
dest1! = (D, u, p, t) -> begin
265+
counter_d[] += 1
266+
fill!(D, 0)
267+
return nothing
268+
end
269+
rhs1! = (du, u, p, t) -> begin
270+
counter_rhs[] += 1
271+
du[1] = -u[1] + u[2]
272+
du[2] = u[1] - u[2]
273+
return nothing
274+
end
275+
u0 = [1.0, 0.0]
276+
tspan = (0.0, 1.0)
277+
prob_default = PDSProblem(prod1!, dest1!, u0, tspan)
278+
prob_special = PDSProblem(prod1!, dest1!, u0, tspan; std_rhs = rhs1!)
279+
280+
du = similar(u0)
281+
counter_p[] = 0
282+
counter_d[] = 0
283+
counter_rhs[] = 0
284+
@inferred prob_default.f(du, u0, nothing, 0.0)
285+
@test counter_p[] == 1
286+
@test counter_d[] == 1
287+
@test counter_rhs[] == 0
288+
289+
counter_p[] = 0
290+
counter_d[] = 0
291+
counter_rhs[] = 0
292+
@inferred prob_special.f(du, u0, nothing, 0.0)
293+
@test counter_p[] == 0
294+
@test counter_d[] == 0
295+
@test counter_rhs[] == 1
296+
297+
counter_p[] = 0
298+
counter_d[] = 0
299+
counter_rhs[] = 0
300+
@inferred solve(prob_default, MPE(); dt = 0.1)
301+
@test 10 <= counter_p[] <= 11
302+
@test 10 <= counter_d[] <= 11
303+
@test counter_d[] == counter_p[]
304+
@test counter_rhs[] == 0
305+
306+
counter_p[] = 0
307+
counter_d[] = 0
308+
counter_rhs[] = 0
309+
@inferred solve(prob_default, Euler(); dt = 0.1)
310+
@test 10 <= counter_p[] <= 11
311+
@test 10 <= counter_d[] <= 11
312+
@test counter_d[] == counter_p[]
313+
@test counter_rhs[] == 0
314+
315+
counter_p[] = 0
316+
counter_d[] = 0
317+
counter_rhs[] = 0
318+
@inferred solve(prob_special, Euler(); dt = 0.1)
319+
@test counter_p[] == 0
320+
@test counter_d[] == 0
321+
@test 10 <= counter_rhs[] <= 11
322+
end
323+
end
324+
219325
@testset "ConservativePDSFunction" begin
220326
prod_1! = (P, u, p, t) -> begin
221327
fill!(P, zero(eltype(P)))

0 commit comments

Comments
 (0)