1
+ module LinearSolveSparseArrays
2
+
3
+ using LinearSolve, LinearAlgebra
4
+ using SparseArrays
5
+ using SparseArrays: AbstractSparseMatrixCSC, nonzeros, rowvals, getcolptr
6
+
7
+ # Can't `using KLU` because cannot have a dependency in there without
8
+ # requiring the user does `using KLU`
9
+ # But there's no reason to require it because SparseArrays will already
10
+ # load SuiteSparse and thus all of the underlying KLU code
11
+ include (" ../src/KLU/klu.jl" )
12
+
13
+ LinearSolve. issparsematrixcsc (A:: AbstractSparseMatrixCSC ) = true
14
+
15
+ function LinearSolve. handle_sparsematrixcsc_lu (A:: AbstractSparseMatrixCSC )
16
+ lu (SparseMatrixCSC (size (A)... , getcolptr (A), rowvals (A), nonzeros (A)),
17
+ check = false )
18
+ end
19
+
20
+ function LinearSolve. init_cacheval (alg:: GenericFactorization ,
21
+ A:: Union {Hermitian{T, <: SparseMatrixCSC },
22
+ Symmetric{T, <: SparseMatrixCSC }}, b, u, Pl, Pr,
23
+ maxiters:: Int , abstol, reltol, verbose:: Bool ,
24
+ assumptions:: OperatorAssumptions ) where {T}
25
+ newA = copy (convert (AbstractMatrix, A))
26
+ LinearSolve. do_factorization (alg, newA, b, u)
27
+ end
28
+
29
+ const PREALLOCATED_UMFPACK = SparseArrays. UMFPACK. UmfpackLU (SparseMatrixCSC (0 , 0 , [1 ],
30
+ Int[], Float64[]))
31
+
32
+ function LinearSolve. init_cacheval (alg:: UMFPACKFactorization , A:: SparseMatrixCSC{Float64, Int} , b, u,
33
+ Pl, Pr,
34
+ maxiters:: Int , abstol, reltol,
35
+ verbose:: Bool , assumptions:: OperatorAssumptions )
36
+ PREALLOCATED_UMFPACK
37
+ end
38
+
39
+ function LinearSolve. init_cacheval (alg:: UMFPACKFactorization , A:: AbstractSparseArray , b, u, Pl, Pr,
40
+ maxiters:: Int , abstol,
41
+ reltol,
42
+ verbose:: Bool , assumptions:: OperatorAssumptions )
43
+ A = convert (AbstractMatrix, A)
44
+ zerobased = SparseArrays. getcolptr (A)[1 ] == 0
45
+ return SparseArrays. UMFPACK. UmfpackLU (SparseMatrixCSC (size (A)... , getcolptr (A),
46
+ rowvals (A), nonzeros (A)))
47
+ end
48
+
49
+ function SciMLBase. solve! (cache:: LinearCache , alg:: UMFPACKFactorization ; kwargs... )
50
+ A = cache. A
51
+ A = convert (AbstractMatrix, A)
52
+ if cache. isfresh
53
+ cacheval = LinearSolve. @get_cacheval (cache, :UMFPACKFactorization )
54
+ if alg. reuse_symbolic
55
+ # Caches the symbolic factorization: https://github.com/JuliaLang/julia/pull/33738
56
+ if alg. check_pattern && pattern_changed (cacheval, A)
57
+ fact = lu (
58
+ SparseMatrixCSC (size (A)... , getcolptr (A), rowvals (A),
59
+ nonzeros (A)),
60
+ check = false )
61
+ else
62
+ fact = lu! (cacheval,
63
+ SparseMatrixCSC (size (A)... , getcolptr (A), rowvals (A),
64
+ nonzeros (A)), check = false )
65
+ end
66
+ else
67
+ fact = lu (SparseMatrixCSC (size (A)... , getcolptr (A), rowvals (A), nonzeros (A)),
68
+ check = false )
69
+ end
70
+ cache. cacheval = fact
71
+ cache. isfresh = false
72
+ end
73
+
74
+ F = LinearSolve. @get_cacheval (cache, :UMFPACKFactorization )
75
+ if F. status == SparseArrays. UMFPACK. UMFPACK_OK
76
+ y = ldiv! (cache. u, F, cache. b)
77
+ SciMLBase. build_linear_solution (alg, y, nothing , cache)
78
+ else
79
+ SciMLBase. build_linear_solution (
80
+ alg, cache. u, nothing , cache; retcode = ReturnCode. Infeasible)
81
+ end
82
+ end
83
+
84
+ const PREALLOCATED_KLU = KLU. KLUFactorization (SparseMatrixCSC (0 , 0 , [1 ], Int[],
85
+ Float64[]))
86
+
87
+ function LinearSolve. init_cacheval (alg:: KLUFactorization , A:: SparseMatrixCSC{Float64, Int} , b, u, Pl,
88
+ Pr,
89
+ maxiters:: Int , abstol, reltol,
90
+ verbose:: Bool , assumptions:: OperatorAssumptions )
91
+ PREALLOCATED_KLU
92
+ end
93
+
94
+ function LinearSolve. init_cacheval (alg:: KLUFactorization , A:: AbstractSparseArray , b, u, Pl, Pr,
95
+ maxiters:: Int , abstol,
96
+ reltol,
97
+ verbose:: Bool , assumptions:: OperatorAssumptions )
98
+ A = convert (AbstractMatrix, A)
99
+ return KLU. KLUFactorization (SparseMatrixCSC (size (A)... , getcolptr (A), rowvals (A),
100
+ nonzeros (A)))
101
+ end
102
+
103
+ # TODO : guard this against errors
104
+ function SciMLBase. solve! (cache:: LinearCache , alg:: KLUFactorization ; kwargs... )
105
+ A = cache. A
106
+ A = convert (AbstractMatrix, A)
107
+ if cache. isfresh
108
+ cacheval = LinearSolve. @get_cacheval (cache, :KLUFactorization )
109
+ if alg. reuse_symbolic
110
+ if alg. check_pattern && pattern_changed (cacheval, A)
111
+ fact = KLU. klu (
112
+ SparseMatrixCSC (size (A)... , getcolptr (A), rowvals (A),
113
+ nonzeros (A)),
114
+ check = false )
115
+ else
116
+ fact = KLU. klu! (cacheval, nonzeros (A), check = false )
117
+ end
118
+ else
119
+ # New fact each time since the sparsity pattern can change
120
+ # and thus it needs to reallocate
121
+ fact = KLU. klu (SparseMatrixCSC (size (A)... , getcolptr (A), rowvals (A),
122
+ nonzeros (A)))
123
+ end
124
+ cache. cacheval = fact
125
+ cache. isfresh = false
126
+ end
127
+ F = LinearSolve. @get_cacheval (cache, :KLUFactorization )
128
+ if F. common. status == KLU. KLU_OK
129
+ y = ldiv! (cache. u, F, cache. b)
130
+ SciMLBase. build_linear_solution (alg, y, nothing , cache)
131
+ else
132
+ SciMLBase. build_linear_solution (
133
+ alg, cache. u, nothing , cache; retcode = ReturnCode. Infeasible)
134
+ end
135
+ end
136
+
137
+ const PREALLOCATED_CHOLMOD = cholesky (SparseMatrixCSC (0 , 0 , [1 ], Int[], Float64[]))
138
+
139
+ function LinearSolve. init_cacheval (alg:: CHOLMODFactorization ,
140
+ A:: Union{SparseMatrixCSC{T, Int}, Symmetric{T, SparseMatrixCSC{T, Int}}} , b, u,
141
+ Pl, Pr,
142
+ maxiters:: Int , abstol, reltol,
143
+ verbose:: Bool , assumptions:: OperatorAssumptions ) where {T < :
144
+ Union{Float32, Float64}}
145
+ PREALLOCATED_CHOLMOD
146
+ end
147
+
148
+ function LinearSolve. init_cacheval (alg:: NormalCholeskyFactorization ,
149
+ A:: Union {AbstractSparseArray, GPUArraysCore. AnyGPUArray,
150
+ Symmetric{<: Number , <: AbstractSparseArray }}, b, u, Pl, Pr,
151
+ maxiters:: Int , abstol, reltol, verbose:: Bool ,
152
+ assumptions:: OperatorAssumptions )
153
+ ArrayInterface. cholesky_instance (convert (AbstractMatrix, A))
154
+ end
155
+
156
+ # Specialize QR for the non-square case
157
+ # Missing ldiv! definitions: https://github.com/JuliaSparse/SparseArrays.jl/issues/242
158
+ function LinearSolve. _ldiv! (x:: Vector ,
159
+ A:: Union {SparseArrays. QR, LinearAlgebra. QRCompactWY,
160
+ SparseArrays. SPQR. QRSparse,
161
+ SparseArrays. CHOLMOD. Factor}, b:: Vector )
162
+ x .= A \ b
163
+ end
164
+
165
+ function LinearSolve. _ldiv! (x:: AbstractVector ,
166
+ A:: Union {SparseArrays. QR, LinearAlgebra. QRCompactWY,
167
+ SparseArrays. SPQR. QRSparse,
168
+ SparseArrays. CHOLMOD. Factor}, b:: AbstractVector )
169
+ x .= A \ b
170
+ end
171
+
172
+ # Ambiguity removal
173
+ function LinearSolve. _ldiv! (:: SVector ,
174
+ A:: Union {SparseArrays. CHOLMOD. Factor, LinearAlgebra. QR,
175
+ LinearAlgebra. QRCompactWY, SparseArrays. SPQR. QRSparse},
176
+ b:: AbstractVector )
177
+ (A \ b)
178
+ end
179
+ function LinearSolve. _ldiv! (:: SVector ,
180
+ A:: Union {SparseArrays. CHOLMOD. Factor, LinearAlgebra. QR,
181
+ LinearAlgebra. QRCompactWY, SparseArrays. SPQR. QRSparse},
182
+ b:: SVector )
183
+ (A \ b)
184
+ end
185
+
186
+ function pattern_changed (fact, A:: SparseArrays.SparseMatrixCSC )
187
+ ! (SparseArrays. decrement (SparseArrays. getcolptr (A)) ==
188
+ fact. colptr && SparseArrays. decrement (SparseArrays. getrowval (A)) ==
189
+ fact. rowval)
190
+ end
191
+
192
+ function LinearSolve. defaultalg (A:: AbstractSparseMatrixCSC{<:Union{Float64, ComplexF64}, Ti} , b,
193
+ assump:: OperatorAssumptions{Bool} ) where {Ti}
194
+ if assump. issq
195
+ if length (b) <= 10_000 && length (nonzeros (A)) / length (A) < 2e-4
196
+ LinearSolve. DefaultLinearSolver (LinearSolve. DefaultAlgorithmChoice. KLUFactorization)
197
+ else
198
+ LinearSolve. DefaultLinearSolver (LinearSolve. DefaultAlgorithmChoice. UMFPACKFactorization)
199
+ end
200
+ else
201
+ LinearSolve. DefaultLinearSolver (LinearSolve. DefaultAlgorithmChoice. QRFactorization)
202
+ end
203
+ end
204
+
205
+ LinearSolve. PrecompileTools. @compile_workload begin
206
+ A = sprand (4 , 4 , 0.3 ) + I
207
+ b = rand (4 )
208
+ prob = LinearProblem (A, b)
209
+ sol = solve (prob, KLUFactorization ())
210
+ sol = solve (prob, UMFPACKFactorization ())
211
+ end
212
+
213
+ end
0 commit comments