|
| 1 | +module NonlinearSolveBaseLinearSolveExt |
| 2 | + |
| 3 | +using ArrayInterface: ArrayInterface |
| 4 | +using CommonSolve: CommonSolve, init, solve! |
| 5 | +using LinearAlgebra: ColumnNorm |
| 6 | +using LinearSolve: LinearSolve, QRFactorization |
| 7 | +using NonlinearSolveBase: NonlinearSolveBase, LinearSolveJLCache, LinearSolveResult, Utils |
| 8 | +using SciMLBase: ReturnCode, LinearProblem |
| 9 | + |
| 10 | +function (cache::LinearSolveJLCache)(; |
| 11 | + A = nothing, b = nothing, linu = nothing, du = nothing, p = nothing, |
| 12 | + cachedata = nothing, reuse_A_if_factorization = false, verbose = true, kwargs...) |
| 13 | + cache.stats.nsolve += 1 |
| 14 | + |
| 15 | + update_A!(cache, A, reuse_A_if_factorization) |
| 16 | + b !== nothing && setproperty!(cache.lincache, :b, b) |
| 17 | + linu !== nothing && NonlinearSolveBase.set_lincache_u!(cache, linu) |
| 18 | + |
| 19 | + Plprev = cache.lincache.Pl |
| 20 | + Prprev = cache.lincache.Pr |
| 21 | + |
| 22 | + if cache.precs === nothing |
| 23 | + Pl, Pr = nothing, nothing |
| 24 | + else |
| 25 | + Pl, Pr = cache.precs(cache.lincache.A, du, linu, p, nothing, |
| 26 | + A !== nothing, Plprev, Prprev, cachedata) |
| 27 | + end |
| 28 | + |
| 29 | + if Pl !== nothing || Pr !== nothing |
| 30 | + Pl, Pr = NonlinearSolveBase.wrap_preconditioners(Pl, Pr, linu) |
| 31 | + cache.lincache.Pl = Pl |
| 32 | + cache.lincache.Pr = Pr |
| 33 | + end |
| 34 | + |
| 35 | + linres = solve!(cache.lincache) |
| 36 | + cache.lincache = linres.cache |
| 37 | + # Unfortunately LinearSolve.jl doesn't have the most uniform ReturnCode handling |
| 38 | + if linres.retcode === ReturnCode.Failure |
| 39 | + structured_mat = ArrayInterface.isstructured(cache.lincache.A) |
| 40 | + is_gpuarray = ArrayInterface.device(cache.lincache.A) isa ArrayInterface.GPU |
| 41 | + |
| 42 | + if !(cache.linsolve isa QRFactorization{ColumnNorm}) && !is_gpuarray && |
| 43 | + !structured_mat |
| 44 | + if verbose |
| 45 | + @warn "Potential Rank Deficient Matrix Detected. Attempting to solve using \ |
| 46 | + Pivoted QR Factorization." |
| 47 | + end |
| 48 | + @assert (A !== nothing)&&(b !== nothing) "This case is not yet supported. \ |
| 49 | + Please open an issue at \ |
| 50 | + https://github.com/SciML/NonlinearSolve.jl" |
| 51 | + if cache.additional_lincache === nothing # First time |
| 52 | + linprob = LinearProblem(A, b; u0 = linres.u) |
| 53 | + cache.additional_lincache = init( |
| 54 | + linprob, QRFactorization(ColumnNorm()); alias_u0 = false, |
| 55 | + alias_A = false, alias_b = false, cache.lincache.Pl, cache.lincache.Pr) |
| 56 | + else |
| 57 | + cache.additional_lincache.A = A |
| 58 | + cache.additional_lincache.b = b |
| 59 | + cache.additional_lincache.Pl = cache.lincache.Pl |
| 60 | + cache.additional_lincache.Pr = cache.lincache.Pr |
| 61 | + end |
| 62 | + linres = solve!(cache.additional_lincache) |
| 63 | + cache.additional_lincache = linres.cache |
| 64 | + linres.retcode === ReturnCode.Failure && |
| 65 | + return LinearSolveResult(; linres.u, success = false) |
| 66 | + return LinearSolveResult(; linres.u) |
| 67 | + elseif !(cache.linsolve isa QRFactorization{ColumnNorm}) |
| 68 | + if verbose |
| 69 | + if structured_mat || is_gpuarray |
| 70 | + mat_desc = structured_mat ? "Structured" : "GPU" |
| 71 | + @warn "Potential Rank Deficient Matrix Detected. But Matrix is \ |
| 72 | + $(mat_desc). Currently, we don't attempt to solve Rank Deficient \ |
| 73 | + $(mat_desc) Matrices. Please open an issue at \ |
| 74 | + https://github.com/SciML/NonlinearSolve.jl" |
| 75 | + end |
| 76 | + end |
| 77 | + end |
| 78 | + return LinearSolveResult(; linres.u, success = false) |
| 79 | + end |
| 80 | + |
| 81 | + return LinearSolveResult(; linres.u) |
| 82 | +end |
| 83 | + |
| 84 | +NonlinearSolveBase.needs_square_A(linsolve, ::Any) = LinearSolve.needs_square_A(linsolve) |
| 85 | + |
| 86 | +update_A!(cache::LinearSolveJLCache, ::Nothing, reuse) = cache |
| 87 | +function update_A!(cache::LinearSolveJLCache, A, reuse) |
| 88 | + return update_A!(cache, Utils.safe_getproperty(cache.linsolve, Val(:alg)), A, reuse) |
| 89 | +end |
| 90 | + |
| 91 | +function update_A!(cache::LinearSolveJLCache, alg, A, reuse) |
| 92 | + # Not a Factorization Algorithm so don't update `nfactors` |
| 93 | + set_lincache_A!(cache.lincache, A) |
| 94 | + return cache |
| 95 | +end |
| 96 | +function update_A!(cache::LinearSolveJLCache, ::LinearSolve.AbstractFactorization, A, reuse) |
| 97 | + reuse && return cache |
| 98 | + set_lincache_A!(cache.lincache, A) |
| 99 | + cache.stats.nfactors += 1 |
| 100 | + return cache |
| 101 | +end |
| 102 | +function update_A!( |
| 103 | + cache::LinearSolveJLCache, alg::LinearSolve.DefaultLinearSolver, A, reuse) |
| 104 | + if alg == |
| 105 | + LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.KrylovJL_GMRES) |
| 106 | + # Force a reset of the cache. This is not properly handled in LinearSolve.jl |
| 107 | + set_lincache_A!(cache.lincache, A) |
| 108 | + return cache |
| 109 | + end |
| 110 | + reuse && return cache |
| 111 | + set_lincache_A!(cache.lincache, A) |
| 112 | + cache.stats.nfactors += 1 |
| 113 | + return cache |
| 114 | +end |
| 115 | + |
| 116 | +function set_lincache_A!(lincache, new_A) |
| 117 | + if !LinearSolve.default_alias_A(lincache.alg, new_A, lincache.b) && |
| 118 | + ArrayInterface.can_setindex(lincache.A) |
| 119 | + copyto!(lincache.A, new_A) |
| 120 | + end |
| 121 | + lincache.A = new_A # important!! triggers special code in `setproperty!` |
| 122 | +end |
| 123 | + |
| 124 | +end |
0 commit comments