Skip to content

Commit 230c07f

Browse files
authored
Merge pull request #378 from SciML/ap/step_polyalg
Add step! for polyalgorithms
2 parents 5292a61 + 91efd8a commit 230c07f

File tree

8 files changed

+148
-19
lines changed

8 files changed

+148
-19
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NonlinearSolve"
22
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
33
authors = ["SciML"]
4-
version = "3.7.0"
4+
version = "3.7.1"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/abstract_types.jl

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,8 @@ not applicable. Else a boolean value is returned.
173173
"""
174174
concrete_jac(::AbstractNonlinearSolveAlgorithm) = nothing
175175

176-
function Base.show(io::IO, alg::AbstractNonlinearSolveAlgorithm{name}) where {name}
177-
__show_algorithm(io, alg, name, 0)
176+
function Base.show(io::IO, alg::AbstractNonlinearSolveAlgorithm)
177+
__show_algorithm(io, alg, get_name(alg), 0)
178178
end
179179

180180
get_name(::AbstractNonlinearSolveAlgorithm{name}) where {name} = name
@@ -207,6 +207,23 @@ Abstract Type for all NonlinearSolve.jl Caches.
207207
"""
208208
abstract type AbstractNonlinearSolveCache{iip, timeit} end
209209

210+
function Base.show(io::IO, cache::AbstractNonlinearSolveCache)
211+
__show_cache(io, cache, 0)
212+
end
213+
214+
function __show_cache(io::IO, cache::AbstractNonlinearSolveCache, indent = 0)
215+
println(io, "$(nameof(typeof(cache)))(")
216+
__show_algorithm(io, cache.alg,
217+
(" "^(indent + 4)) * "alg = " * string(get_name(cache.alg)), indent + 4)
218+
println(io, ",")
219+
println(io, (" "^(indent + 4)) * "u = ", get_u(cache), ",")
220+
println(io, (" "^(indent + 4)) * "residual = ", get_fu(cache), ",")
221+
println(io, (" "^(indent + 4)) * "inf-norm(residual) = ", norm(get_fu(cache), Inf), ",")
222+
println(io, " "^(indent + 4) * "nsteps = ", get_nsteps(cache), ",")
223+
println(io, " "^(indent + 4) * "retcode = ", cache.retcode)
224+
print(io, " "^(indent) * ")")
225+
end
226+
210227
SciMLBase.isinplace(::AbstractNonlinearSolveCache{iip}) where {iip} = iip
211228

212229
get_fu(cache::AbstractNonlinearSolveCache) = cache.fu

