diff --git a/Project.toml b/Project.toml index 8b3f09565..541a5d3c3 100644 --- a/Project.toml +++ b/Project.toml @@ -19,6 +19,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" Moshi = "2e0e35c7-a2e4-4343-998d-7ef72827ed2d" +PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" @@ -75,6 +76,7 @@ Makie = "0.20, 0.21, 0.22, 0.23, 0.24" Markdown = "1.10" Moshi = "0.3" PartialFunctions = "1.1" +PreallocationTools = "0.4.31" PrecompileTools = "1.2" Preferences = "1.3" Printf = "1.10" diff --git a/src/SciMLBase.jl b/src/SciMLBase.jl index aef3272db..54beb0640 100644 --- a/src/SciMLBase.jl +++ b/src/SciMLBase.jl @@ -14,6 +14,7 @@ using Distributed using Markdown using Printf import Preferences +using PreallocationTools: get_tmp, DiffCache import Logging, ArrayInterface import IteratorInterfaceExtensions diff --git a/src/problems/ode_problems.jl b/src/problems/ode_problems.jl index 2aafa4172..da1576d17 100644 --- a/src/problems/ode_problems.jl +++ b/src/problems/ode_problems.jl @@ -474,7 +474,7 @@ end function SplitODEProblem{iip}(f::SplitFunction, u0, tspan, p = NullParameters(); kwargs...) where {iip} if f._func_cache === nothing && iip - _func_cache = similar(u0) + _func_cache = DiffCache(u0) f = remake(f; _func_cache) end ODEProblem(f, u0, tspan, p, SplitODEProblem{iip}(); kwargs...) diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index e1f2aada1..0a73d6bd0 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -2618,9 +2618,10 @@ end (f::SplitFunction)(u, p, t) = f.f1(u, p, t) + f.f2(u, p, t) function (f::SplitFunction)(du, u, p, t) - f.f1(f._func_cache, u, p, t) + tmp = get_tmp(f._func_cache, du) + f.f1(tmp, u, p, t) f.f2(du, u, p, t) - du .+= f._func_cache + du .+= tmp end (f::DiscreteFunction)(args...) = f.f(args...) @@ -2668,9 +2669,10 @@ end (f::SplitSDEFunction)(u, p, t) = f.f1(u, p, t) + f.f2(u, p, t) function (f::SplitSDEFunction)(du, u, p, t) - f.f1(f._func_cache, u, p, t) + tmp = get_tmp(f._func_cache, du) + f.f1(tmp, u, p, t) f.f2(du, u, p, t) - du .+= f._func_cache + du .+= tmp end (f::RODEFunction)(args...) = f.f(args...) diff --git a/test/downstream/splitodeproblem_cache.jl b/test/downstream/splitodeproblem_cache.jl new file mode 100644 index 000000000..130db8c39 --- /dev/null +++ b/test/downstream/splitodeproblem_cache.jl @@ -0,0 +1,33 @@ +using OrdinaryDiffEq, Test + +# https://github.com/SciML/OrdinaryDiffEq.jl/issues/2719 + +# set up functions +function f1!(du, u , p, t) + du .= -u.^2 + return nothing +end + +function f2!(du, u , p, t) + du .= 2u + return nothing +end + +function f!(du, u, p, t) + du .= -u.^2 .+ 2u + return nothing +end + +#create problems +u0 = ones(2) +tspan = (0.0, 1.0) +prob = ODEProblem(f!, u0, tspan) +f_split! = SplitFunction(f1!, f2!) +prob_split = SplitODEProblem(f_split!, u0, tspan) + +#solve +sol = solve(prob, Rodas5P()) +sol_split = solve(prob_split, Rodas5P()) + +#tests +@test sol_split == sol diff --git a/test/runtests.jl b/test/runtests.jl index 3d0983005..6170019a3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -119,6 +119,9 @@ end @time @safetestset "Table Traits" begin include("downstream/traits.jl") end + @time @safetestset "SplitODEProblem cache" begin + include("downstream/splitodeproblem_cache.jl") + end end if !is_APPVEYOR && GROUP == "SymbolicIndexingInterface"