Skip to content

Commit 444ae15

Browse files
Merge pull request #2698 from Ickaser/dae_typestab
Make get_differential_vars type stable
2 parents 5a35a15 + 60e6ba1 commit 444ae15

File tree

6 files changed

+50
-35
lines changed

6 files changed

+50
-35
lines changed

docs/src/massmatrixdae/BDF.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ CollapsedDocStrings = true
77
Multistep BDF methods, good for large stiff systems.
88

99
```julia
10+
using LinearAlgebra: Diagonal
1011
function rober(du, u, p, t)
1112
y₁, y₂, y₃ = u
1213
k₁, k₂, k₃ = p
@@ -15,9 +16,7 @@ function rober(du, u, p, t)
1516
du[3] = y₁ + y₂ + y₃ - 1
1617
nothing
1718
end
18-
M = [1.0 0 0
19-
0 1.0 0
20-
0 0 0]
19+
M = Diagonal([1.0, 1.0, 0])
2120
f = ODEFunction(rober, mass_matrix = M)
2221
prob_mm = ODEProblem(f, [1.0, 0.0, 0.0], (0.0, 1e5), (0.04, 3e7, 1e4))
2322
sol = solve(prob_mm, FBDF(), reltol = 1e-8, abstol = 1e-8)

docs/src/massmatrixdae/Rosenbrock.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ For larger systems look at multistep methods.
2020
## Example usage
2121

2222
```julia
23+
using LinearAlgebra: Diagonal
2324
function rober(du, u, p, t)
2425
y₁, y₂, y₃ = u
2526
k₁, k₂, k₃ = p
@@ -28,9 +29,7 @@ function rober(du, u, p, t)
2829
du[3] = y₁ + y₂ + y₃ - 1
2930
nothing
3031
end
31-
M = [1.0 0 0
32-
0 1.0 0
33-
0 0 0]
32+
M = Diagonal([1.0, 1.0, 0])
3433
f = ODEFunction(rober, mass_matrix = M)
3534
prob_mm = ODEProblem(f, [1.0, 0.0, 0.0], (0.0, 1e5), (0.04, 3e7, 1e4))
3635
sol = solve(prob_mm, Rodas5(), reltol = 1e-8, abstol = 1e-8)

lib/OrdinaryDiffEqBDF/test/dae_ad_tests.jl

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
using OrdinaryDiffEqBDF, LinearAlgebra, ForwardDiff, Test
22
using OrdinaryDiffEqNonlinearSolve: BrownFullBasicInit, ShampineCollocationInit
3+
using ADTypes: AutoForwardDiff, AutoFiniteDiff
4+
5+
afd_cs3 = AutoForwardDiff(chunksize=3)
36

47
function f(out, du, u, p, t)
58
out[1] = -p[1] * u[1] + p[3] * u[2] * u[3] - du[1]
@@ -16,22 +19,30 @@ u₀ = [1.0, 0, 0]
1619
du₀ = [-0.04, 0.04, 0.0]
1720
tspan = (0.0, 100000.0)
1821
differential_vars = [true, true, false]
22+
M = Diagonal([1.0, 1.0, 0.0])
1923
prob = DAEProblem(f, du₀, u₀, tspan, p, differential_vars = differential_vars)
2024
prob_oop = DAEProblem{false}(f, du₀, u₀, tspan, p, differential_vars = differential_vars)
21-
sol1 = solve(prob, DFBDF(), dt = 1e-5, abstol = 1e-8, reltol = 1e-8)
22-
sol2 = solve(prob_oop, DFBDF(), dt = 1e-5, abstol = 1e-8, reltol = 1e-8)
25+
f_mm = ODEFunction{true, SciMLBase.AutoSpecialize}(f, mass_matrix = M)
26+
prob_mm = ODEProblem(f_mm, u₀, tspan, p)
27+
@test_broken sol1 = @inferred solve(prob, DFBDF(autodiff=afd_cs3), dt = 1e-5, abstol = 1e-8, reltol = 1e-8)
28+
@test_broken sol2 = @inferred solve(prob_oop, DFBDF(autodiff=afd_cs3), dt = 1e-5, abstol = 1e-8, reltol = 1e-8)
29+
@test_broken sol3 = @inferred solve(prob_mm, FBDF(autodiff=afd_cs3), dt = 1e-5, abstol = 1e-8, reltol = 1e-8)
2330

