Skip to content

Commit 0d96f7f

Browse files
committed
refactor: move LinearSolve wrapper into NonlinearSolveBase
1 parent ecded14 commit 0d96f7f

19 files changed

+360
-292
lines changed

lib/NonlinearSolveBase/Project.toml

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
name = "NonlinearSolveBase"
22
uuid = "be0214bd-f91f-a760-ac4e-3421ce2b2da0"
33
authors = ["Avik Pal <[email protected]> and contributors"]
4-
version = "1.0.0"
4+
version = "1.1.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
8+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
89
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
910
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
1011
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
@@ -15,22 +16,27 @@ FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
1516
FunctionProperties = "f62d2435-5019-4c03-9749-2d4c77af0cbc"
1617
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1718
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
19+
MaybeInplace = "bb5d69b7-63fc-4a16-80bd-7e42200c7bdb"
1820
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
1921
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
22+
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
2023
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
2124

2225
[weakdeps]
2326
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
2427
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
28+
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
2529
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2630

2731
[extensions]
2832
NonlinearSolveBaseDiffEqBaseExt = "DiffEqBase"
2933
NonlinearSolveBaseForwardDiffExt = "ForwardDiff"
34+
NonlinearSolveBaseLinearSolveExt = "LinearSolve"
3035
NonlinearSolveBaseSparseArraysExt = "SparseArrays"
3136

