Skip to content

Commit b0a4ba1

Browse files
feat: implement initialization for polyalg cache
1 parent 0ac88e0 commit b0a4ba1

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

lib/NonlinearSolveBase/src/polyalg.jl

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,14 @@ end
5959
u0
6060
u0_aliased
6161
alias_u0::Bool
62+
63+
initializealg
64+
end
65+
66+
function update_parameter_object!(cache::NonlinearSolvePolyAlgorithmCache, p)
67+
foreach(cache.caches) do subcache
68+
update_parameter_object!(subcache, p)
69+
end
6270
end
6371

6472
function NonlinearSolveBase.get_abstol(cache::NonlinearSolvePolyAlgorithmCache)
@@ -104,7 +112,7 @@ end
104112
function SciMLBase.__init(
105113
prob::AbstractNonlinearProblem, alg::NonlinearSolvePolyAlgorithm, args...;
106114
stats = NLStats(0, 0, 0, 0, 0), maxtime = nothing, maxiters = 1000,
107-
internalnorm = L2_NORM, alias_u0 = false, verbose = true, kwargs...
115+
internalnorm = L2_NORM, alias_u0 = false, verbose = true, initializealg = NonlinearSolveDefaultInit(), kwargs...
108116
)
109117
if alias_u0 && !ArrayInterface.ismutable(prob.u0)
110118
verbose && @warn "`alias_u0` has been set to `true`, but `u0` is \
@@ -116,18 +124,20 @@ function SciMLBase.__init(
116124
u0_aliased = alias_u0 ? copy(u0) : u0
117125
alias_u0 && (prob = SciMLBase.remake(prob; u0 = u0_aliased))
118126

119-
return NonlinearSolvePolyAlgorithmCache(
127+
cache = NonlinearSolvePolyAlgorithmCache(
120128
alg.static_length, prob,
121129
map(alg.algs) do solver
122130
SciMLBase.__init(
123131
prob, solver, args...;
124-
stats, maxtime, internalnorm, alias_u0, verbose, kwargs...
132+
stats, maxtime, internalnorm, alias_u0, verbose, initializealg = SciMLBase.NoInit(), kwargs...
125133
)
126134
end,
127135
alg, -1, alg.start_index, 0, stats, 0.0, maxtime,
128136
ReturnCode.Default, false, maxiters, internalnorm,
129-
u0, u0_aliased, alias_u0
137+
u0, u0_aliased, alias_u0, initializealg
130138
)
139+
initialize_cache!(cache)
140+
return cache
131141
end
132142

133143
@generated function InternalAPI.step!(

0 commit comments

Comments
 (0)