src/adtypes.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,13 +83,13 @@ AutoSparseForwardDiff
8383
Uses [`PolyesterForwardDiff.jl`](https://github.com/JuliaDiff/PolyesterForwardDiff.jl)
8484
to compute the jacobian. This is essentially parallelized `ForwardDiff.jl`.
8585
86-
- Supports both inplace and out-of-place functions
86+
- Supports both inplace and out-of-place functions
8787
8888
### Keyword Arguments
8989
90-
- `chunksize`: Count of dual numbers that can be propagated simultaneously. Setting
91-
this number to a high value will lead to slowdowns. Use
92-
[`NonlinearSolve.pickchunksize`](@ref) to get a proper value.
90+
- `chunksize`: Count of dual numbers that can be propagated simultaneously. Setting
91+
this number to a high value will lead to slowdowns. Use
92+
[`NonlinearSolve.pickchunksize`](@ref) to get a proper value.
9393
"""
9494
AutoPolyesterForwardDiff
9595

src/core/generic.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,13 @@ Performs one step of the nonlinear solver.
4747
"""
4848
function SciMLBase.step!(cache::AbstractNonlinearSolveCache{iip, timeit},
4949
args...; kwargs...) where {iip, timeit}
50+
not_terminated(cache) || return
5051
timeit && (time_start = time())
5152
res = @static_timeit cache.timer "solve" begin
5253
__step!(cache, args...; kwargs...)
5354
end
54-
cache.nsteps += 1
55+
56+
hasfield(typeof(cache), :nsteps) && (cache.nsteps += 1)
5557

5658
if timeit
5759
cache.total_time += time() - time_start

src/default.jl

Lines changed: 79 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,26 +44,56 @@ function Base.show(io::IO, alg::NonlinearSolvePolyAlgorithm{pType, N}) where {pT
4444
end
4545
end
4646

47-
@concrete mutable struct NonlinearSolvePolyAlgorithmCache{iip, N} <:
48-
AbstractNonlinearSolveCache{iip, false}
47+
@concrete mutable struct NonlinearSolvePolyAlgorithmCache{iip, N, timeit} <:
48+
AbstractNonlinearSolveCache{iip, timeit}
4949
caches
5050
alg
51+
best::Int
5152
current::Int
53+
nsteps::Int
54+
total_time::Float64
55+
maxtime
56+
retcode::ReturnCode.T
57+
force_stop::Bool
58+
maxiters::Int
59+
end
60+
61+
function Base.show(
62+
io::IO, cache::NonlinearSolvePolyAlgorithmCache{pType, N}) where {pType, N}
63+
problem_kind = ifelse(pType == :NLS, "NonlinearProblem", "NonlinearLeastSquaresProblem")
64+
println(io, "NonlinearSolvePolyAlgorithmCache for $(problem_kind) with $(N) algorithms")
65+
best_alg = ifelse(cache.best == -1, "nothing", cache.best)
66+
println(io, "Best algorithm: $(best_alg)")
67+
println(io, "Current algorithm: $(cache.current)")
68+
println(io, "nsteps: $(cache.nsteps)")
69+
println(io, "retcode: $(cache.retcode)")
70+
__show_cache(io, cache.caches[cache.current], 0)
5271
end
5372

5473
function reinit_cache!(cache::NonlinearSolvePolyAlgorithmCache, args...; kwargs...)
5574
foreach(c -> reinit_cache!(c, args...; kwargs...), cache.caches)
5675
cache.current = 1
76+
cache.nsteps = 0
77+
cache.total_time = 0.0
5778
end
5879

5980
for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProblem, :NLLS))
6081
algType = NonlinearSolvePolyAlgorithm{pType}
6182
@eval begin
62-
function SciMLBase.__init(
63-
prob::$probType, alg::$algType{N}, args...; kwargs...) where {N}
64-
return NonlinearSolvePolyAlgorithmCache{isinplace(prob), N}(
65-
map(solver -> SciMLBase.__init(prob, solver, args...; kwargs...), alg.algs),
66-
alg, 1)
83+
function SciMLBase.__init(prob::$probType, alg::$algType{N}, args...;
84+
maxtime = nothing, maxiters = 1000, kwargs...) where {N}
85+
return NonlinearSolvePolyAlgorithmCache{isinplace(prob), N, maxtime !== nothing}(
86+
map(solver -> SciMLBase.__init(prob, solver, args...; maxtime, kwargs...),
87+
alg.algs),
88+
alg,
89+
-1,
90+
1,
91+
0,
92+
0.0,
93+
maxtime,
94+
ReturnCode.Default,
95+
false,
96+
maxiters)
6797
end
6898
end
6999
end
@@ -89,7 +119,7 @@ end
89119
fu = get_fu($(cache_syms[i]))
90120
return SciMLBase.build_solution(
91121
$(sol_syms[i]).prob, cache.alg, u, fu;
92-
retcode = ReturnCode.Success, stats,
122+
retcode = $(sol_syms[i]).retcode, stats,
93123
original = $(sol_syms[i]), trace = $(sol_syms[i]).trace)
94124
end
95125
cache.current = $(i + 1)
@@ -103,12 +133,11 @@ end
103133
end
104134
push!(calls,
105135
quote
106-
retcode = ReturnCode.MaxIters
107-
108136
fus = tuple($(Tuple(resids)...))
109137
minfu, idx = __findmin(cache.caches[1].internalnorm, fus)
110138
stats = cache.caches[idx].stats
111-
u = cache.caches[idx].u
139+
u = get_u(cache.caches[idx])
140+
retcode = cache.caches[idx].retcode
112141

113142
return SciMLBase.build_solution(cache.caches[idx].prob, cache.alg, u, fus[idx];
114143
retcode, stats, cache.caches[idx].trace)
@@ -117,6 +146,45 @@ end
117146
return Expr(:block, calls...)
118147
end
119148

149+
@generated function __step!(
150+
cache::NonlinearSolvePolyAlgorithmCache{iip, N}, args...; kwargs...) where {iip, N}
151+
calls = []
152+
cache_syms = [gensym("cache") for i in 1:N]
153+
for i in 1:N
154+
push!(calls,
155+
quote
156+
$(cache_syms[i]) = cache.caches[$(i)]
157+
if $(i) == cache.current
158+
__step!($(cache_syms[i]), args...; kwargs...)
159+
$(cache_syms[i]).nsteps += 1
160+
if !not_terminated($(cache_syms[i]))
161+
if SciMLBase.successful_retcode($(cache_syms[i]).retcode)
162+
cache.best = $(i)
163+
cache.force_stop = true
164+
cache.retcode = $(cache_syms[i]).retcode
165+
else
166+
cache.current = $(i + 1)
167+
end
168+
end
169+
return
170+
end
171+
end)
172+
end
173+
174+
push!(calls,
175+
quote
176+
if !(1 cache.current length(cache.caches))
177+
minfu, idx = __findmin(first(cache.caches).internalnorm, cache.caches)
178+
cache.best = idx
179+
cache.retcode = cache.caches[cache.best].retcode
180+
cache.force_stop = true
181+
return
182+
end
183+
end)
184+
185+
return Expr(:block, calls...)
186+
end
187+
120188
for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProblem, :NLLS))
121189
algType = NonlinearSolvePolyAlgorithm{pType}
122190
@eval begin

