Skip to content

Commit c7409c8

Browse files
committed
Add pretty printing for the cache
1 parent 84a08dc commit c7409c8

File tree

3 files changed

+42
-10
lines changed

3 files changed

+42
-10
lines changed

src/abstract_types.jl

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,8 @@ not applicable. Else a boolean value is returned.
173173
"""
174174
concrete_jac(::AbstractNonlinearSolveAlgorithm) = nothing
175175

176-
function Base.show(io::IO, alg::AbstractNonlinearSolveAlgorithm{name}) where {name}
177-
__show_algorithm(io, alg, name, 0)
176+
function Base.show(io::IO, alg::AbstractNonlinearSolveAlgorithm)
177+
__show_algorithm(io, alg, get_name(alg), 0)
178178
end
179179

180180
get_name(::AbstractNonlinearSolveAlgorithm{name}) where {name} = name
@@ -207,6 +207,24 @@ Abstract Type for all NonlinearSolve.jl Caches.
207207
"""
208208
abstract type AbstractNonlinearSolveCache{iip, timeit} end
209209

210+
function Base.show(io::IO, cache::AbstractNonlinearSolveCache)
211+
__show_cache(io, cache, 0)
212+
end
213+
214+
function __show_cache(io::IO, cache::AbstractNonlinearSolveCache, indent = 0)
215+
println(io, "$(nameof(typeof(cache)))(")
216+
__show_algorithm(io, cache.alg,
217+
(" "^(indent + 4)) * "alg = " * string(get_name(cache.alg)), indent +
218+
4)
219+
println(io, ",")
220+
println(io, (" "^(indent + 4)) * "u = ", get_u(cache), ",")
221+
println(io, (" "^(indent + 4)) * "residual = ", get_fu(cache), ",")
222+
println(io, (" "^(indent + 4)) * "inf-norm(residual) = ", norm(get_fu(cache), Inf), ",")
223+
println(io, " "^(indent + 4) * "nsteps = ", get_nsteps(cache), ",")
224+
println(io, " "^(indent + 4) * "retcode = ", cache.retcode)
225+
print(io, " "^(indent) * ")")
226+
end
227+
210228
SciMLBase.isinplace(::AbstractNonlinearSolveCache{iip}) where {iip} = iip
211229

212230
get_fu(cache::AbstractNonlinearSolveCache) = cache.fu

src/core/generic.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,9 @@ Performs one step of the nonlinear solver.
4545
respectively. For algorithms that don't use jacobian information, this keyword is
4646
ignored with a one-time warning.
4747
"""
48-
function SciMLBase.step!(cache::AbstractNonlinearSolveCache{iip, timeit},
49-
args...; kwargs...) where {iip, timeit}
48+
function SciMLBase.step!(cache::AbstractNonlinearSolveCache{iip, timeit}, args...;
49+
kwargs...) where {iip, timeit}
50+
not_terminated(cache) || return
5051
timeit && (time_start = time())
5152
res = @static_timeit cache.timer "solve" begin
5253
__step!(cache, args...; kwargs...)

src/default.jl

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,19 @@ end
5555
maxtime
5656
retcode::ReturnCode.T
5757
force_stop::Bool
58+
maxiters::Int
59+
end
60+
61+
function Base.show(io::IO,
62+
cache::NonlinearSolvePolyAlgorithmCache{pType, N}) where {pType, N}
63+
problem_kind = ifelse(pType == :NLS, "NonlinearProblem", "NonlinearLeastSquaresProblem")
64+
println(io, "NonlinearSolvePolyAlgorithmCache for $(problem_kind) with $(N) algorithms")
65+
best_alg = ifelse(cache.best == -1, "nothing", cache.best)
66+
println(io, "Best algorithm: $(best_alg)")
67+
println(io, "Current algorithm: $(cache.current)")
68+
println(io, "nsteps: $(cache.nsteps)")
69+
println(io, "retcode: $(cache.retcode)")
70+
__show_cache(io, cache.caches[cache.current], 0)
5871
end
5972

6073
function reinit_cache!(cache::NonlinearSolvePolyAlgorithmCache, args...; kwargs...)
@@ -68,11 +81,11 @@ for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProb
6881
algType = NonlinearSolvePolyAlgorithm{pType}
6982
@eval begin
7083
function SciMLBase.__init(prob::$probType, alg::$algType{N}, args...;
71-
maxtime = nothing, kwargs...) where {N}
84+
maxtime = nothing, maxiters = 1000, kwargs...) where {N}
7285
return NonlinearSolvePolyAlgorithmCache{isinplace(prob), N, maxtime !== nothing}(
7386
map(solver -> SciMLBase.__init(prob, solver, args...; maxtime, kwargs...),
7487
alg.algs), alg, -1, 1, 0, 0.0, maxtime,
75-
ReturnCode.Default, false)
88+
ReturnCode.Default, false, maxiters)
7689
end
7790
end
7891
end
@@ -124,8 +137,8 @@ end
124137
return Expr(:block, calls...)
125138
end
126139

127-
@generated function __step!(
128-
cache::NonlinearSolvePolyAlgorithmCache{iip, N}, args...; kwargs...) where {iip, N}
140+
@generated function __step!(cache::NonlinearSolvePolyAlgorithmCache{iip, N}, args...;
141+
kwargs...) where {iip, N}
129142
calls = []
130143
cache_syms = [gensym("cache") for i in 1:N]
131144
for i in 1:N
@@ -134,6 +147,7 @@ end
134147
$(cache_syms[i]) = cache.caches[$(i)]
135148
if $(i) == cache.current
136149
__step!($(cache_syms[i]), args...; kwargs...)
150+
$(cache_syms[i]).nsteps += 1
137151
if !not_terminated($(cache_syms[i]))
138152
if SciMLBase.successful_retcode($(cache_syms[i]).retcode)
139153
cache.best = $(i)
@@ -157,8 +171,7 @@ end
157171
cache.force_stop = true
158172
return
159173
end
160-
end
161-
)
174+
end)
162175

163176
return Expr(:block, calls...)
164177
end

0 commit comments

Comments
 (0)