2431
# These tests flex differentiation of the solver and through the initialization
2532
# To only test the solver part and isolate potential issues, set the initialization to consistent
26-
@testset "Inplace: $(isinplace(_prob)), DAEProblem: $(_prob isa DAEProblem), BrownBasic: $(initalg isa BrownFullBasicInit), Autodiff: $autodiff" for _prob in [
27-
prob, prob_oop],
28-
initalg in [BrownFullBasicInit(), ShampineCollocationInit()], autodiff in [true, false]
33+
@testset "Inplace: $(isinplace(_prob)), DAEProblem: $(_prob isa DAEProblem), BrownBasic: $(initalg isa BrownFullBasicInit), Autodiff: $autodiff" for _prob in [
34+
prob, prob_oop, prob_mm],
35+
initalg in [BrownFullBasicInit(), ShampineCollocationInit()], autodiff in [afd_cs3, AutoFiniteDiff()]
2936

30-
alg = DFBDF(; autodiff)
37+
alg = (_prob isa DAEProblem) ? DFBDF(; autodiff) : FBDF(; autodiff)
3138
function f(p)
3239
sol = solve(remake(_prob, p = p), alg, abstol = 1e-14,
3340
reltol = 1e-14, initializealg = initalg)
3441
sum(sol)
3542
end
36-
@test ForwardDiff.gradient(f, [0.04, 3e7, 1e4])[0, 0, 0] atol=1e-8
43+
if _prob isa DAEProblem
44+
@test ForwardDiff.gradient(f, [0.04, 3e7, 1e4])[0, 0, 0] atol=1e-8
45+
else
46+
@test_broken ForwardDiff.gradient(f, [0.04, 3e7, 1e4])[0, 0, 0] atol=1e-8
47+
end
3748
end

