From 59fdd2bcd4035ac499c6a5b42ad600051330a1a7 Mon Sep 17 00:00:00 2001 From: thomvet Date: Fri, 15 Aug 2025 09:20:36 +0200 Subject: [PATCH 1/3] make SplitODEProblem use PreallocationTools.DiffCache for caching --- Project.toml | 2 ++ src/SciMLBase.jl | 1 + src/problems/ode_problems.jl | 2 +- src/scimlfunctions.jl | 10 ++++--- test/downstream/splitodeproblem_cache.jl | 33 ++++++++++++++++++++++++ test/runtests.jl | 3 +++ 6 files changed, 46 insertions(+), 5 deletions(-) create mode 100644 test/downstream/splitodeproblem_cache.jl diff --git a/Project.toml b/Project.toml index 8b3f09565..541a5d3c3 100644 --- a/Project.toml +++ b/Project.toml @@ -19,6 +19,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" Moshi = "2e0e35c7-a2e4-4343-998d-7ef72827ed2d" +PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" @@ -75,6 +76,7 @@ Makie = "0.20, 0.21, 0.22, 0.23, 0.24" Markdown = "1.10" Moshi = "0.3" PartialFunctions = "1.1" +PreallocationTools = "0.4.31" PrecompileTools = "1.2" Preferences = "1.3" Printf = "1.10" diff --git a/src/SciMLBase.jl b/src/SciMLBase.jl index aef3272db..54beb0640 100644 --- a/src/SciMLBase.jl +++ b/src/SciMLBase.jl @@ -14,6 +14,7 @@ using Distributed using Markdown using Printf import Preferences +using PreallocationTools: get_tmp, DiffCache import Logging, ArrayInterface import IteratorInterfaceExtensions diff --git a/src/problems/ode_problems.jl b/src/problems/ode_problems.jl index 2aafa4172..7bd025ed8 100644 --- a/src/problems/ode_problems.jl +++ b/src/problems/ode_problems.jl @@ -474,7 +474,7 @@ end function SplitODEProblem{iip}(f::SplitFunction, u0, tspan, p = NullParameters(); kwargs...) where {iip} if f._func_cache === nothing && iip - _func_cache = similar(u0) + _func_cache = DiffCache(u0) f = remake(f; _func_cache) end ODEProblem(f, u0, tspan, p, SplitODEProblem{iip}(); kwargs...) diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index e1f2aada1..0a73d6bd0 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -2618,9 +2618,10 @@ end (f::SplitFunction)(u, p, t) = f.f1(u, p, t) + f.f2(u, p, t) function (f::SplitFunction)(du, u, p, t) - f.f1(f._func_cache, u, p, t) + tmp = get_tmp(f._func_cache, du) + f.f1(tmp, u, p, t) f.f2(du, u, p, t) - du .+= f._func_cache + du .+= tmp end (f::DiscreteFunction)(args...) = f.f(args...) @@ -2668,9 +2669,10 @@ end (f::SplitSDEFunction)(u, p, t) = f.f1(u, p, t) + f.f2(u, p, t) function (f::SplitSDEFunction)(du, u, p, t) - f.f1(f._func_cache, u, p, t) + tmp = get_tmp(f._func_cache, du) + f.f1(tmp, u, p, t) f.f2(du, u, p, t) - du .+= f._func_cache + du .+= tmp end (f::RODEFunction)(args...) = f.f(args...) diff --git a/test/downstream/splitodeproblem_cache.jl b/test/downstream/splitodeproblem_cache.jl new file mode 100644 index 000000000..546e99e15 --- /dev/null +++ b/test/downstream/splitodeproblem_cache.jl @@ -0,0 +1,33 @@ +using OrdinaryDiffEq, Test + +# https://github.com/SciML/OrdinaryDiffEq.jl/issues/2719 + +# set up functions +function f1!(du, u , p, t) + du .= -u.^2 + return nothing +end + +function f2!(du, u , p, t) + du .= 2u + return nothing +end + +function f!(du, u, p, t) + du .= -u.^2 .+ 2u + return nothing +end + +#create problems +u0 = ones(2) +tspan = (0.0, 1.0) +prob = ODEProblem(f!, u0, tspan) +f_split! = SplitFunction(f1!, f2!) +prob_split = SplitODEProblem(f_split!, u0, tspan) + +#solve +sol = solve(prob, Rodas5P()) +sol_split = solve(prob_split, Rodas5P()) + +#tests +@test sol_split == sol \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 3d0983005..6170019a3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -119,6 +119,9 @@ end @time @safetestset "Table Traits" begin include("downstream/traits.jl") end + @time @safetestset "SplitODEProblem cache" begin + include("downstream/splitodeproblem_cache.jl") + end end if !is_APPVEYOR && GROUP == "SymbolicIndexingInterface" From d7a8e15662950f44234cf7a0508e56d55e8382d3 Mon Sep 17 00:00:00 2001 From: thomvet Date: Fri, 15 Aug 2025 09:21:54 +0200 Subject: [PATCH 2/3] space --- src/problems/ode_problems.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/problems/ode_problems.jl b/src/problems/ode_problems.jl index 7bd025ed8..da1576d17 100644 --- a/src/problems/ode_problems.jl +++ b/src/problems/ode_problems.jl @@ -474,7 +474,7 @@ end function SplitODEProblem{iip}(f::SplitFunction, u0, tspan, p = NullParameters(); kwargs...) where {iip} if f._func_cache === nothing && iip - _func_cache = DiffCache(u0) + _func_cache = DiffCache(u0) f = remake(f; _func_cache) end ODEProblem(f, u0, tspan, p, SplitODEProblem{iip}(); kwargs...) From 6ad5e762b72a0de1bb071e5df6fdc0fa539c6376 Mon Sep 17 00:00:00 2001 From: thomvet Date: Fri, 15 Aug 2025 09:22:16 +0200 Subject: [PATCH 3/3] NL --- test/downstream/splitodeproblem_cache.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/downstream/splitodeproblem_cache.jl b/test/downstream/splitodeproblem_cache.jl index 546e99e15..130db8c39 100644 --- a/test/downstream/splitodeproblem_cache.jl +++ b/test/downstream/splitodeproblem_cache.jl @@ -30,4 +30,4 @@ sol = solve(prob, Rodas5P()) sol_split = solve(prob_split, Rodas5P()) #tests -@test sol_split == sol \ No newline at end of file +@test sol_split == sol