Skip to content

Commit 3724fbc

Browse files
committed
Make it generated
1 parent d40b901 commit 3724fbc

File tree

2 files changed

+50
-47
lines changed

2 files changed

+50
-47
lines changed

src/default.jl

Lines changed: 48 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -185,11 +185,14 @@ end
185185
end)
186186
end
187187

188-
resids = map(x -> "$x.resid", sol_syms)
188+
resids = map(x -> Symbol("$(x)_resid"), sol_syms)
189+
for (sym, resid) in zip(sol_syms, resids)
190+
push!(calls, :($(resid) = $(sym).resid))
191+
end
189192

190193
push!(calls,
191194
quote
192-
resids = $(Tuple(resids))
195+
resids = tuple($(Tuple(resids)...))
193196
minfu, idx = findmin(DEFAULT_NORM, resids)
194197
end)
195198

@@ -198,8 +201,7 @@ end
198201
quote
199202
if idx == $i
200203
return SciMLBase.build_solution(prob, alg, $(sol_syms[i]).u,
201-
$(sol_syms[i]).resid; $(sol_syms[i]).retcode, $(sol_syms[i]).stats,
202-
original = $(sol_syms[i]))
204+
$(sol_syms[i]).resid; $(sol_syms[i]).retcode, $(sol_syms[i]).stats)
203205
end
204206
end)
205207
end
@@ -210,53 +212,54 @@ end
210212

211213
## General shared polyalg functions
212214

213-
function perform_step!(cache::Union{RobustMultiNewtonCache,
214-
FastShortcutNonlinearPolyalgCache})
215-
current = cache.current
216-
1 current length(cache.caches) || error("Current choices shouldn't get here!")
217-
218-
current_cache = cache.caches[current]
219-
while not_terminated(current_cache)
220-
perform_step!(current_cache)
215+
@generated function SciMLBase.solve!(cache::Union{RobustMultiNewtonCache{iip, N},
216+
FastShortcutNonlinearPolyalgCache{iip, N}}) where {iip, N}
217+
calls = [
218+
quote
219+
1 cache.current length(cache.caches) ||
220+
error("Current choices shouldn't get here!")
221+
end,
222+
]
223+
224+
cache_syms = [gensym("cache") for i in 1:N]
225+
sol_syms = [gensym("sol") for i in 1:N]
226+
for i in 1:N
227+
push!(calls,
228+
quote
229+
$(cache_syms[i]) = cache.caches[$(i)]
230+
if $(i) == cache.current
231+
$(sol_syms[i]) = SciMLBase.solve!($(cache_syms[i]))
232+
if SciMLBase.successful_retcode($(sol_syms[i]))
233+
stats = $(sol_syms[i]).stats
234+
u = $(sol_syms[i]).u
235+
fu = get_fu($(cache_syms[i]))
236+
return SciMLBase.build_solution($(sol_syms[i]).prob, cache.alg, u,
237+
fu; retcode = ReturnCode.Success, stats,
238+
original = $(sol_syms[i]))
239+
end
240+
cache.current = $(i + 1)
241+
end
242+
end)
221243
end
222244

223-
return nothing
224-
end
225-
226-
function SciMLBase.solve!(cache::Union{RobustMultiNewtonCache,
227-
FastShortcutNonlinearPolyalgCache})
228-
current = cache.current
229-
1 current length(cache.caches) || error("Current choices shouldn't get here!")
230-
231-
current_cache = cache.caches[current]
232-
while current length(cache.caches) # && !all(terminated[current:end])
233-
sol_tmp = solve!(current_cache)
234-
SciMLBase.successful_retcode(sol_tmp) && break
235-
current += 1
236-
cache.current = current
237-
current_cache = cache.caches[current]
245+
resids = map(x -> Symbol("$(x)_resid"), cache_syms)
246+
for (sym, resid) in zip(cache_syms, resids)
247+
push!(calls, :($(resid) = get_fu($(sym))))
238248
end
249+
push!(calls,
250+
quote
251+
retcode = ReturnCode.MaxIters
239252

240-
if current length(cache.caches)
241-
retcode = ReturnCode.Success
242-
243-
stats = cache.caches[current].stats
244-
u = cache.caches[current].u
245-
fu = get_fu(cache.caches[current])
246-
247-
return SciMLBase.build_solution(cache.caches[1].prob, cache.alg, u, fu;
248-
retcode, stats)
249-
else
250-
retcode = ReturnCode.MaxIters
253+
fus = tuple($(Tuple(resids)...))
254+
minfu, idx = findmin(cache.caches[1].internalnorm, fus)
255+
stats = cache.caches[idx].stats
256+
u = cache.caches[idx].u
251257

252-
fus = get_fu.(cache.caches)
253-
minfu, idx = findmin(cache.caches[1].internalnorm, fus)
254-
stats = cache.caches[idx].stats
255-
u = cache.caches[idx].u
258+
return SciMLBase.build_solution(cache.caches[idx].prob, cache.alg, u,
259+
fus[idx]; retcode, stats)
260+
end)
256261

257-
return SciMLBase.build_solution(cache.caches[idx].prob, cache.alg, u, fus[idx];
258-
retcode, stats)
259-
end
262+
return Expr(:block, calls...)
260263
end
261264

262265
function SciMLBase.reinit!(cache::Union{RobustMultiNewtonCache,

src/dfsane.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ end
131131
function perform_step!(cache::DFSaneCache{true})
132132
@unpack alg, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀, σₙ, σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ, τₘₐₓ, nₑₓₚ, M = cache
133133

134-
f = iip ? (dx, x) -> cache.prob.f(dx, x, cache.p) : (x) -> cache.prob.f(x, cache.p)
134+
f = (dx, x) -> cache.prob.f(dx, x, cache.p)
135135

136136
T = eltype(cache.uₙ)
137137
n = cache.stats.nsteps
@@ -208,7 +208,7 @@ end
208208
function perform_step!(cache::DFSaneCache{false})
209209
@unpack alg, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀, σₙ, σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ, τₘₐₓ, nₑₓₚ, M = cache
210210

211-
f = iip ? (dx, x) -> cache.prob.f(dx, x, cache.p) : (x) -> cache.prob.f(x, cache.p)
211+
f = x -> cache.prob.f(x, cache.p)
212212

213213
T = eltype(cache.uₙ)
214214
n = cache.stats.nsteps

0 commit comments

Comments
 (0)