|
1 | 1 | using OrdinaryDiffEqRosenbrock, LinearAlgebra, ForwardDiff, Test |
2 | 2 | using OrdinaryDiffEqNonlinearSolve: BrownFullBasicInit, ShampineCollocationInit |
| 3 | +using ADTypes: AutoForwardDiff, AutoFiniteDiff |
3 | 4 |
|
4 | 5 | function rober(du, u, p, t) |
5 | 6 | y₁, y₂, y₃ = u |
|
19 | 20 | M = [1.0 0 0 |
20 | 21 | 0 1.0 0 |
21 | 22 | 0 0 0] |
22 | | -roberf = ODEFunction(rober, mass_matrix = M) |
23 | | -roberf_oop = ODEFunction{false}(rober, mass_matrix = M) |
| 23 | +# M = Diagonal([1.0, 1.0, 0.0]) |
| 24 | +roberf = ODEFunction{true, SciMLBase.AutoSpecialize}(rober, mass_matrix = M) |
| 25 | +roberf_oop = ODEFunction{false, SciMLBase.AutoSpecialize}(rober, mass_matrix = M) |
24 | 26 | prob_mm = ODEProblem(roberf, [1.0, 0.0, 0.2], (0.0, 1e5), (0.04, 3e7, 1e4)) |
25 | 27 | prob_mm_oop = ODEProblem(roberf_oop, [1.0, 0.0, 0.2], (0.0, 1e5), (0.04, 3e7, 1e4)) |
26 | | -sol = solve(prob_mm, Rodas5P(), reltol = 1e-8, abstol = 1e-8) |
27 | | -sol = solve(prob_mm_oop, Rodas5P(), reltol = 1e-8, abstol = 1e-8) |
| 28 | +sol = @inferred solve(prob_mm, Rodas5P(), reltol = 1e-8, abstol = 1e-8) |
| 29 | +sol = @inferred solve(prob_mm_oop, Rodas5P(), reltol = 1e-8, abstol = 1e-8) |
28 | 30 |
|
29 | 31 | # These tests flex differentiation of the solver and through the initialization |
30 | 32 | # To only test the solver part and isolate potential issues, set the initialization to consistent |
31 | 33 | @testset "Inplace: $(isinplace(_prob)), DAEProblem: $(_prob isa DAEProblem), BrownBasic: $(initalg isa BrownFullBasicInit), Autodiff: $autodiff" for _prob in [ |
32 | 34 | prob_mm, prob_mm_oop], |
33 | | - initalg in [BrownFullBasicInit(), ShampineCollocationInit()], autodiff in [true, false] |
| 35 | + initalg in [BrownFullBasicInit(), ShampineCollocationInit()], autodiff in [AutoForwardDiff(chunksize=3), AutoFiniteDiff()] |
34 | 36 |
|
35 | 37 | alg = Rodas5P(; autodiff) |
36 | 38 | function f(p) |
37 | | - sol = solve(remake(_prob, p = p), alg, abstol = 1e-14, |
| 39 | + sol = @inferred solve(remake(_prob, p = p), alg, abstol = 1e-14, |
38 | 40 | reltol = 1e-14, initializealg = initalg) |
39 | 41 | sum(sol) |
40 | 42 | end |
|
0 commit comments