lib/OrdinaryDiffEqCore/src/misc_utils.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,18 +114,21 @@ are differential variables. Returns `DifferentialVarsUndefined` if it cannot
114114
be determined (i.e. the mass matrix is not diagonal).
115115
"""
116116
function get_differential_vars(f, u)
117-
differential_vars = nothing
118117
if hasproperty(f, :mass_matrix)
119118
mm = f.mass_matrix
120119
mm = mm isa MatrixOperator ? mm.A : mm
121120

122-
if mm isa UniformScaling || all(!iszero, mm)
121+
if mm isa UniformScaling
123122
return nothing
123+
elseif all(!iszero, mm)
124+
return trues(size(mm, 1))
124125
elseif !(mm isa SciMLOperators.AbstractSciMLOperator) && isdiag(mm)
125-
differential_vars = reshape(diag(mm) .!= 0, size(u))
126+
return reshape(diag(mm) .!= 0, size(u))
126127
else
127128
return DifferentialVarsUndefined()
128129
end
130+
else
131+
return nothing
129132
end
130133
end
131134

lib/OrdinaryDiffEqRosenbrock/test/dae_rosenbrock_ad_tests.jl

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
using OrdinaryDiffEqRosenbrock, LinearAlgebra, ForwardDiff, Test
22
using OrdinaryDiffEqNonlinearSolve: BrownFullBasicInit, ShampineCollocationInit
3+
using ADTypes: AutoForwardDiff, AutoFiniteDiff
34

5+
afd_cs3 = AutoForwardDiff(chunksize=3)
46
function rober(du, u, p, t)
57
y₁, y₂, y₃ = u
68
k₁, k₂, k₃ = p
@@ -16,25 +18,24 @@ function rober(u, p, t)
1618
k₁ * y₁ - k₃ * y₂ * y₃ - k₂ * y₂^2,
1719
y₁ + y₂ + y₃ - 1]
1820
end
19-
M = [1.0 0 0
20-
0 1.0 0
21-
0 0 0]
22-
roberf = ODEFunction(rober, mass_matrix = M)
23-
roberf_oop = ODEFunction{false}(rober, mass_matrix = M)
21+
M = Diagonal([1.0, 1.0, 0.0])
22+
roberf = ODEFunction{true, SciMLBase.AutoSpecialize}(rober, mass_matrix = M)
23+
roberf_oop = ODEFunction{false, SciMLBase.AutoSpecialize}(rober, mass_matrix = M)
2424
prob_mm = ODEProblem(roberf, [1.0, 0.0, 0.2], (0.0, 1e5), (0.04, 3e7, 1e4))
2525
prob_mm_oop = ODEProblem(roberf_oop, [1.0, 0.0, 0.2], (0.0, 1e5), (0.04, 3e7, 1e4))
26-
sol = solve(prob_mm, Rodas5P(), reltol = 1e-8, abstol = 1e-8)
27-
sol = solve(prob_mm_oop, Rodas5P(), reltol = 1e-8, abstol = 1e-8)
26+
# Both should be inferrable so long as AutoSpecialize is used...
27+
@test_broken sol = @inferred solve(prob_mm, Rodas5P(), reltol = 1e-8, abstol = 1e-8)
28+
sol = @inferred solve(prob_mm_oop, Rodas5P(), reltol = 1e-8, abstol = 1e-8)
2829

2930
# These tests flex differentiation of the solver and through the initialization
3031
# To only test the solver part and isolate potential issues, set the initialization to consistent
3132
@testset "Inplace: $(isinplace(_prob)), DAEProblem: $(_prob isa DAEProblem), BrownBasic: $(initalg isa BrownFullBasicInit), Autodiff: $autodiff" for _prob in [
3233
prob_mm, prob_mm_oop],
33-
initalg in [BrownFullBasicInit(), ShampineCollocationInit()], autodiff in [true, false]
34+
initalg in [BrownFullBasicInit(), ShampineCollocationInit()], autodiff in [AutoForwardDiff(chunksize=3), AutoFiniteDiff()]
3435

3536
alg = Rodas5P(; autodiff)
3637
function f(p)
37-
sol = solve(remake(_prob, p = p), alg, abstol = 1e-14,
38+
sol = @inferred solve(remake(_prob, p = p), alg, abstol = 1e-14,
3839
reltol = 1e-14, initializealg = initalg)
3940
sum(sol)
4041
end

test/interface/mass_matrix_tests.jl

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
using OrdinaryDiffEq, Test, LinearAlgebra, Statistics
22
using OrdinaryDiffEqCore
33
using OrdinaryDiffEqNonlinearSolve: NLFunctional, NLAnderson, NLNewton
4+
using LinearAlgebra: Diagonal
5+
using ADTypes: AutoForwardDiff
46

57
# create mass matrix problems
68
function make_mm_probs(mm_A, ::Val{iip}) where {iip}
@@ -194,11 +196,10 @@ end
194196
u0 = [0.0, 1.0]
195197
tspan = (0.0, 1.0)
196198

197-
M = fill(0.0, 2, 2)
198-
M[1, 1] = 1.0
199+
M = Diagonal([1.0, 0.0])
199200

200201
m_ode_prob = ODEProblem(ODEFunction(f!; mass_matrix = M), u0, tspan)
201-
@test_nowarn sol = solve(m_ode_prob, Rosenbrock23())
202+
@test_nowarn sol = @inferred solve(m_ode_prob, Rosenbrock23(autodiff=AutoForwardDiff(chunksize=2)))
202203

203204
M = [0.637947 0.637947
204205
0.637947 0.637947]
@@ -323,14 +324,15 @@ function dynamics(u, p, t)
323324
end
324325

325326
x0 = zeros(n, n)
326-
M = zeros(n * n) |> Diagonal |> Matrix
327+
M = zeros(n * n) |> Diagonal
327328
M[1, 1] = true # zero mass matrix breaks rosenbrock
328-
f = ODEFunction(dynamics!, mass_matrix = M)
329+
f = ODEFunction{true, SciMLBase.AutoSpecialize}(dynamics!, mass_matrix = M)
329330
tspan = (0, 10.0)
331+
adalg = AutoForwardDiff(chunksize=n)
330332
prob = ODEProblem(f, x0, tspan)
331-
foop = ODEFunction(dynamics, mass_matrix = M)
333+
foop = ODEFunction{false, SciMLBase.AutoSpecialize}(dynamics, mass_matrix = M)
332334
proboop = ODEProblem(f, x0, tspan)
333-
sol = solve(prob, Rosenbrock23())
334-
sol = solve(prob, Rodas4(), initializealg = ShampineCollocationInit())
335-
sol = solve(proboop, Rodas5())
336-
sol = solve(proboop, Rodas4(), initializealg = ShampineCollocationInit())
335+
@test_broken sol = @inferred solve(prob, Rosenbrock23(autodiff=adalg))
336+
@test_broken sol = @inferred solve(prob, Rodas4(autodiff=adalg), initializealg = ShampineCollocationInit())
337+
@test_broken sol = @inferred solve(proboop, Rodas5())
338+
@test_broken sol = @inferred solve(proboop, Rodas4(), initializealg = ShampineCollocationInit())

0 commit comments

Comments
 (0)