|
1 | | -using OrdinaryDiffEq, Sundials, SciMLBase, Test |
| 1 | +using ModelingToolkit, NonlinearSolve, OrdinaryDiffEq, Sundials, SciMLBase, Test |
| 2 | +using SymbolicIndexingInterface |
| 3 | +using ModelingToolkit: t_nounits as t, D_nounits as D |
2 | 4 |
|
3 | 5 | @testset "CheckInit" begin |
4 | 6 | abstol = 1e-10 |
5 | | - @testset "Sundials + ODEProblem" begin |
6 | | - function rhs(u, p, t) |
7 | | - return [u[1] * t, u[1]^2 - u[2]^2] |
8 | | - end |
9 | | - function rhs!(du, u, p, t) |
10 | | - du[1] = u[1] * t |
11 | | - du[2] = u[1]^2 - u[2]^2 |
12 | | - end |
13 | | - |
14 | | - oopfn = ODEFunction{false}(rhs, mass_matrix = [1 0; 0 0]) |
15 | | - iipfn = ODEFunction{true}(rhs!, mass_matrix = [1 0; 0 0]) |
16 | | - |
17 | | - @testset "Inplace = $(SciMLBase.isinplace(f))" for f in [oopfn, iipfn] |
18 | | - prob = ODEProblem(f, [1.0, 1.0], (0.0, 1.0)) |
19 | | - integ = init(prob, Sundials.ARKODE()) |
20 | | - u0, _, success = SciMLBase.get_initial_values( |
21 | | - prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)); abstol) |
22 | | - @test success |
23 | | - @test u0 == prob.u0 |
24 | | - |
25 | | - integ.u[2] = 2.0 |
26 | | - @test_throws SciMLBase.CheckInitFailureError SciMLBase.get_initial_values( |
27 | | - prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)); abstol) |
28 | | - end |
29 | | - end |
30 | | - |
31 | 7 | @testset "Sundials + DAEProblem" begin |
32 | 8 | function daerhs(du, u, p, t) |
33 | 9 | return [du[1] - u[1] * t - p, u[1]^2 - u[2]^2] |
@@ -59,3 +35,26 @@ using OrdinaryDiffEq, Sundials, SciMLBase, Test |
59 | 35 | end |
60 | 36 | end |
61 | 37 | end |
| 38 | + |
| 39 | +@testset "OverrideInit with MTK" begin |
| 40 | + abstol = 1e-10 |
| 41 | + reltol = 1e-8 |
| 42 | + |
| 43 | + @variables x(t) [guess = 1.0] y(t) [guess = 1.0] |
| 44 | + @parameters p=missing [guess = 1.0] q=missing [guess = 1.0] |
| 45 | + @mtkbuild sys = ODESystem([D(x) ~ p * y + q * t, D(y) ~ 5x + q], t; |
| 46 | + initialization_eqs = [p^2 + q^2 ~ 3, x^3 + y^3 ~ 5]) |
| 47 | + prob = ODEProblem( |
| 48 | + sys, [x => 1.0], (0.0, 1.0), [p => 1.0]; initializealg = SciMLBase.NoInit()) |
| 49 | + |
| 50 | + @test prob.f.initialization_data isa SciMLBase.OverrideInitData |
| 51 | + integ = init(prob, Tsit5()) |
| 52 | + u0, pobj, success = SciMLBase.get_initial_values( |
| 53 | + prob, integ, prob.f, SciMLBase.OverrideInit(), Val(true); |
| 54 | + nlsolve_alg = NewtonRaphson(), abstol, reltol) |
| 55 | + |
| 56 | + @test getu(sys, x)(u0) ≈ 1.0 |
| 57 | + @test getu(sys, y)(u0) ≈ cbrt(4) |
| 58 | + @test getp(sys, p)(pobj) ≈ 1.0 |
| 59 | + @test getp(sys, q)(pobj) ≈ sqrt(2) |
| 60 | +end |
0 commit comments