Skip to content

Commit 712ce57

Browse files
committed
feat: nicer generalized printing of structs/results
1 parent 38088d9 commit 712ce57

File tree

9 files changed

+98
-233
lines changed

9 files changed

+98
-233
lines changed

lib/NonlinearSolveBase/src/abstract_types.jl

Lines changed: 40 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,7 @@ end
1010
abstract type AbstractNonlinearSolveBaseAPI end # Mostly used for pretty-printing
1111

1212
function Base.show(io::IO, ::MIME"text/plain", alg::AbstractNonlinearSolveBaseAPI)
13-
main_name = nameof(typeof(alg))
14-
modifiers = String[]
15-
for field in fieldnames(typeof(alg))
16-
val = getfield(alg, field)
17-
Utils.is_default_value(val, field, getfield(alg, field)) && continue
18-
push!(modifiers, "$(field) = $(val)")
19-
end
20-
print(io, "$(main_name)($(join(modifiers, ", ")))")
13+
print(io, Utils.clean_sprint_struct(alg))
2114
return
2215
end
2316

@@ -198,12 +191,9 @@ Abstract Type for all NonlinearSolveBase Algorithms.
198191
199192
- `concrete_jac(alg)`: whether or not the algorithm uses a concrete Jacobian. Defaults
200193
to `nothing`.
201-
- `get_name(alg)`: get the name of the algorithm.
202194
"""
203195
abstract type AbstractNonlinearSolveAlgorithm <: AbstractNonlinearAlgorithm end
204196

205-
get_name(alg::AbstractNonlinearSolveAlgorithm) = Utils.safe_getproperty(alg, Val(:name))
206-
207197
"""
208198
concrete_jac(alg::AbstractNonlinearSolveAlgorithm)::Bool
209199
@@ -218,6 +208,18 @@ concrete_jac(v::Bool) = v
218208
concrete_jac(::Val{false}) = false
219209
concrete_jac(::Val{true}) = true
220210

211+
function Base.show(io::IO, ::MIME"text/plain", alg::AbstractNonlinearSolveAlgorithm)
212+
print(io, Utils.clean_sprint_struct(alg, 0))
213+
return
214+
end
215+
216+
function show_nonlinearsolve_algorithm(
217+
io::IO, alg::AbstractNonlinearSolveAlgorithm, name, indent::Int = 0
218+
)
219+
print(io, name)
220+
print(io, Utils.clean_sprint_struct(alg, indent))
221+
end
222+
221223
"""
222224
AbstractNonlinearSolveCache
223225
@@ -299,30 +301,34 @@ function Base.setindex!(cache::AbstractNonlinearSolveCache, val, sym)
299301
return SII.setu(cache, sym)(cache, val)
300302
end
301303

302-
# XXX: Implement this
303-
# function Base.show(io::IO, cache::AbstractNonlinearSolveCache)
304-
# __show_cache(io, cache, 0)
305-
# end
306-
307-
# function __show_cache(io::IO, cache::AbstractNonlinearSolveCache, indent = 0)
308-
# println(io, "$(nameof(typeof(cache)))(")
309-
# __show_algorithm(io, cache.alg,
310-
# (" "^(indent + 4)) * "alg = " * string(get_name(cache.alg)), indent + 4)
311-
312-
# ustr = sprint(show, get_u(cache); context = (:compact => true, :limit => true))
313-
# println(io, ",\n" * (" "^(indent + 4)) * "u = $(ustr),")
314-
315-
# residstr = sprint(show, get_fu(cache); context = (:compact => true, :limit => true))
316-
# println(io, (" "^(indent + 4)) * "residual = $(residstr),")
317-
318-
# normstr = sprint(
319-
# show, norm(get_fu(cache), Inf); context = (:compact => true, :limit => true))
320-
# println(io, (" "^(indent + 4)) * "inf-norm(residual) = $(normstr),")
304+
function Base.show(io::IO, ::MIME"text/plain", cache::AbstractNonlinearSolveCache)
305+
return show_nonlinearsolve_cache(io, cache)
306+
end
321307

322-
# println(io, " "^(indent + 4) * "nsteps = ", cache.stats.nsteps, ",")
323-
# println(io, " "^(indent + 4) * "retcode = ", cache.retcode)
324-
# print(io, " "^(indent) * ")")
325-
# end
308+
function show_nonlinearsolve_cache(io::IO, cache::AbstractNonlinearSolveCache, indent = 0)
309+
println(io, "$(nameof(typeof(cache)))(")
310+
show_nonlinearsolve_algorithm(
311+
io,
312+
cache.alg,
313+
(" "^(indent + 4)) * "alg = ",
314+
indent + 4
315+
)
316+
317+
ustr = sprint(show, get_u(cache); context = (:compact => true, :limit => true))
318+
println(io, ",\n" * (" "^(indent + 4)) * "u = $(ustr),")
319+
320+
residstr = sprint(show, get_fu(cache); context = (:compact => true, :limit => true))
321+
println(io, (" "^(indent + 4)) * "residual = $(residstr),")
322+
323+
normstr = sprint(
324+
show, norm(get_fu(cache), Inf); context = (:compact => true, :limit => true)
325+
)
326+
println(io, (" "^(indent + 4)) * "inf-norm(residual) = $(normstr),")
327+
328+
println(io, " "^(indent + 4) * "nsteps = ", cache.stats.nsteps, ",")
329+
println(io, " "^(indent + 4) * "retcode = ", cache.retcode)
330+
print(io, " "^(indent) * ")")
331+
end
326332

327333
"""
328334
AbstractLinearSolverCache

