Skip to content

Commit 4312611

Browse files
Merge pull request #447 from AayushSabharwal/as/integrator-setindex
feat: support `setindex!` for nonlinear integrators
2 parents a74c321 + 5eba972 commit 4312611

File tree

3 files changed

+24
-35
lines changed

3 files changed

+24
-35
lines changed

src/NonlinearSolve.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ using PrecompileTools: @recompile_invalidations, @compile_workload, @setup_workl
5353
using StaticArraysCore: StaticArray, SVector, SArray, MArray, Size, SMatrix
5454
using SymbolicIndexingInterface: SymbolicIndexingInterface, ParameterIndexingProxy,
5555
symbolic_container, parameter_values, state_values,
56-
getu
56+
getu, setu
5757
end
5858

5959
@reexport using SciMLBase, SimpleNonlinearSolve

src/abstract_types.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,10 @@ function Base.getindex(cache::AbstractNonlinearSolveCache, sym)
220220
return getu(cache, sym)(cache)
221221
end
222222

223+
function Base.setindex!(cache::AbstractNonlinearSolveCache, val, sym)
224+
return setu(cache, sym)(cache, val)
225+
end
226+
223227
function Base.show(io::IO, cache::AbstractNonlinearSolveCache)
224228
__show_cache(io, cache, 0)
225229
end

test/downstream/mtk_cache_indexing_tests.jl

Lines changed: 19 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -10,39 +10,24 @@
1010
# Creates an integrator.
1111
nlprob = NonlinearProblem(nlsys, [X => 1.0], [p => 2.0, d => 3.0])
1212

13-
@testset "GeneralizedFirstOrderAlgorithmCache" begin
14-
nint = init(nlprob, NewtonRaphson())
15-
@test nint isa NonlinearSolve.GeneralizedFirstOrderAlgorithmCache
16-
17-
@test nint[X] == 1.0
18-
@test nint[nlsys.X] == 1.0
19-
@test nint[:X] == 1.0
20-
@test nint.ps[p] == 2.0
21-
@test nint.ps[nlsys.p] == 2.0
22-
@test nint.ps[:p] == 2.0
23-
end
24-
25-
@testset "NonlinearSolvePolyAlgorithmCache" begin
26-
nint = init(nlprob, FastShortcutNonlinearPolyalg())
27-
@test nint isa NonlinearSolve.NonlinearSolvePolyAlgorithmCache
28-
29-
@test nint[X] == 1.0
30-
@test nint[nlsys.X] == 1.0
31-
@test nint[:X] == 1.0
32-
@test nint.ps[p] == 2.0
33-
@test nint.ps[nlsys.p] == 2.0
34-
@test nint.ps[:p] == 2.0
35-
end
36-
37-
@testset "NonlinearSolveNoInitCache" begin
38-
nint = init(nlprob, SimpleNewtonRaphson())
39-
@test nint isa NonlinearSolve.NonlinearSolveNoInitCache
40-
41-
@test nint[X] == 1.0
42-
@test nint[nlsys.X] == 1.0
43-
@test nint[:X] == 1.0
44-
@test nint.ps[p] == 2.0
45-
@test nint.ps[nlsys.p] == 2.0
46-
@test nint.ps[:p] == 2.0
13+
@testset "$integtype" for (alg, integtype) in [
14+
(NewtonRaphson(), NonlinearSolve.GeneralizedFirstOrderAlgorithmCache),
15+
(FastShortcutNonlinearPolyalg(), NonlinearSolve.NonlinearSolvePolyAlgorithmCache),
16+
(SimpleNewtonRaphson(), NonlinearSolve.NonlinearSolveNoInitCache),
17+
]
18+
nint = init(nlprob, alg)
19+
@test nint isa integtype
20+
21+
for (i, sym) in enumerate([X, nlsys.X, :X])
22+
# test both getindex and setindex!
23+
nint[sym] = 1.5i
24+
@test nint[sym] == 1.5i
25+
end
26+
27+
for (i, sym) in enumerate([p, nlsys.p, :p])
28+
# test both getindex and setindex!
29+
nint.ps[sym] = 2.5i
30+
@test nint.ps[sym] == 2.5i
31+
end
4732
end
4833
end

0 commit comments

Comments
 (0)