@@ -39,13 +39,12 @@ Base.eltype(::BDFCache{iip, T}) where {iip, T} = T
3939
4040function SciMLBase. __init (prob:: FODEProblem , alg:: BDF ; dt = 0.0 , reltol = 1e-6 ,
4141 abstol = 1e-6 , maxiters = 1000 , kwargs... )
42- prob = _is_need_convert! (prob)
42+ prob, iip = _is_need_convert! (prob)
4343 dt ≤ 0 ? throw (ArgumentError (" dt must be positive" )) : nothing
4444 (; f, order, u0, tspan, p) = prob
4545 t0 = tspan[1 ]
4646 tfinal = tspan[2 ]
4747 T = eltype (u0)
48- iip = isinplace (prob)
4948
5049 all (x -> x == order[1 ], order) ? nothing :
5150 throw (ArgumentError (" BDF method is only for commensurate order FODE" ))
@@ -65,7 +64,7 @@ function SciMLBase.__init(prob::FODEProblem, alg::BDF; dt = 0.0, reltol = 1e-6,
6564 NNr = 2 ^ (Q + 1 ) * r
6665
6766 # Preallocation of some variables
68- y = [Vector {T} (undef, problem_size) for _ in 1 : (N + 1 )]
67+ y = [u0 for _ in 1 : (N + 1 )]
6968 fy = similar (y)
7069 zn = zeros (T, problem_size, NNr + 1 )
7170
@@ -98,7 +97,11 @@ function SciMLBase.__init(prob::FODEProblem, alg::BDF; dt = 0.0, reltol = 1e-6,
9897 mesh = t0 .+ collect (0 : N) * dt
9998 y[1 ] .= high_order_prob ? u0[1 , :] : u0
10099 temp = high_order_prob ? similar (u0[1 , :]) : similar (u0)
101- prob. f (temp, u0, p, t0)
100+ if iip
101+ prob. f (temp, u0, p, t0)
102+ else
103+ temp .= prob. f (u0, p, t0)
104+ end
102105 fy[1 ] = temp
103106
104107 return BDFCache {iip, T} (prob, alg, mesh, u0, alpha, halpha, y, fy, zn, jac, prob. p,
@@ -250,7 +253,11 @@ function BDF_first_approximations(cache::BDFCache{iip, T}) where {iip, T}
250253 F0 = similar (Y0)
251254 B0 = similar (Y0)
252255 for j in 1 : s
253- prob. f (F0. u[j], cache. y[1 ], p, mesh[j + 1 ])
256+ if iip
257+ prob. f (F0. u[j], cache. y[1 ], p, mesh[j + 1 ])
258+ else
259+ F0. u[j] .= prob. f (cache. y[1 ], p, mesh[j + 1 ])
260+ end
254261 St = ABM_starting_term (cache, mesh[j + 1 ])
255262 B0. u[j] = St + halpha * (omega[j + 1 ] + w[1 , j + 1 ]) * cache. fy[1 ]
256263 end
@@ -269,7 +276,11 @@ function BDF_first_approximations(cache::BDFCache{iip, T}) where {iip, T}
269276 JF = zeros (T, s * problem_size, s * problem_size)
270277 J_temp = Matrix {T} (undef, problem_size, problem_size)
271278 for j in 1 : s
272- jac (J_temp, cache. y[1 ], p, mesh[j + 1 ])
279+ if iip
280+ jac (J_temp, cache. y[1 ], p, mesh[j + 1 ])
281+ else
282+ J_temp .= jac (cache. y[1 ], p, mesh[j + 1 ])
283+ end
273284 JF[((j - 1 ) * problem_size + 1 ): (j * problem_size), ((j - 1 ) * problem_size + 1 ): (j * problem_size)] .= J_temp
274285 end
275286 stop = false
@@ -381,15 +392,15 @@ function jacobian_of_fdefun(f, t, y, p)
381392end
382393
383394function _is_need_convert! (prob:: FODEProblem )
384- length (prob. u0) == 1 ? _convert_single_term_to_vectorized_prob! (prob) : prob
395+ length (prob. u0) == 1 ? ( _convert_single_term_to_vectorized_prob! (prob), true ) : ( prob, SciMLBase . isinplace (prob))
385396end
386397
387398function _convert_single_term_to_vectorized_prob! (prob:: FODEProblem )
388399 if SciMLBase. isinplace (prob)
389400 if isa (prob. u0, AbstractArray)
390401 new_prob = remake (prob; order = [prob. order])
391402 else
392- new_prob = remake (prob; u0 = [ prob. u0] , order = [prob. order])
403+ new_prob = remake (prob; u0 = prob. u0, order = [prob. order])
393404 end
394405 return new_prob
395406 else
0 commit comments