lib/NonlinearSolveBase/src/descent/geodesic_acceleration.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,12 @@ get_linear_solver(alg::GeodesicAcceleration) = get_linear_solver(alg.descent)
4949
last_step_accepted::Bool
5050
end
5151

52-
# XXX: Implement
53-
# function __reinit_internal!(
54-
# cache::GeodesicAccelerationCache, args...; p = cache.p, kwargs...)
55-
# cache.p = p
56-
# cache.last_step_accepted = false
57-
# end
52+
function InternalAPI.reinit!(cache::GeodesicAccelerationCache; p = cache.p, kwargs...)
53+
cache.p = p
54+
cache.last_step_accepted = false
55+
end
5856

57+
# XXX: Implement
5958
# @internal_caches GeodesicAccelerationCache :descent_cache
6059

6160
function get_velocity(cache::GeodesicAccelerationCache)

lib/NonlinearSolveBase/src/utils.jl

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ unwrap_val(::Val{x}) where {x} = unwrap_val(x)
115115

116116
is_default_value(::Any, ::Symbol, ::Nothing) = true
117117
is_default_value(::Any, ::Symbol, ::Missing) = true
118+
is_default_value(::Any, ::Symbol, val::Int) = val == typemax(typeof(val))
118119
is_default_value(::Any, ::Symbol, ::Any) = false
119120

120121
maybe_symmetric(x) = Symmetric(x)
@@ -178,7 +179,6 @@ function condition_number(J::AbstractVector)
178179
end
179180
condition_number(::Any) = -1
180181

181-
# XXX: Move to NonlinearSolveQuasiNewton
182182
# compute `pinv` if `inv` won't work
183183
maybe_pinv!!_workspace(A) = nothing, A
184184

@@ -239,4 +239,47 @@ function make_identity!!(A::AbstractMatrix{T}, α) where {T}
239239
return A
240240
end
241241

