Skip to content

Commit bcb98c5

Browse files
committed
Switch to QR Pivoted on linear solve failure
1 parent f2f5035 commit bcb98c5

File tree

3 files changed

+54
-10
lines changed

3 files changed

+54
-10
lines changed

src/descent/dogleg.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ function __internal_solve!(cache::DoglegCache{INV, NF}, J, fu, u, idx::Val{N} =
8181
want to use a Trust Region."
8282
δu = get_du(cache, idx)
8383
T = promote_type(eltype(u), eltype(fu))
84-
δu_newton, _, _ = __internal_solve!(
85-
cache.newton_cache, J, fu, u, idx; skip_solve, kwargs...)
84+
δu_newton = __internal_solve!(
85+
cache.newton_cache, J, fu, u, idx; skip_solve, kwargs...).δu
8686

8787
# Newton's Step within the trust region
8888
if cache.internalnorm(δu_newton) trust_region
@@ -102,8 +102,8 @@ function __internal_solve!(cache::DoglegCache{INV, NF}, J, fu, u, idx::Val{N} =
102102
@bb cache.δu_cache_mul = JᵀJ × vec(δu_cauchy)
103103
δuJᵀJδu = __dot(δu_cauchy, cache.δu_cache_mul)
104104
else
105-
δu_cauchy, _, _ = __internal_solve!(
106-
cache.cauchy_cache, J, fu, u, idx; skip_solve, kwargs...)
105+
δu_cauchy = __internal_solve!(
106+
cache.cauchy_cache, J, fu, u, idx; skip_solve, kwargs...).δu
107107
J_ = INV ? inv(J) : J
108108
l_grad = cache.internalnorm(δu_cauchy)
109109
@bb cache.JᵀJ_cache = J × vec(δu_cauchy) # TODO: Rename

src/descent/geodesic_acceleration.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,8 @@ function __internal_solve!(cache::GeodesicAccelerationCache, J, fu, u, idx::Val{
106106
skip_solve::Bool = false, kwargs...) where {N}
107107
a, v, δu = get_acceleration(cache, idx), get_velocity(cache, idx), get_du(cache, idx)
108108
skip_solve && return DescentResult(; δu, extras = (; a, v))
109-
v, _, _ = __internal_solve!(
110-
cache.descent_cache, J, fu, u, Val(2N - 1); skip_solve, kwargs...)
109+
v = __internal_solve!(
110+
cache.descent_cache, J, fu, u, Val(2N - 1); skip_solve, kwargs...).δu
111111

112112
@bb @. cache.u_cache = u + cache.h * v
113113
cache.fu_cache = evaluate_f!!(cache.f, cache.fu_cache, cache.u_cache, cache.p)
@@ -116,8 +116,8 @@ function __internal_solve!(cache::GeodesicAccelerationCache, J, fu, u, idx::Val{
116116
Jv = _restructure(cache.fu_cache, cache.Jv)
117117
@bb @. cache.fu_cache = (2 / cache.h) * ((cache.fu_cache - fu) / cache.h - Jv)
118118

119-
a, _, _ = __internal_solve!(cache.descent_cache, J, cache.fu_cache, u, Val(2N);
120-
skip_solve, kwargs..., reuse_A_if_factorization = true)
119+
a = __internal_solve!(cache.descent_cache, J, cache.fu_cache, u, Val(2N);
120+
skip_solve, kwargs..., reuse_A_if_factorization = true).δu
121121

122122
norm_v = cache.internalnorm(v)
123123
norm_a = cache.internalnorm(a)

src/internal/linear_solve.jl

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import LinearSolve: AbstractFactorization, DefaultAlgorithmChoice, DefaultLinearSolver
22

3+
const LinearSolveFailureCode = isdefined(ReturnCode, :InternalLinearSolveFailure) ?
4+
ReturnCode.InternalLinearSolveFailure : ReturnCode.Failure
5+
36
"""
47
LinearSolverCache(alg, linsolve, A, b, u; kwargs...)
58
@@ -23,6 +26,15 @@ handled:
2326
2427
Returns the solution of the system `u` and stores the updated cache in `cache.lincache`.
2528
29+
#### Special Handling for Rank-deficient Matrix `A`
30+
31+
If we detect a failure in the linear solve (mostly due to using an algorithm that doesn't
32+
support rank-deficient matrices), we emit a warning and attempt to solve the problem using
33+
Pivoted QR factorization. This is quite efficient if there are only a few rank-deficient
34+
that originate in the problem. However, if these are quite frequent for the main nonlinear
35+
system, then it is recommended to use a different linear solver that supports rank-deficient
36+
matrices.
37+
2638
#### Keyword Arguments
2739
2840
- `reuse_A_if_factorization`: If `true`, then the factorization of `A` is reused if
@@ -36,6 +48,7 @@ not mutated, we do this by copying over `A` to a preconstructed cache.
3648
@concrete mutable struct LinearSolverCache <: AbstractLinearSolverCache
3749
lincache
3850
linsolve
51+
additional_lincache::Any
3952
A
4053
b
4154
precs
@@ -71,7 +84,7 @@ function LinearSolverCache(alg, linsolve, A, b, u; kwargs...)
7184
(linsolve === nothing && A isa SMatrix) ||
7285
(A isa Diagonal) ||
7386
(linsolve isa typeof(\))
74-
return LinearSolverCache(nothing, nothing, A, b, nothing, 0, 0)
87+
return LinearSolverCache(nothing, nothing, nothing, A, b, nothing, 0, 0)
7588
end
7689
@bb u_ = copy(u_fixed)
7790
linprob = LinearProblem(A, b; u0 = u_, kwargs...)
@@ -89,7 +102,7 @@ function LinearSolverCache(alg, linsolve, A, b, u; kwargs...)
89102
# Unalias here, we will later use these as caches
90103
lincache = init(linprob, linsolve; alias_A = false, alias_b = false, Pl, Pr)
91104

92-
return LinearSolverCache(lincache, linsolve, nothing, nothing, precs, 0, 0)
105+
return LinearSolverCache(lincache, linsolve, nothing, nothing, nothing, precs, 0, 0)
93106
end
94107

95108
# Direct Linear Solve Case without Caching
@@ -108,6 +121,7 @@ function (cache::LinearSolverCache{Nothing})(;
108121
end
109122
return res
110123
end
124+
111125
# Use LinearSolve.jl
112126
function (cache::LinearSolverCache)(;
113127
A = nothing, b = nothing, linu = nothing, du = nothing, p = nothing,
@@ -139,8 +153,38 @@ function (cache::LinearSolverCache)(;
139153
cache.lincache.Pr = Pr
140154
end
141155

156+
# display(A)
157+
142158
linres = solve!(cache.lincache)
159+
# @show cache.lincache.cacheval
160+
# @show LinearAlgebra.issuccess(cache.lincache.cacheval)
143161
cache.lincache = linres.cache
162+
# Unfortunately LinearSolve.jl doesn't have the most uniform ReturnCode handling
163+
if linres.retcode === ReturnCode.Failure
164+
# TODO: We need to guard this somehow because this will surely fail if A is on GPU
165+
# TODO: or some fancy Matrix type
166+
if !(cache.linsolve isa QRFactorization{ColumnNorm})
167+
@warn "Potential Rank Deficient Matrix Detected. Attempting to solve using \
168+
Pivoted QR Factorization."
169+
@assert (A !== nothing)&&(b !== nothing) "This case is not yet supported. \
170+
Please open an issue at \
171+
https://github.com/SciML/NonlinearSolve.jl"
172+
if cache.additional_lincache === nothing # First time
173+
linprob = LinearProblem(A, b; u0 = linres.u)
174+
cache.additional_lincache = init(
175+
linprob, QRFactorization(ColumnNorm()); alias_u0 = false,
176+
alias_A = false, alias_b = false, cache.lincache.Pl, cache.lincache.Pr)
177+
else
178+
cache.additional_lincache.A = A
179+
cache.additional_lincache.b = b
180+
cache.additional_lincache.Pl = cache.lincache.Pl
181+
cache.additional_lincache.Pr = cache.lincache.Pr
182+
end
183+
linres = solve!(cache.additional_lincache)
184+
cache.additional_lincache = linres.cache
185+
return linres.u
186+
end
187+
end
144188

145189
return linres.u
146190
end

0 commit comments

Comments
 (0)