Skip to content

Commit e15d6e4

Browse files
committed
Fix multiterms PECE methods
1 parent 9801d8e commit e15d6e4

15 files changed

+99
-37
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
.vscode/
22
./idea
3+
.DS_Store
34

45
# Files generated by invoking Julia with --code-coverage
56
*.jl.cov

enzyme_kinetics.svg

Lines changed: 54 additions & 0 deletions
Loading

ext/FractionalDiffEqFdeSolverExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@ function SciMLBase.__solve(prob::FODEProblem, alg::FdeSolverPECE; dt = 0.0, abst
2323
end
2424
end
2525
par = p isa SciMLBase.NullParameters ? nothing : p
26+
length(u0) == 1 && (u0 = first(u0))
2627
t, y = FDEsolver(newf, tSpan, u0, order, par, JF = prob.f.jac, h = dt, tol = abstol)
27-
u = eachrow(y)
28+
u = collect(Vector{eltype(y)}, eachrow(y))
2829

2930
return DiffEqBase.build_solution(prob, alg, t, u)
3031
end

src/delay/pece.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ function SciMLBase.__init(prob::FDDEProblem, alg::DelayPECE; dt = 0.0, kwargs...
5858
N = length(t0:dt:(tfinal + dt))
5959
yp = _generate_similar_array(u0, N, u0)
6060
y = _generate_similar_array(u0, N - 1, u0)
61-
y[1] = _generate_similar_array(u0, 1, h(p, 0))
61+
y[1] = (length(u0) == 1) ? _generate_similar_array(u0, 1, h(p, 0)) : u0#_generate_similar_array(u0, 1, h(p, 0))
6262

6363
return DelayPECECache{iip, T}(prob, alg, mesh, u0, order[1], τ, p, y, yp, dt, kwargs)
6464
end
@@ -87,7 +87,7 @@ function SciMLBase.solve!(cache::DelayPECECache{iip, T}) where {iip, T}
8787
(; prob, alg, mesh, u0, order, p, dt) = cache
8888
maxn = length(mesh)
8989
l = length(u0)
90-
initial = _generate_similar_array(u0, 1, prob.h(p, 0))
90+
initial = (length(u0) == 1) ? _generate_similar_array(u0, 1, prob.h(p, 0)) : u0
9191
for n in 1:(maxn - 1)
9292
order = OrderWrapper(order, mesh[n + 1])
9393
cache.yp[n + 1] = zeros(T, l)

src/fode/bdf.jl

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,12 @@ Base.eltype(::BDFCache{iip, T}) where {iip, T} = T
3939

4040
function SciMLBase.__init(prob::FODEProblem, alg::BDF; dt = 0.0, reltol = 1e-6,
4141
abstol = 1e-6, maxiters = 1000, kwargs...)
42-
prob = _is_need_convert!(prob)
42+
prob, iip = _is_need_convert!(prob)
4343
dt 0 ? throw(ArgumentError("dt must be positive")) : nothing
4444
(; f, order, u0, tspan, p) = prob
4545
t0 = tspan[1]
4646
tfinal = tspan[2]
4747
T = eltype(u0)
48-
iip = isinplace(prob)
4948

5049
all(x -> x == order[1], order) ? nothing :
5150
throw(ArgumentError("BDF method is only for commensurate order FODE"))
@@ -65,7 +64,7 @@ function SciMLBase.__init(prob::FODEProblem, alg::BDF; dt = 0.0, reltol = 1e-6,
6564
NNr = 2^(Q + 1) * r
6665

6766
# Preallocation of some variables
68-
y = [Vector{T}(undef, problem_size) for _ in 1:(N + 1)]
67+
y = [u0 for _ in 1:(N + 1)]
6968
fy = similar(y)
7069
zn = zeros(T, problem_size, NNr + 1)
7170

@@ -98,7 +97,11 @@ function SciMLBase.__init(prob::FODEProblem, alg::BDF; dt = 0.0, reltol = 1e-6,
9897
mesh = t0 .+ collect(0:N) * dt
9998
y[1] .= high_order_prob ? u0[1, :] : u0
10099
temp = high_order_prob ? similar(u0[1, :]) : similar(u0)
101-
prob.f(temp, u0, p, t0)
100+
if iip
101+
prob.f(temp, u0, p, t0)
102+
else
103+
temp .= prob.f(u0, p, t0)
104+
end
102105
fy[1] = temp
103106

104107
return BDFCache{iip, T}(prob, alg, mesh, u0, alpha, halpha, y, fy, zn, jac, prob.p,
@@ -250,7 +253,11 @@ function BDF_first_approximations(cache::BDFCache{iip, T}) where {iip, T}
250253
F0 = similar(Y0)
251254
B0 = similar(Y0)
252255
for j in 1:s
253-
prob.f(F0.u[j], cache.y[1], p, mesh[j + 1])
256+
if iip
257+
prob.f(F0.u[j], cache.y[1], p, mesh[j + 1])
258+
else
259+
F0.u[j] .= prob.f(cache.y[1], p, mesh[j + 1])
260+
end
254261
St = ABM_starting_term(cache, mesh[j + 1])
255262
B0.u[j] = St + halpha * (omega[j + 1] + w[1, j + 1]) * cache.fy[1]
256263
end
@@ -269,7 +276,11 @@ function BDF_first_approximations(cache::BDFCache{iip, T}) where {iip, T}
269276
JF = zeros(T, s * problem_size, s * problem_size)
270277
J_temp = Matrix{T}(undef, problem_size, problem_size)
271278
for j in 1:s
272-
jac(J_temp, cache.y[1], p, mesh[j + 1])
279+
if iip
280+
jac(J_temp, cache.y[1], p, mesh[j + 1])
281+
else
282+
J_temp .= jac(cache.y[1], p, mesh[j + 1])
283+
end
273284
JF[((j - 1) * problem_size + 1):(j * problem_size), ((j - 1) * problem_size + 1):(j * problem_size)] .= J_temp
274285
end
275286
stop = false
@@ -381,15 +392,15 @@ function jacobian_of_fdefun(f, t, y, p)
381392
end
382393

383394
function _is_need_convert!(prob::FODEProblem)
384-
length(prob.u0) == 1 ? _convert_single_term_to_vectorized_prob!(prob) : prob
395+
length(prob.u0) == 1 ? (_convert_single_term_to_vectorized_prob!(prob), true) : (prob, SciMLBase.isinplace(prob))
385396
end
386397

387398
function _convert_single_term_to_vectorized_prob!(prob::FODEProblem)
388399
if SciMLBase.isinplace(prob)
389400
if isa(prob.u0, AbstractArray)
390401
new_prob = remake(prob; order = [prob.order])
391402
else
392-
new_prob = remake(prob; u0 = [prob.u0], order = [prob.order])
403+
new_prob = remake(prob; u0 = prob.u0, order = [prob.order])
393404
end
394405
return new_prob
395406
else

src/fode/explicit_pi.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ Base.eltype(::PIEXCache{iip, T}) where {iip, T} = T
3434

3535
function SciMLBase.__init(prob::FODEProblem, alg::PIEX; dt = 0.0, abstol = 1e-6, kwargs...)
3636
dt 0 ? throw(ArgumentError("dt must be positive")) : nothing
37-
prob = _is_need_convert!(prob)
37+
prob, iip = _is_need_convert!(prob)
3838
(; f, order, u0, tspan, p) = prob
3939
t0 = tspan[1]
4040
tfinal = tspan[2]

src/fode/implicit_pi_rectangle.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,11 @@ Base.eltype(::PIRectCache{iip, T}) where {iip, T} = T
3636
function SciMLBase.__init(
3737
prob::FODEProblem, alg::PIRect; dt = 0.0, abstol = 1e-6, maxiters = 1000, kwargs...)
3838
dt 0 ? throw(ArgumentError("dt must be positive")) : nothing
39-
prob = _is_need_convert!(prob)
39+
prob, iip = _is_need_convert!(prob)
4040
(; f, order, u0, tspan, p) = prob
4141
t0 = tspan[1]
4242
tfinal = tspan[2]
4343
T = eltype(u0)
44-
iip = isinplace(prob)
4544

4645
alpha_length = length(order)
4746
order = (alpha_length == 1) ? order : order[:]

src/fode/implicit_pi_trapzoid.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,11 @@ Base.eltype(::PITrapCache{iip, T}) where {iip, T} = T
3737
function SciMLBase.__init(
3838
prob::FODEProblem, alg::PITrap; dt = 0.0, abstol = 1e-6, maxiters = 1000, kwargs...)
3939
dt 0 ? throw(ArgumentError("dt must be positive")) : nothing
40-
prob = _is_need_convert!(prob)
40+
prob, iip = _is_need_convert!(prob)
4141
(; f, order, u0, tspan, p) = prob
4242
t0 = tspan[1]
4343
tfinal = tspan[2]
4444
T = eltype(u0)
45-
iip = isinplace(prob)
4645

4746
alpha_length = length(order)
4847
order = (alpha_length == 1) ? order : order[:]

src/fode/newton_gregory.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,11 @@ Base.eltype(::NewtonGregoryCache{iip, T}) where {iip, T} = T
4040
function SciMLBase.__init(prob::FODEProblem, alg::NewtonGregory; dt = 0.0,
4141
reltol = 1e-6, abstol = 1e-6, maxiters = 1000, kwargs...)
4242
dt 0 ? throw(ArgumentError("dt must be positive")) : nothing
43-
prob = _is_need_convert!(prob)
43+
prob, iip = _is_need_convert!(prob)
4444
(; f, order, u0, tspan, p) = prob
4545
t0 = tspan[1]
4646
tfinal = tspan[2]
4747
T = eltype(u0)
48-
iip = isinplace(prob)
4948
all(x -> x == order[1], order) ? nothing :
5049
throw(ArgumentError("BDF method is only for commensurate order FODE"))
5150
alpha = order[1] # commensurate ordre FODE

src/fode/nonlinearalg.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,11 @@ Base.eltype(::NonlinearAlgCache{iip, T}) where {iip, T} = T
2323
function SciMLBase.__init(
2424
prob::FODEProblem, alg::NonLinearAlg; dt = 0.0, L0 = 1e10, kwargs...)
2525
dt 0 ? throw(ArgumentError("dt must be positive")) : nothing
26-
prob = _is_need_convert!(prob)
26+
prob, iip = _is_need_convert!(prob)
2727
(; f, order, u0, tspan, p) = prob
2828
t0 = tspan[1]
2929
tfinal = tspan[2]
3030
T = eltype(u0)
31-
iip = isinplace(prob)
3231
mesh = collect(t0:dt:tfinal)
3332
problem_size = length(u0)
3433
N = round(Int, (tfinal - t0) / dt) + 1

0 commit comments

Comments
 (0)