Skip to content

Commit 2e9256a

Browse files
Merge pull request #1106 from thomvet/fix-ordinarydiffeq-2719-new
Fix ordinarydiffeq 2719 new - Use PreallocationTools.DiffCache for caching in SplitODEProblem
2 parents b209f72 + 6ad5e76 commit 2e9256a

File tree

6 files changed

+46
-5
lines changed

6 files changed

+46
-5
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1919
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
2020
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
2121
Moshi = "2e0e35c7-a2e4-4343-998d-7ef72827ed2d"
22+
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
2223
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
2324
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
2425
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
@@ -75,6 +76,7 @@ Makie = "0.20, 0.21, 0.22, 0.23, 0.24"
7576
Markdown = "1.10"
7677
Moshi = "0.3"
7778
PartialFunctions = "1.1"
79+
PreallocationTools = "0.4.31"
7880
PrecompileTools = "1.2"
7981
Preferences = "1.3"
8082
Printf = "1.10"

src/SciMLBase.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ using Distributed
1414
using Markdown
1515
using Printf
1616
import Preferences
17+
using PreallocationTools: get_tmp, DiffCache
1718

1819
import Logging, ArrayInterface
1920
import IteratorInterfaceExtensions

src/problems/ode_problems.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,7 @@ end
474474
function SplitODEProblem{iip}(f::SplitFunction, u0, tspan, p = NullParameters();
475475
kwargs...) where {iip}
476476
if f._func_cache === nothing && iip
477-
_func_cache = similar(u0)
477+
_func_cache = DiffCache(u0)
478478
f = remake(f; _func_cache)
479479
end
480480
ODEProblem(f, u0, tspan, p, SplitODEProblem{iip}(); kwargs...)

src/scimlfunctions.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2618,9 +2618,10 @@ end
26182618

26192619
(f::SplitFunction)(u, p, t) = f.f1(u, p, t) + f.f2(u, p, t)
26202620
function (f::SplitFunction)(du, u, p, t)
2621-
f.f1(f._func_cache, u, p, t)
2621+
tmp = get_tmp(f._func_cache, du)
2622+
f.f1(tmp, u, p, t)
26222623
f.f2(du, u, p, t)
2623-
du .+= f._func_cache
2624+
du .+= tmp
26242625
end
26252626

26262627
(f::DiscreteFunction)(args...) = f.f(args...)
@@ -2668,9 +2669,10 @@ end
26682669
(f::SplitSDEFunction)(u, p, t) = f.f1(u, p, t) + f.f2(u, p, t)
26692670

26702671
function (f::SplitSDEFunction)(du, u, p, t)
2671-
f.f1(f._func_cache, u, p, t)
2672+
tmp = get_tmp(f._func_cache, du)
2673+
f.f1(tmp, u, p, t)
26722674
f.f2(du, u, p, t)
2673-
du .+= f._func_cache
2675+
du .+= tmp
26742676
end
26752677

26762678
(f::RODEFunction)(args...) = f.f(args...)
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
using OrdinaryDiffEq, Test
2+
3+
# https://github.com/SciML/OrdinaryDiffEq.jl/issues/2719
4+
5+
# set up functions
6+
function f1!(du, u , p, t)
7+
du .= -u.^2
8+
return nothing
9+
end
10+
11+
function f2!(du, u , p, t)
12+
du .= 2u
13+
return nothing
14+
end
15+
16+
function f!(du, u, p, t)
17+
du .= -u.^2 .+ 2u
18+
return nothing
19+
end
20+
21+
#create problems
22+
u0 = ones(2)
23+
tspan = (0.0, 1.0)
24+
prob = ODEProblem(f!, u0, tspan)
25+
f_split! = SplitFunction(f1!, f2!)
26+
prob_split = SplitODEProblem(f_split!, u0, tspan)
27+
28+
#solve
29+
sol = solve(prob, Rodas5P())
30+
sol_split = solve(prob_split, Rodas5P())
31+
32+
#tests
33+
@test sol_split == sol

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,9 @@ end
119119
@time @safetestset "Table Traits" begin
120120
include("downstream/traits.jl")
121121
end
122+
@time @safetestset "SplitODEProblem cache" begin
123+
include("downstream/splitodeproblem_cache.jl")
124+
end
122125
end
123126

124127
if !is_APPVEYOR && GROUP == "SymbolicIndexingInterface"

0 commit comments

Comments
 (0)