Skip to content

Commit 5eba972

Browse files
test: test nonlinear integrator setindex!, refactor tests
1 parent d050e89 commit 5eba972

File tree

1 file changed

+19
-34
lines changed

1 file changed

+19
-34
lines changed

test/downstream/mtk_cache_indexing_tests.jl

Lines changed: 19 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -10,39 +10,24 @@
1010
# Creates an integrator.
1111
nlprob = NonlinearProblem(nlsys, [X => 1.0], [p => 2.0, d => 3.0])
1212

13-
@testset "GeneralizedFirstOrderAlgorithmCache" begin
14-
nint = init(nlprob, NewtonRaphson())
15-
@test nint isa NonlinearSolve.GeneralizedFirstOrderAlgorithmCache
16-
17-
@test nint[X] == 1.0
18-
@test nint[nlsys.X] == 1.0
19-
@test nint[:X] == 1.0
20-
@test nint.ps[p] == 2.0
21-
@test nint.ps[nlsys.p] == 2.0
22-
@test nint.ps[:p] == 2.0
23-
end
24-
25-
@testset "NonlinearSolvePolyAlgorithmCache" begin
26-
nint = init(nlprob, FastShortcutNonlinearPolyalg())
27-
@test nint isa NonlinearSolve.NonlinearSolvePolyAlgorithmCache
28-
29-
@test nint[X] == 1.0
30-
@test nint[nlsys.X] == 1.0
31-
@test nint[:X] == 1.0
32-
@test nint.ps[p] == 2.0
33-
@test nint.ps[nlsys.p] == 2.0
34-
@test nint.ps[:p] == 2.0
35-
end
36-
37-
@testset "NonlinearSolveNoInitCache" begin
38-
nint = init(nlprob, SimpleNewtonRaphson())
39-
@test nint isa NonlinearSolve.NonlinearSolveNoInitCache
40-
41-
@test nint[X] == 1.0
42-
@test nint[nlsys.X] == 1.0
43-
@test nint[:X] == 1.0
44-
@test nint.ps[p] == 2.0
45-
@test nint.ps[nlsys.p] == 2.0
46-
@test nint.ps[:p] == 2.0
13+
@testset "$integtype" for (alg, integtype) in [
14+
(NewtonRaphson(), NonlinearSolve.GeneralizedFirstOrderAlgorithmCache),
15+
(FastShortcutNonlinearPolyalg(), NonlinearSolve.NonlinearSolvePolyAlgorithmCache),
16+
(SimpleNewtonRaphson(), NonlinearSolve.NonlinearSolveNoInitCache),
17+
]
18+
nint = init(nlprob, alg)
19+
@test nint isa integtype
20+
21+
for (i, sym) in enumerate([X, nlsys.X, :X])
22+
# test both getindex and setindex!
23+
nint[sym] = 1.5i
24+
@test nint[sym] == 1.5i
25+
end
26+
27+
for (i, sym) in enumerate([p, nlsys.p, :p])
28+
# test both getindex and setindex!
29+
nint.ps[sym] = 2.5i
30+
@test nint.ps[sym] == 2.5i
31+
end
4732
end
4833
end

0 commit comments

Comments
 (0)