3237
[compat]
3338
ADTypes = "1.9"
39+
Adapt = "4.1.0"
3440
Aqua = "0.8.7"
3541
ArrayInterface = "7.9"
3642
CommonSolve = "0.2.4"
@@ -45,9 +51,12 @@ ForwardDiff = "0.10.36"
4551
FunctionProperties = "0.1.2"
4652
InteractiveUtils = "<0.0.1, 1"
4753
LinearAlgebra = "1.10"
54+
LinearSolve = "2.36.1"
4855
Markdown = "1.10"
56+
MaybeInplace = "0.1.4"
4957
RecursiveArrayTools = "3"
5058
SciMLBase = "2.50"
59+
SciMLOperators = "0.3.10"
5160
SparseArrays = "1.10"
5261
StaticArraysCore = "1.4"
5362
Test = "1.10"
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
module NonlinearSolveBaseLinearSolveExt
2+
3+
using ArrayInterface: ArrayInterface
4+
using CommonSolve: CommonSolve, init, solve!
5+
using LinearAlgebra: ColumnNorm
6+
using LinearSolve: LinearSolve, QRFactorization
7+
using NonlinearSolveBase: NonlinearSolveBase, LinearSolveJLCache, LinearSolveResult, Utils
8+
using SciMLBase: ReturnCode, LinearProblem
9+
10+
function (cache::LinearSolveJLCache)(;
11+
A = nothing, b = nothing, linu = nothing, du = nothing, p = nothing,
12+
cachedata = nothing, reuse_A_if_factorization = false, verbose = true, kwargs...)
13+
cache.stats.nsolve += 1
14+
15+
update_A!(cache, A, reuse_A_if_factorization)
16+
b !== nothing && setproperty!(cache.lincache, :b, b)
17+
linu !== nothing && NonlinearSolveBase.set_lincache_u!(cache, linu)
18+
19+
Plprev = cache.lincache.Pl
20+
Prprev = cache.lincache.Pr
21+
22+
if cache.precs === nothing
23+
Pl, Pr = nothing, nothing
24+
else
25+
Pl, Pr = cache.precs(cache.lincache.A, du, linu, p, nothing,
26+
A !== nothing, Plprev, Prprev, cachedata)
27+
end
28+
29+
if Pl !== nothing || Pr !== nothing
30+
Pl, Pr = NonlinearSolveBase.wrap_preconditioners(Pl, Pr, linu)
31+
cache.lincache.Pl = Pl
32+
cache.lincache.Pr = Pr
33+
end
34+
35+
linres = solve!(cache.lincache)
36+
cache.lincache = linres.cache
37+
# Unfortunately LinearSolve.jl doesn't have the most uniform ReturnCode handling
38+
if linres.retcode === ReturnCode.Failure
39+
structured_mat = ArrayInterface.isstructured(cache.lincache.A)
40+
is_gpuarray = ArrayInterface.device(cache.lincache.A) isa ArrayInterface.GPU
41+
42+
if !(cache.linsolve isa QRFactorization{ColumnNorm}) && !is_gpuarray &&
43+
!structured_mat
44+
if verbose
45+
@warn "Potential Rank Deficient Matrix Detected. Attempting to solve using \
46+
Pivoted QR Factorization."
47+
end
48+
@assert (A !== nothing)&&(b !== nothing) "This case is not yet supported. \
49+
Please open an issue at \
50+
https://github.com/SciML/NonlinearSolve.jl"
51+
if cache.additional_lincache === nothing # First time
52+
linprob = LinearProblem(A, b; u0 = linres.u)
53+
cache.additional_lincache = init(
54+
linprob, QRFactorization(ColumnNorm()); alias_u0 = false,
55+
alias_A = false, alias_b = false, cache.lincache.Pl, cache.lincache.Pr)
56+
else
57+
cache.additional_lincache.A = A
58+
cache.additional_lincache.b = b
59+
cache.additional_lincache.Pl = cache.lincache.Pl
60+
cache.additional_lincache.Pr = cache.lincache.Pr
61+
end
62+
linres = solve!(cache.additional_lincache)
63+
cache.additional_lincache = linres.cache
64+
linres.retcode === ReturnCode.Failure &&
65+
return LinearSolveResult(; linres.u, success = false)
66+
return LinearSolveResult(; linres.u)
67+
elseif !(cache.linsolve isa QRFactorization{ColumnNorm})
68+
if verbose
69+
if structured_mat || is_gpuarray
70+
mat_desc = structured_mat ? "Structured" : "GPU"
71+
@warn "Potential Rank Deficient Matrix Detected. But Matrix is \
72+
$(mat_desc). Currently, we don't attempt to solve Rank Deficient \
73+
$(mat_desc) Matrices. Please open an issue at \
74+
https://github.com/SciML/NonlinearSolve.jl"
75+
end
76+
end
77+
end
78+
return LinearSolveResult(; linres.u, success = false)
79+
end
80+
81+
return LinearSolveResult(; linres.u)
82+
end
83+
84+
NonlinearSolveBase.needs_square_A(linsolve, ::Any) = LinearSolve.needs_square_A(linsolve)
85+
86+
update_A!(cache::LinearSolveJLCache, ::Nothing, reuse) = cache
87+
function update_A!(cache::LinearSolveJLCache, A, reuse)
88+
return update_A!(cache, Utils.safe_getproperty(cache.linsolve, Val(:alg)), A, reuse)
89+
end
90+
91+
function update_A!(cache::LinearSolveJLCache, alg, A, reuse)
92+
# Not a Factorization Algorithm so don't update `nfactors`
93+
set_lincache_A!(cache.lincache, A)
94+
return cache
95+
end
96+
function update_A!(cache::LinearSolveJLCache, ::LinearSolve.AbstractFactorization, A, reuse)
97+
reuse && return cache
98+
set_lincache_A!(cache.lincache, A)
99+
cache.stats.nfactors += 1
100+
return cache
101+
end
102+
function update_A!(
103+
cache::LinearSolveJLCache, alg::LinearSolve.DefaultLinearSolver, A, reuse)
104+
if alg ==
105+
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.KrylovJL_GMRES)
106+
# Force a reset of the cache. This is not properly handled in LinearSolve.jl
107+
set_lincache_A!(cache.lincache, A)
108+
return cache
109+
end
110+
reuse && return cache
111+
set_lincache_A!(cache.lincache, A)
112+
cache.stats.nfactors += 1
113+
return cache
114+
end
115+
116+
function set_lincache_A!(lincache, new_A)
117+
if !LinearSolve.default_alias_A(lincache.alg, new_A, lincache.b) &&
118+
ArrayInterface.can_setindex(lincache.A)
119+
copyto!(lincache.A, new_A)
120+
end
121+
lincache.A = new_A # important!! triggers special code in `setproperty!`
122+
end
123+
124+
end

lib/NonlinearSolveBase/src/NonlinearSolveBase.jl

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,41 @@
11
module NonlinearSolveBase
22

