Skip to content

Commit 84a08dc

Browse files
committed
Add step! for polyalgorithms
1 parent bf072d2 commit 84a08dc

File tree

7 files changed

+108
-19
lines changed

7 files changed

+108
-19
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NonlinearSolve"
22
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
33
authors = ["SciML"]
4-
version = "3.6.0"
4+
version = "3.6.1"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/adtypes.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,13 +83,13 @@ AutoSparseForwardDiff
8383
Uses [`PolyesterForwardDiff.jl`](https://github.com/JuliaDiff/PolyesterForwardDiff.jl)
8484
to compute the jacobian. This is essentially parallelized `ForwardDiff.jl`.
8585
86-
- Supports both inplace and out-of-place functions
86+
- Supports both inplace and out-of-place functions
8787
8888
### Keyword Arguments
8989
90-
- `chunksize`: Count of dual numbers that can be propagated simultaneously. Setting
91-
this number to a high value will lead to slowdowns. Use
92-
[`NonlinearSolve.pickchunksize`](@ref) to get a proper value.
90+
- `chunksize`: Count of dual numbers that can be propagated simultaneously. Setting
91+
this number to a high value will lead to slowdowns. Use
92+
[`NonlinearSolve.pickchunksize`](@ref) to get a proper value.
9393
"""
9494
AutoPolyesterForwardDiff
9595

src/core/generic.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ function SciMLBase.step!(cache::AbstractNonlinearSolveCache{iip, timeit},
5151
res = @static_timeit cache.timer "solve" begin
5252
__step!(cache, args...; kwargs...)
5353
end
54-
cache.nsteps += 1
54+
55+
hasfield(typeof(cache), :nsteps) && (cache.nsteps += 1)
5556

5657
if timeit
5758
cache.total_time += time() - time_start

src/default.jl

Lines changed: 59 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -44,26 +44,35 @@ function Base.show(io::IO, alg::NonlinearSolvePolyAlgorithm{pType, N}) where {pT
4444
end
4545
end
4646

47-
@concrete mutable struct NonlinearSolvePolyAlgorithmCache{iip, N} <:
48-
AbstractNonlinearSolveCache{iip, false}
47+
@concrete mutable struct NonlinearSolvePolyAlgorithmCache{iip, N, timeit} <:
48+
AbstractNonlinearSolveCache{iip, timeit}
4949
caches
5050
alg
51+
best::Int
5152
current::Int
53+
nsteps::Int
54+
total_time::Float64
55+
maxtime
56+
retcode::ReturnCode.T
57+
force_stop::Bool
5258
end
5359

5460
function reinit_cache!(cache::NonlinearSolvePolyAlgorithmCache, args...; kwargs...)
5561
foreach(c -> reinit_cache!(c, args...; kwargs...), cache.caches)
5662
cache.current = 1
63+
cache.nsteps = 0
64+
cache.total_time = 0.0
5765
end
5866

5967
for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProblem, :NLLS))
6068
algType = NonlinearSolvePolyAlgorithm{pType}
6169
@eval begin
62-
function SciMLBase.__init(
63-
prob::$probType, alg::$algType{N}, args...; kwargs...) where {N}
64-
return NonlinearSolvePolyAlgorithmCache{isinplace(prob), N}(
65-
map(solver -> SciMLBase.__init(prob, solver, args...; kwargs...), alg.algs),
66-
alg, 1)
70+
function SciMLBase.__init(prob::$probType, alg::$algType{N}, args...;
71+
maxtime = nothing, kwargs...) where {N}
72+
return NonlinearSolvePolyAlgorithmCache{isinplace(prob), N, maxtime !== nothing}(
73+
map(solver -> SciMLBase.__init(prob, solver, args...; maxtime, kwargs...),
74+
alg.algs), alg, -1, 1, 0, 0.0, maxtime,
75+
ReturnCode.Default, false)
6776
end
6877
end
6978
end
@@ -87,9 +96,8 @@ end
8796
stats = $(sol_syms[i]).stats
8897
u = $(sol_syms[i]).u
8998
fu = get_fu($(cache_syms[i]))
90-
return SciMLBase.build_solution(
91-
$(sol_syms[i]).prob, cache.alg, u, fu;
92-
retcode = ReturnCode.Success, stats,
99+
return SciMLBase.build_solution($(sol_syms[i]).prob, cache.alg, u,
100+
fu; retcode = $(sol_syms[i]).retcode, stats,
93101
original = $(sol_syms[i]), trace = $(sol_syms[i]).trace)
94102
end
95103
cache.current = $(i + 1)
@@ -103,12 +111,11 @@ end
103111
end
104112
push!(calls,
105113
quote
106-
retcode = ReturnCode.MaxIters
107-
108114
fus = tuple($(Tuple(resids)...))
109115
minfu, idx = __findmin(cache.caches[1].internalnorm, fus)
110116
stats = cache.caches[idx].stats
111-
u = cache.caches[idx].u
117+
u = get_u(cache.caches[idx])
118+
retcode = cache.caches[idx].retcode
112119

113120
return SciMLBase.build_solution(cache.caches[idx].prob, cache.alg, u, fus[idx];
114121
retcode, stats, cache.caches[idx].trace)
@@ -117,6 +124,45 @@ end
117124
return Expr(:block, calls...)
118125
end
119126

127+
@generated function __step!(
128+
cache::NonlinearSolvePolyAlgorithmCache{iip, N}, args...; kwargs...) where {iip, N}
129+
calls = []
130+
cache_syms = [gensym("cache") for i in 1:N]
131+
for i in 1:N
132+
push!(calls,
133+
quote
134+
$(cache_syms[i]) = cache.caches[$(i)]
135+
if $(i) == cache.current
136+
__step!($(cache_syms[i]), args...; kwargs...)
137+
if !not_terminated($(cache_syms[i]))
138+
if SciMLBase.successful_retcode($(cache_syms[i]).retcode)
139+
cache.best = $(i)
140+
cache.force_stop = true
141+
cache.retcode = $(cache_syms[i]).retcode
142+
else
143+
cache.current = $(i + 1)
144+
end
145+
end
146+
return
147+
end
148+
end)
149+
end
150+
151+
push!(calls,
152+
quote
153+
if !(1 cache.current length(cache.caches))
154+
minfu, idx = __findmin(first(cache.caches).internalnorm, cache.caches)
155+
cache.best = idx
156+
cache.retcode = cache.caches[cache.best].retcode
157+
cache.force_stop = true
158+
return
159+
end
160+
end
161+
)
162+
163+
return Expr(:block, calls...)
164+
end
165+
120166
for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProblem, :NLLS))
121167
algType = NonlinearSolvePolyAlgorithm{pType}
122168
@eval begin

src/utils.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@ LazyArrays.applied_axes(::typeof(__zero), x) = axes(x)
9494
@inline __is_complex(::Type{Complex}) = true
9595
@inline __is_complex(::Type{T}) where {T} = false
9696

97+
function __findmin_caches(f, caches)
98+
return __findmin(f get_fu, caches)
99+
end
97100
function __findmin(f, x)
98101
return findmin(x) do xᵢ
99102
fx = f(xᵢ)

test/misc/polyalg_tests.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,31 @@
2828
cache = init(probN, custom_polyalg; abstol = 1e-9)
2929
solver = solve!(cache)
3030
@test SciMLBase.successful_retcode(solver)
31+
32+
# Test the step interface
33+
cache = init(probN; abstol = 1e-9)
34+
for i in 1:10000
35+
step!(cache)
36+
cache.force_stop && break
37+
end
38+
@test SciMLBase.successful_retcode(cache.retcode)
39+
cache = init(probN, RobustMultiNewton(); abstol = 1e-9)
40+
for i in 1:10000
41+
step!(cache)
42+
cache.force_stop && break
43+
end
44+
@test SciMLBase.successful_retcode(cache.retcode)
45+
cache = init(probN, FastShortcutNonlinearPolyalg(); abstol = 1e-9)
46+
for i in 1:10000
47+
step!(cache)
48+
cache.force_stop && break
49+
end
50+
@test SciMLBase.successful_retcode(cache.retcode)
51+
cache = init(probN, custom_polyalg; abstol = 1e-9)
52+
for i in 1:10000
53+
step!(cache)
54+
cache.force_stop && break
55+
end
3156
end
3257

3358
@testitem "Testing #153 Singular Exception" begin

testing.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
using NonlinearSolve
2+
3+
f(u, p) = u .* u .- 2
4+
5+
u0 = [1.0, 1.0]
6+
7+
prob = NonlinearProblem(f, u0)
8+
9+
nlcache = init(prob);
10+
11+
for i in 1:10
12+
step!(nlcache)
13+
@show nlcache.retcode
14+
end

0 commit comments

Comments
 (0)