src/utils.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@ LazyArrays.applied_axes(::typeof(__zero), x) = axes(x)
9494
@inline __is_complex(::Type{Complex}) = true
9595
@inline __is_complex(::Type{T}) where {T} = false
9696

97+
function __findmin_caches(f, caches)
98+
return __findmin(f get_fu, caches)
99+
end
97100
function __findmin(f, x)
98101
return findmin(x) do xᵢ
99102
fx = f(xᵢ)

test/misc/polyalg_tests.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,31 @@
2828
cache = init(probN, custom_polyalg; abstol = 1e-9)
2929
solver = solve!(cache)
3030
@test SciMLBase.successful_retcode(solver)
31+
32+
# Test the step interface
33+
cache = init(probN; abstol = 1e-9)
34+
for i in 1:10000
35+
step!(cache)
36+
cache.force_stop && break
37+
end
38+
@test SciMLBase.successful_retcode(cache.retcode)
39+
cache = init(probN, RobustMultiNewton(); abstol = 1e-9)
40+
for i in 1:10000
41+
step!(cache)
42+
cache.force_stop && break
43+
end
44+
@test SciMLBase.successful_retcode(cache.retcode)
45+
cache = init(probN, FastShortcutNonlinearPolyalg(); abstol = 1e-9)
46+
for i in 1:10000
47+
step!(cache)
48+
cache.force_stop && break
49+
end
50+
@test SciMLBase.successful_retcode(cache.retcode)
51+
cache = init(probN, custom_polyalg; abstol = 1e-9)
52+
for i in 1:10000
53+
step!(cache)
54+
cache.force_stop && break
55+
end
3156
end
3257

3358
@testitem "Testing #153 Singular Exception" begin

testing.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
using NonlinearSolve
2+
3+
f(u, p) = u .* u .- 2
4+
5+
u0 = [1.0, 1.0]
6+
7+
prob = NonlinearProblem(f, u0)
8+
9+
nlcache = init(prob);
10+
11+
for i in 1:10
12+
step!(nlcache)
13+
@show nlcache.retcode
14+
end

0 commit comments

Comments
 (0)