33
using ADTypes: ADTypes, AbstractADType
4+
using Adapt: WrappedArray
45
using ArrayInterface: ArrayInterface
5-
using CommonSolve: CommonSolve
6+
using CommonSolve: CommonSolve, init
67
using Compat: @compat
78
using ConcreteStructs: @concrete
89
using DifferentiationInterface: DifferentiationInterface
910
using EnzymeCore: EnzymeCore
1011
using FastClosures: @closure
1112
using FunctionProperties: hasbranching
12-
using LinearAlgebra: norm
13+
using LinearAlgebra: Diagonal, norm, ldiv!
1314
using Markdown: @doc_str
15+
using MaybeInplace: @bb
1416
using RecursiveArrayTools: AbstractVectorOfArray, ArrayPartition
1517
using SciMLBase: SciMLBase, ReturnCode, AbstractODEIntegrator, AbstractNonlinearProblem,
16-
NonlinearProblem, NonlinearLeastSquaresProblem, AbstractNonlinearFunction,
17-
@add_kwonly, StandardNonlinearProblem, NullParameters, isinplace,
18-
warn_paramtype
19-
using StaticArraysCore: StaticArray
18+
AbstractNonlinearAlgorithm, AbstractNonlinearFunction,
19+
NonlinearProblem, NonlinearLeastSquaresProblem, StandardNonlinearProblem,
20+
NullParameters, NLStats, LinearProblem, isinplace, warn_paramtype,
21+
@add_kwonly
22+
using SciMLOperators: AbstractSciMLOperator, IdentityOperator
23+
using StaticArraysCore: StaticArray, SMatrix, SArray, MArray
2024

2125
const DI = DifferentiationInterface
2226

2327
include("public.jl")
2428
include("utils.jl")
2529

30+
include("abstract_types.jl")
31+
2632
include("immutable_problem.jl")
2733
include("common_defaults.jl")
2834
include("termination_conditions.jl")
2935

3036
include("autodiff.jl")
37+
include("jacobian.jl")
38+
include("linear_solve.jl")
3139

3240
# Unexported Public API
3341
@compat(public, (L2_NORM, Linf_NORM, NAN_CHECK, UNITLESS_ABS2, get_tolerance))
@@ -36,6 +44,10 @@ include("autodiff.jl")
3644
(select_forward_mode_autodiff, select_reverse_mode_autodiff,
3745
select_jacobian_autodiff))
3846

47+
# public for NonlinearSolve.jl to use
48+
@compat(public, (InternalAPI, supports_line_search, supports_trust_region, set_du!))
49+
@compat(public, (construct_linear_solver, needs_square_A))
50+
3951
export RelTerminationMode, AbsTerminationMode,
4052
NormTerminationMode, RelNormTerminationMode, AbsNormTerminationMode,
4153
RelNormSafeTerminationMode, AbsNormSafeTerminationMode,
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
module InternalAPI
2+
3+
function init end
4+
function solve! end
5+
6+
end
7+
8+
abstract type AbstractDescentDirection end
9+
10+
supports_line_search(::AbstractDescentDirection) = false
11+
supports_trust_region(::AbstractDescentDirection) = false
12+
13+
function get_linear_solver(alg::AbstractDescentDirection)
14+
return Utils.safe_getproperty(alg, Val(:linsolve))
15+
end
16+
17+
abstract type AbstractDescentCache end
18+
19+
SciMLBase.get_du(cache::AbstractDescentCache) = cache.δu
20+
SciMLBase.get_du(cache::AbstractDescentCache, ::Val{1}) = SciMLBase.get_du(cache)
21+
SciMLBase.get_du(cache::AbstractDescentCache, ::Val{N}) where {N} = cache.δus[N - 1]
22+
set_du!(cache::AbstractDescentCache, δu) = (cache.δu = δu)
23+
set_du!(cache::AbstractDescentCache, δu, ::Val{1}) = set_du!(cache, δu)
24+
set_du!(cache::AbstractDescentCache, δu, ::Val{N}) where {N} = (cache.δus[N - 1] = δu)
25+
26+
function last_step_accepted(cache::AbstractDescentCache)
27+
hasfield(typeof(cache), :last_step_accepted) && return cache.last_step_accepted
28+
return true
29+
end
30+
31+
abstract type AbstractNonlinearSolveAlgorithm <: AbstractNonlinearAlgorithm end
32+
33+
get_name(alg::AbstractNonlinearSolveAlgorithm) = Utils.safe_getproperty(alg, Val(:name))
34+
35+
function concrete_jac(alg::AbstractNonlinearSolveAlgorithm)
36+
return concrete_jac(Utils.safe_getproperty(alg, Val(:concrete_jac)))
37+
end
38+
concrete_jac(::Missing) = missing
39+
concrete_jac(v::Bool) = v
40+
concrete_jac(::Val{false}) = false
41+
concrete_jac(::Val{true}) = true
42+
43+
abstract type AbstractNonlinearSolveCache end
44+
45+
abstract type AbstractLinearSolverCache end
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+

0 commit comments

Comments
 (0)