diff --git a/Project.toml b/Project.toml index 92b029a3a..d47da8473 100644 --- a/Project.toml +++ b/Project.toml @@ -19,6 +19,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" Moshi = "2e0e35c7-a2e4-4343-998d-7ef72827ed2d" +PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" @@ -75,6 +76,7 @@ Makie = "0.20, 0.21, 0.22, 0.23, 0.24" Markdown = "1.10" Moshi = "0.3" PartialFunctions = "1.1" +PreallocationTools = "0.4.29" PrecompileTools = "1.2" Preferences = "1.3" Printf = "1.10" diff --git a/src/SciMLBase.jl b/src/SciMLBase.jl index 49d9e246d..30b11c0bd 100644 --- a/src/SciMLBase.jl +++ b/src/SciMLBase.jl @@ -14,6 +14,7 @@ using Distributed using Markdown using Printf import Preferences +using PreallocationTools import Logging, ArrayInterface import IteratorInterfaceExtensions diff --git a/src/problems/ode_problems.jl b/src/problems/ode_problems.jl index 2aafa4172..0ab64fa1e 100644 --- a/src/problems/ode_problems.jl +++ b/src/problems/ode_problems.jl @@ -474,7 +474,7 @@ end function SplitODEProblem{iip}(f::SplitFunction, u0, tspan, p = NullParameters(); kwargs...) where {iip} if f._func_cache === nothing && iip - _func_cache = similar(u0) + _func_cache = PreallocationTools.LazyBufferCache() f = remake(f; _func_cache) end ODEProblem(f, u0, tspan, p, SplitODEProblem{iip}(); kwargs...) diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index 0d6c26a02..eae48eeda 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -2618,9 +2618,10 @@ end (f::SplitFunction)(u, p, t) = f.f1(u, p, t) + f.f2(u, p, t) function (f::SplitFunction)(du, u, p, t) - f.f1(f._func_cache, u, p, t) + tmp = PreallocationTools.get_tmp(f._func_cache, u) + f.f1(tmp, u, p, t) f.f2(du, u, p, t) - du .+= f._func_cache + du .+= tmp end (f::DiscreteFunction)(args...) = f.f(args...) @@ -2666,11 +2667,11 @@ end (f::SDDEFunction)(args...) = f.f(args...) (f::SplitSDEFunction)(u, p, t) = f.f1(u, p, t) + f.f2(u, p, t) - function (f::SplitSDEFunction)(du, u, p, t) - f.f1(f._func_cache, u, p, t) + tmp = PreallocationTools.get_tmp(f._func_cache) + f.f1(tmp, u, p, t) f.f2(du, u, p, t) - du .+= f._func_cache + du .+= tmp end (f::RODEFunction)(args...) = f.f(args...)