242+
function clean_sprint_struct(x)
243+
x isa Symbol && return "$(Meta.quot(x))"
244+
x isa Number && return string(x)
245+
(!Base.isstructtype(typeof(x)) || x isa Val) && return string(x)
246+
247+
modifiers = String[]
248+
name = nameof(typeof(x))
249+
for field in fieldnames(typeof(x))
250+
val = getfield(x, field)
251+
if field === :name
252+
name = val
253+
continue
254+
end
255+
is_default_value(x, field, val) && continue
256+
push!(modifiers, "$(field) = $(clean_sprint_struct(val))")
257+
end
258+
259+
return "$(nameof(typeof(x)))($(join(modifiers, ", ")))"
260+
end
261+
262+
function clean_sprint_struct(x, indent::Int)
263+
x isa Symbol && return "$(Meta.quot(x))"
264+
x isa Number && return string(x)
265+
(!Base.isstructtype(typeof(x)) || x isa Val) && return string(x)
266+
267+
modifiers = String[]
268+
name = nameof(typeof(x))
269+
for field in fieldnames(typeof(x))
270+
val = getfield(x, field)
271+
if field === :name
272+
name = val
273+
continue
274+
end
275+
is_default_value(x, field, val) && continue
276+
push!(modifiers, "$(field) = $(clean_sprint_struct(val, indent + 4))")
277+
end
278+
spacing = " "^indent * " "
279+
spacing_last = " "^indent
280+
281+
length(modifiers) == 0 && return "$(nameof(typeof(x)))()"
282+
return "$(name)(\n$(spacing)$(join(modifiers, ",\n$(spacing)"))\n$(spacing_last))"
283+
end
284+
242285
end

