Skip to content

Commit d7f6976

Browse files
committed
Add some inference checks to DAE tests, make type stable
1 parent f561d23 commit d7f6976

File tree

3 files changed

+21
-19
lines changed

3 files changed

+21
-19
lines changed

lib/OrdinaryDiffEqBDF/test/dae_ad_tests.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ tspan = (0.0, 100000.0)
1818
differential_vars = [true, true, false]
1919
prob = DAEProblem(f, du₀, u₀, tspan, p, differential_vars = differential_vars)
2020
prob_oop = DAEProblem{false}(f, du₀, u₀, tspan, p, differential_vars = differential_vars)
21-
sol1 = solve(prob, DFBDF(), dt = 1e-5, abstol = 1e-8, reltol = 1e-8)
22-
sol2 = solve(prob_oop, DFBDF(), dt = 1e-5, abstol = 1e-8, reltol = 1e-8)
21+
sol1 = @inferred solve(prob, DFBDF(), dt = 1e-5, abstol = 1e-8, reltol = 1e-8)
22+
sol2 = @inferred solve(prob_oop, DFBDF(), dt = 1e-5, abstol = 1e-8, reltol = 1e-8)
2323

2424
# These tests flex differentiation of the solver and through the initialization
2525
# To only test the solver part and isolate potential issues, set the initialization to consistent
@@ -29,7 +29,7 @@ sol2 = solve(prob_oop, DFBDF(), dt = 1e-5, abstol = 1e-8, reltol = 1e-8)
2929

3030
alg = DFBDF(; autodiff)
3131
function f(p)
32-
sol = solve(remake(_prob, p = p), alg, abstol = 1e-14,
32+
@inferred sol = solve(remake(_prob, p = p), alg, abstol = 1e-14,
3333
reltol = 1e-14, initializealg = initalg)
3434
sum(sol)
3535
end

lib/OrdinaryDiffEqRosenbrock/test/dae_rosenbrock_ad_tests.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using OrdinaryDiffEqRosenbrock, LinearAlgebra, ForwardDiff, Test
22
using OrdinaryDiffEqNonlinearSolve: BrownFullBasicInit, ShampineCollocationInit
3+
using ADTypes: AutoForwardDiff, AutoFiniteDiff
34

45
function rober(du, u, p, t)
56
y₁, y₂, y₃ = u
@@ -19,22 +20,23 @@ end
1920
M = [1.0 0 0
2021
0 1.0 0
2122
0 0 0]
22-
roberf = ODEFunction(rober, mass_matrix = M)
23-
roberf_oop = ODEFunction{false}(rober, mass_matrix = M)
23+
# M = Diagonal([1.0, 1.0, 0.0])
24+
roberf = ODEFunction{true, SciMLBase.AutoSpecialize}(rober, mass_matrix = M)
25+
roberf_oop = ODEFunction{false, SciMLBase.AutoSpecialize}(rober, mass_matrix = M)
2426
prob_mm = ODEProblem(roberf, [1.0, 0.0, 0.2], (0.0, 1e5), (0.04, 3e7, 1e4))
2527
prob_mm_oop = ODEProblem(roberf_oop, [1.0, 0.0, 0.2], (0.0, 1e5), (0.04, 3e7, 1e4))
26-
sol = solve(prob_mm, Rodas5P(), reltol = 1e-8, abstol = 1e-8)
27-
sol = solve(prob_mm_oop, Rodas5P(), reltol = 1e-8, abstol = 1e-8)
28+
sol = @inferred solve(prob_mm, Rodas5P(), reltol = 1e-8, abstol = 1e-8)
29+
sol = @inferred solve(prob_mm_oop, Rodas5P(), reltol = 1e-8, abstol = 1e-8)
2830

2931
# These tests flex differentiation of the solver and through the initialization
3032
# To only test the solver part and isolate potential issues, set the initialization to consistent
3133
@testset "Inplace: $(isinplace(_prob)), DAEProblem: $(_prob isa DAEProblem), BrownBasic: $(initalg isa BrownFullBasicInit), Autodiff: $autodiff" for _prob in [
3234
prob_mm, prob_mm_oop],
33-
initalg in [BrownFullBasicInit(), ShampineCollocationInit()], autodiff in [true, false]
35+
initalg in [BrownFullBasicInit(), ShampineCollocationInit()], autodiff in [AutoForwardDiff(chunksize=3), AutoFiniteDiff()]
3436

3537
alg = Rodas5P(; autodiff)
3638
function f(p)
37-
sol = solve(remake(_prob, p = p), alg, abstol = 1e-14,
39+
sol = @inferred solve(remake(_prob, p = p), alg, abstol = 1e-14,
3840
reltol = 1e-14, initializealg = initalg)
3941
sum(sol)
4042
end

test/interface/mass_matrix_tests.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using OrdinaryDiffEq, Test, LinearAlgebra, Statistics
22
using OrdinaryDiffEqCore
33
using OrdinaryDiffEqNonlinearSolve: NLFunctional, NLAnderson, NLNewton
4+
using LinearAlgebra: Diagonal
45

56
# create mass matrix problems
67
function make_mm_probs(mm_A, ::Val{iip}) where {iip}
@@ -194,11 +195,10 @@ end
194195
u0 = [0.0, 1.0]
195196
tspan = (0.0, 1.0)
196197

197-
M = fill(0.0, 2, 2)
198-
M[1, 1] = 1.0
198+
M = Diagonal([1.0, 0.0])
199199

200200
m_ode_prob = ODEProblem(ODEFunction(f!; mass_matrix = M), u0, tspan)
201-
@test_nowarn sol = solve(m_ode_prob, Rosenbrock23())
201+
@test_nowarn sol = @inferred solve(m_ode_prob, Rosenbrock23())
202202

203203
M = [0.637947 0.637947
204204
0.637947 0.637947]
@@ -323,14 +323,14 @@ function dynamics(u, p, t)
323323
end
324324

325325
x0 = zeros(n, n)
326-
M = zeros(n * n) |> Diagonal |> Matrix
326+
M = zeros(n * n) |> Diagonal
327327
M[1, 1] = true # zero mass matrix breaks rosenbrock
328-
f = ODEFunction(dynamics!, mass_matrix = M)
328+
f = ODEFunction{true, SciMLBase.AutoSpecialize}(dynamics!, mass_matrix = M)
329329
tspan = (0, 10.0)
330330
prob = ODEProblem(f, x0, tspan)
331-
foop = ODEFunction(dynamics, mass_matrix = M)
331+
foop = ODEFunction{false, SciMLBase.AutoSpecialize}(dynamics, mass_matrix = M)
332332
proboop = ODEProblem(f, x0, tspan)
333-
sol = solve(prob, Rosenbrock23())
334-
sol = solve(prob, Rodas4(), initializealg = ShampineCollocationInit())
335-
sol = solve(proboop, Rodas5())
336-
sol = solve(proboop, Rodas4(), initializealg = ShampineCollocationInit())
333+
sol = @inferred solve(prob, Rosenbrock23())
334+
sol = @inferred solve(prob, Rodas4(), initializealg = ShampineCollocationInit())
335+
sol = @inferred solve(proboop, Rodas5())
336+
sol = @inferred solve(proboop, Rodas4(), initializealg = ShampineCollocationInit())

0 commit comments

Comments
 (0)