Skip to content

Commit b44112d

Browse files
committed
refactor: Don't use duplicate solve
1 parent 61e97a8 commit b44112d

File tree

1 file changed

+0
-158
lines changed

1 file changed

+0
-158
lines changed

lib/NonlinearSolveBase/src/polyalg.jl

Lines changed: 0 additions & 158 deletions
Original file line numberDiff line numberDiff line change
@@ -121,78 +121,6 @@ function SciMLBase.__init(
121121
)
122122
end
123123

124-
@generated function CommonSolve.solve!(cache::NonlinearSolvePolyAlgorithmCache{Val{N}}) where {N}
125-
calls = [quote
126-
1 cache.current $(N) || error("Current choices shouldn't get here!")
127-
end]
128-
129-
cache_syms = [gensym("cache") for i in 1:N]
130-
sol_syms = [gensym("sol") for i in 1:N]
131-
u_result_syms = [gensym("u_result") for i in 1:N]
132-
133-
for i in 1:N
134-
push!(calls,
135-
quote
136-
$(cache_syms[i]) = cache.caches[$(i)]
137-
if $(i) == cache.current
138-
cache.alias_u0 && copyto!(cache.u0_aliased, cache.u0)
139-
$(sol_syms[i]) = CommonSolve.solve!($(cache_syms[i]))
140-
if SciMLBase.successful_retcode($(sol_syms[i]))
141-
stats = $(sol_syms[i]).stats
142-
if cache.alias_u0
143-
copyto!(cache.u0, $(sol_syms[i]).u)
144-
$(u_result_syms[i]) = cache.u0
145-
else
146-
$(u_result_syms[i]) = $(sol_syms[i]).u
147-
end
148-
fu = NonlinearSolveBase.get_fu($(cache_syms[i]))
149-
return build_solution_less_specialize(
150-
cache.prob, cache.alg, $(u_result_syms[i]), fu;
151-
retcode = $(sol_syms[i]).retcode, stats,
152-
original = $(sol_syms[i]), trace = $(sol_syms[i]).trace
153-
)
154-
elseif cache.alias_u0
155-
# For safety we need to maintain a copy of the solution
156-
$(u_result_syms[i]) = copy($(sol_syms[i]).u)
157-
end
158-
cache.current = $(i + 1)
159-
end
160-
end)
161-
end
162-
163-
resids = map(Base.Fix2(Symbol, :resid), cache_syms)
164-
for (sym, resid) in zip(cache_syms, resids)
165-
push!(calls, :($(resid) = @isdefined($(sym)) ? $(sym).resid : nothing))
166-
end
167-
push!(calls, quote
168-
fus = tuple($(Tuple(resids)...))
169-
minfu, idx = findmin_caches(cache.prob, fus)
170-
end)
171-
for i in 1:N
172-
push!(calls,
173-
quote
174-
if idx == $(i)
175-
u = cache.alias_u0 ? $(u_result_syms[i]) :
176-
NonlinearSolveBase.get_u(cache.caches[$(i)])
177-
end
178-
end)
179-
end
180-
push!(calls,
181-
quote
182-
retcode = cache.caches[idx].retcode
183-
if cache.alias_u0
184-
copyto!(cache.u0, u)
185-
u = cache.u0
186-
end
187-
return build_solution_less_specialize(
188-
cache.prob, cache.alg, u, fus[idx];
189-
retcode, cache.stats, cache.caches[idx].trace
190-
)
191-
end)
192-
193-
return Expr(:block, calls...)
194-
end
195-
196124
@generated function InternalAPI.step!(
197125
cache::NonlinearSolvePolyAlgorithmCache{Val{N}}, args...; kwargs...
198126
) where {N}
@@ -232,92 +160,6 @@ end
232160
return Expr(:block, calls...)
233161
end
234162

235-
@generated function SciMLBase.__solve(
236-
prob::AbstractNonlinearProblem, alg::NonlinearSolvePolyAlgorithm{Val{N}}, args...;
237-
stats = NLStats(0, 0, 0, 0, 0), alias_u0 = false, verbose = true, kwargs...
238-
) where {N}
239-
sol_syms = [gensym("sol") for _ in 1:N]
240-
prob_syms = [gensym("prob") for _ in 1:N]
241-
u_result_syms = [gensym("u_result") for _ in 1:N]
242-
calls = [quote
243-
current = alg.start_index
244-
if alias_u0 && !ArrayInterface.ismutable(prob.u0)
245-
verbose && @warn "`alias_u0` has been set to `true`, but `u0` is \
246-
immutable (checked using `ArrayInterface.ismutable`)."
247-
alias_u0 = false # If immutable don't care about aliasing
248-
end
249-
u0 = prob.u0
250-
u0_aliased = alias_u0 ? zero(u0) : u0
251-
end]
252-
for i in 1:N
253-
cur_sol = sol_syms[i]
254-
push!(calls,
255-
quote
256-
if current == $(i)
257-
if alias_u0
258-
copyto!(u0_aliased, u0)
259-
$(prob_syms[i]) = SciMLBase.remake(prob; u0 = u0_aliased)
260-
else
261-
$(prob_syms[i]) = prob
262-
end
263-
$(cur_sol) = SciMLBase.__solve(
264-
$(prob_syms[i]), alg.algs[$(i)], args...;
265-
stats, alias_u0, verbose, kwargs...
266-
)
267-
if SciMLBase.successful_retcode($(cur_sol))
268-
if alias_u0
269-
copyto!(u0, $(cur_sol).u)
270-
$(u_result_syms[i]) = u0
271-
else
272-
$(u_result_syms[i]) = $(cur_sol).u
273-
end
274-
return build_solution_less_specialize(
275-
prob, alg, $(u_result_syms[i]), $(cur_sol).resid;
276-
$(cur_sol).retcode, $(cur_sol).stats,
277-
$(cur_sol).trace, original = $(cur_sol)
278-
)
279-
elseif alias_u0
280-
# For safety we need to maintain a copy of the solution
281-
$(u_result_syms[i]) = copy($(cur_sol).u)
282-
end
283-
current = $(i + 1)
284-
end
285-
end)
286-
end
287-
288-
resids = map(Base.Fix2(Symbol, :resid), sol_syms)
289-
for (sym, resid) in zip(sol_syms, resids)
290-
push!(calls, :($(resid) = @isdefined($(sym)) ? $(sym).resid : nothing))
291-
end
292-
293-
push!(calls, quote
294-
resids = tuple($(Tuple(resids)...))
295-
minfu, idx = findmin_resids(prob, resids)
296-
end)
297-
298-
for i in 1:N
299-
push!(calls,
300-
quote
301-
if idx == $(i)
302-
if alias_u0
303-
copyto!(u0, $(u_result_syms[i]))
304-
$(u_result_syms[i]) = u0
305-
else
306-
$(u_result_syms[i]) = $(sol_syms[i]).u
307-
end
308-
return build_solution_less_specialize(
309-
prob, alg, $(u_result_syms[i]), $(sol_syms[i]).resid;
310-
$(sol_syms[i]).retcode, $(sol_syms[i]).stats,
311-
$(sol_syms[i]).trace, original = $(sol_syms[i])
312-
)
313-
end
314-
end)
315-
end
316-
push!(calls, :(error("Current choices shouldn't get here!")))
317-
318-
return Expr(:block, calls...)
319-
end
320-
321163
# Original is often determined on runtime information especially for PolyAlgorithms so it
322164
# is best to never specialize on that
323165
function build_solution_less_specialize(

0 commit comments

Comments
 (0)