lib/NonlinearSolveQuasiNewton/src/initialization.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -151,12 +151,13 @@ NonlinearSolveBase.jacobian_initialized_preinverted(::BroydenLowRankInitializati
151151

152152
function InternalAPI.init(
153153
prob::AbstractNonlinearProblem, alg::BroydenLowRankInitialization,
154-
solver, f::F, fu, u, p; internalnorm::IN = L2_NORM, kwargs...
154+
solver, f::F, fu, u, p;
155+
internalnorm::IN = L2_NORM, maxiters = 1000, kwargs...
155156
) where {F, IN}
156157
if u isa Number # Use the standard broyden
157158
return InternalAPI.init(
158159
prob, IdentityInitialization(true, FullStructure()),
159-
solver, f, fu, u, p; internalnorm, kwargs...
160+
solver, f, fu, u, p; internalnorm, maxiters, kwargs...
160161
)
161162
end
162163
# Pay to cost of slightly more allocations to prevent type-instability for StaticArrays
@@ -212,7 +213,7 @@ Base.adjoint(op::BroydenLowRankJacobian{<:Real}) = transpose(op)
212213

213214
# Storing the transpose to ensure contiguous memory on splicing
214215
function BroydenLowRankJacobian(
215-
fu::StaticArray, u::StaticArray; alpha = true, threshold::Val = Val(10)
216+
fu::StaticArray, u::StaticArray; alpha = true, threshold::Val = Val(10)
216217
)
217218
T = promote_type(eltype(u), eltype(fu))
218219
U = MArray{Tuple{prod(Size(fu)), Utils.unwrap_val(threshold)}, T}(undef)
@@ -265,8 +266,8 @@ function LinearAlgebra.mul!(y::AbstractVector, x::AbstractVector, J::BroydenLowR
265266
end
266267

267268
function LinearAlgebra.mul!(
268-
J::BroydenLowRankJacobian, u::AbstractArray, vᵀ::LinearAlgebra.AdjOrTransAbsVec,
269-
α::Bool, β::Bool
269+
J::BroydenLowRankJacobian, u::AbstractArray, vᵀ::LinearAlgebra.AdjOrTransAbsVec,
270+
α::Bool, β::Bool
270271
)
271272
@assert α & β
272273
idx_update = mod1(J.idx + 1, size(J.U, 2))

lib/NonlinearSolveQuasiNewton/src/reset_conditions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ function InternalAPI.init(
103103
end
104104

105105
@concrete struct IllConditionedJacobianResetCache <: AbstractResetConditionCache
106-
condition_number_threshold <: Number
106+
condition_number_threshold
107107
end
108108

109109
# NOTE: we don't need a reinit! since we establish the threshold based on the eltype

lib/NonlinearSolveQuasiNewton/src/solve.jl

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -37,22 +37,6 @@ examples include [`Broyden`](@ref)'s Method.
3737
name::Symbol
3838
end
3939

40-
# XXX: Implement
41-
# function __show_algorithm(io::IO, alg::QuasiNewtonAlgorithm, name, indent)
42-
# modifiers = String[]
43-
# __is_present(alg.linesearch) && push!(modifiers, "linesearch = $(alg.linesearch)")
44-
# __is_present(alg.trustregion) && push!(modifiers, "trustregion = $(alg.trustregion)")
45-
# push!(modifiers, "descent = $(alg.descent)")
46-
# push!(modifiers, "update_rule = $(alg.update_rule)")
47-
# push!(modifiers, "reinit_rule = $(alg.reinit_rule)")
48-
# push!(modifiers, "max_resets = $(alg.max_resets)")
49-
# push!(modifiers, "initialization = $(alg.initialization)")
50-
# store_inverse_jacobian(alg.update_rule) && push!(modifiers, "inverse_jacobian = true")
51-
# spacing = " "^indent * " "
52-
# spacing_last = " "^indent
53-
# print(io, "$(name)(\n$(spacing)$(join(modifiers, ",\n$(spacing)"))\n$(spacing_last))")
54-
# end
55-
5640
function QuasiNewtonAlgorithm(;
5741
linesearch = missing, trustregion = missing, descent, update_rule, reinit_rule,
5842
initialization, max_resets::Int = typemax(Int), name::Symbol = :unknown,

lib/NonlinearSolveSpectralMethods/src/solve.jl

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,6 @@ function GeneralizedDFSane(;
3232
return GeneralizedDFSane(linesearch, sigma_min, sigma_max, sigma_1, name)
3333
end
3434

35-
# XXX: Add
36-
# function __show_algorithm(io::IO, alg::GeneralizedDFSane, name, indent)
37-
# modifiers = String[]
38-
# __is_present(alg.linesearch) && push!(modifiers, "linesearch = $(alg.linesearch)")
39-
# push!(modifiers, "σ_min = $(alg.σ_min)")
40-
# push!(modifiers, "σ_max = $(alg.σ_max)")
41-
# push!(modifiers, "σ_1 = $(alg.σ_1)")
42-
# spacing = " "^indent * " "
43-
# spacing_last = " "^indent
44-
# print(io, "$(name)(\n$(spacing)$(join(modifiers, ",\n$(spacing)"))\n$(spacing_last))")
45-
# end
46-
4735
@concrete mutable struct GeneralizedDFSaneCache <: AbstractNonlinearSolveCache
4836
# Basic Requirements
4937
fu

src/NonlinearSolve.jl

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ using NonlinearSolveBase: NonlinearSolveBase,
3232
reset_timer!, @static_timeit,
3333
init_nonlinearsolve_trace, update_trace!, reset!
3434

35+
using NonlinearSolveQuasiNewton: Broyden, Klement
36+
3537
# XXX: Remove
3638
import NonlinearSolveBase: InternalAPI, concrete_jac, supports_line_search,
3739
supports_trust_region, last_step_accepted, get_linear_solver,
@@ -149,12 +151,6 @@ include("internal/forward_diff.jl") # we need to define after the algorithms
149151

150152
@compile_workload begin
151153
@sync begin
152-
for T in (Float32, Float64), (fn, u0) in nlfuncs
153-
Threads.@spawn NonlinearProblem(fn, T.(u0), T(2))
154-
end
155-
for (fn, u0) in nlfuncs
156-
Threads.@spawn NonlinearLeastSquaresProblem(fn, u0, 2.0)
157-
end
158154
for prob in probs_nls, alg in nls_algs
159155
Threads.@spawn solve(prob, alg; abstol = 1e-2, verbose = false)
160156
end

0 commit comments

Comments
 (0)