1
+ """
2
+ ```julia
3
+ OpenBLASLUFactorization()
4
+ ```
5
+
6
+ A wrapper over OpenBLAS. Direct calls to OpenBLAS in a way that pre-allocates workspace
7
+ to avoid allocations and does not require libblastrampoline.
8
+ """
9
+ struct OpenBLASLUFactorization <: AbstractFactorization end
10
+
11
+ module OpenBLASLU
12
+
13
+ using LinearAlgebra
14
+ using LinearAlgebra: BlasInt, LU, require_one_based_indexing, checksquare
15
+ using LinearAlgebra. LAPACK: chkfinite, chkstride1, @blasfunc , chkargsok, chktrans, chklapackerror
16
+ using OpenBLAS_jll
17
+
18
+ function getrf! (A:: AbstractMatrix{<:ComplexF64} ;
19
+ ipiv = similar (A, BlasInt, min (size (A, 1 ), size (A, 2 ))),
20
+ info = Ref {BlasInt} (),
21
+ check = false )
22
+ require_one_based_indexing (A)
23
+ check && chkfinite (A)
24
+ chkstride1 (A)
25
+ m, n = size (A)
26
+ lda = max (1 , stride (A, 2 ))
27
+ if isempty (ipiv)
28
+ ipiv = similar (A, BlasInt, min (size (A, 1 ), size (A, 2 )))
29
+ end
30
+ ccall ((@blasfunc (zgetrf_), OpenBLAS_jll. libopenblas), Cvoid,
31
+ (Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF64},
32
+ Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
33
+ m, n, A, lda, ipiv, info)
34
+ chkargsok (info[])
35
+ A, ipiv, info[], info # Error code is stored in LU factorization type
36
+ end
37
+
38
+ function getrf! (A:: AbstractMatrix{<:ComplexF32} ;
39
+ ipiv = similar (A, BlasInt, min (size (A, 1 ), size (A, 2 ))),
40
+ info = Ref {BlasInt} (),
41
+ check = false )
42
+ require_one_based_indexing (A)
43
+ check && chkfinite (A)
44
+ chkstride1 (A)
45
+ m, n = size (A)
46
+ lda = max (1 , stride (A, 2 ))
47
+ if isempty (ipiv)
48
+ ipiv = similar (A, BlasInt, min (size (A, 1 ), size (A, 2 )))
49
+ end
50
+ ccall ((@blasfunc (cgetrf_), OpenBLAS_jll. libopenblas), Cvoid,
51
+ (Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF32},
52
+ Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
53
+ m, n, A, lda, ipiv, info)
54
+ chkargsok (info[])
55
+ A, ipiv, info[], info # Error code is stored in LU factorization type
56
+ end
57
+
58
+ function getrf! (A:: AbstractMatrix{<:Float64} ;
59
+ ipiv = similar (A, BlasInt, min (size (A, 1 ), size (A, 2 ))),
60
+ info = Ref {BlasInt} (),
61
+ check = false )
62
+ require_one_based_indexing (A)
63
+ check && chkfinite (A)
64
+ chkstride1 (A)
65
+ m, n = size (A)
66
+ lda = max (1 , stride (A, 2 ))
67
+ if isempty (ipiv)
68
+ ipiv = similar (A, BlasInt, min (size (A, 1 ), size (A, 2 )))
69
+ end
70
+ ccall ((@blasfunc (dgetrf_), OpenBLAS_jll. libopenblas), Cvoid,
71
+ (Ref{BlasInt}, Ref{BlasInt}, Ptr{Float64},
72
+ Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
73
+ m, n, A, lda, ipiv, info)
74
+ chkargsok (info[])
75
+ A, ipiv, info[], info # Error code is stored in LU factorization type
76
+ end
77
+
78
+ function getrf! (A:: AbstractMatrix{<:Float32} ;
79
+ ipiv = similar (A, BlasInt, min (size (A, 1 ), size (A, 2 ))),
80
+ info = Ref {BlasInt} (),
81
+ check = false )
82
+ require_one_based_indexing (A)
83
+ check && chkfinite (A)
84
+ chkstride1 (A)
85
+ m, n = size (A)
86
+ lda = max (1 , stride (A, 2 ))
87
+ if isempty (ipiv)
88
+ ipiv = similar (A, BlasInt, min (size (A, 1 ), size (A, 2 )))
89
+ end
90
+ ccall ((@blasfunc (sgetrf_), OpenBLAS_jll. libopenblas), Cvoid,
91
+ (Ref{BlasInt}, Ref{BlasInt}, Ptr{Float32},
92
+ Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
93
+ m, n, A, lda, ipiv, info)
94
+ chkargsok (info[])
95
+ A, ipiv, info[], info # Error code is stored in LU factorization type
96
+ end
97
+
98
+ function getrs! (trans:: AbstractChar ,
99
+ A:: AbstractMatrix{<:ComplexF64} ,
100
+ ipiv:: AbstractVector{BlasInt} ,
101
+ B:: AbstractVecOrMat{<:ComplexF64} ;
102
+ info = Ref {BlasInt} ())
103
+ require_one_based_indexing (A, ipiv, B)
104
+ LinearAlgebra. LAPACK. chktrans (trans)
105
+ chkstride1 (A, B, ipiv)
106
+ n = LinearAlgebra. checksquare (A)
107
+ if n != size (B, 1 )
108
+ throw (DimensionMismatch (" B has leading dimension $(size (B,1 )) , but needs $n " ))
109
+ end
110
+ if n != length (ipiv)
111
+ throw (DimensionMismatch (" ipiv has length $(length (ipiv)) , but needs to be $n " ))
112
+ end
113
+ nrhs = size (B, 2 )
114
+ ccall ((@blasfunc (zgetrs_), OpenBLAS_jll. libopenblas), Cvoid,
115
+ (Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF64}, Ref{BlasInt},
116
+ Ptr{BlasInt}, Ptr{ComplexF64}, Ref{BlasInt}, Ptr{BlasInt}, Clong),
117
+ trans, n, size (B, 2 ), A, max (1 , stride (A, 2 )), ipiv, B, max (1 , stride (B, 2 )), info,
118
+ 1 )
119
+ LinearAlgebra. LAPACK. chklapackerror (BlasInt (info[]))
120
+ B
121
+ end
122
+
123
+ function getrs! (trans:: AbstractChar ,
124
+ A:: AbstractMatrix{<:ComplexF32} ,
125
+ ipiv:: AbstractVector{BlasInt} ,
126
+ B:: AbstractVecOrMat{<:ComplexF32} ;
127
+ info = Ref {BlasInt} ())
128
+ require_one_based_indexing (A, ipiv, B)
129
+ LinearAlgebra. LAPACK. chktrans (trans)
130
+ chkstride1 (A, B, ipiv)
131
+ n = LinearAlgebra. checksquare (A)
132
+ if n != size (B, 1 )
133
+ throw (DimensionMismatch (" B has leading dimension $(size (B,1 )) , but needs $n " ))
134
+ end
135
+ if n != length (ipiv)
136
+ throw (DimensionMismatch (" ipiv has length $(length (ipiv)) , but needs to be $n " ))
137
+ end
138
+ nrhs = size (B, 2 )
139
+ ccall ((@blasfunc (cgetrs_), OpenBLAS_jll. libopenblas), Cvoid,
140
+ (Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF32}, Ref{BlasInt},
141
+ Ptr{BlasInt}, Ptr{ComplexF32}, Ref{BlasInt}, Ptr{BlasInt}, Clong),
142
+ trans, n, size (B, 2 ), A, max (1 , stride (A, 2 )), ipiv, B, max (1 , stride (B, 2 )), info,
143
+ 1 )
144
+ LinearAlgebra. LAPACK. chklapackerror (BlasInt (info[]))
145
+ B
146
+ end
147
+
148
+ function getrs! (trans:: AbstractChar ,
149
+ A:: AbstractMatrix{<:Float64} ,
150
+ ipiv:: AbstractVector{BlasInt} ,
151
+ B:: AbstractVecOrMat{<:Float64} ;
152
+ info = Ref {BlasInt} ())
153
+ require_one_based_indexing (A, ipiv, B)
154
+ LinearAlgebra. LAPACK. chktrans (trans)
155
+ chkstride1 (A, B, ipiv)
156
+ n = LinearAlgebra. checksquare (A)
157
+ if n != size (B, 1 )
158
+ throw (DimensionMismatch (" B has leading dimension $(size (B,1 )) , but needs $n " ))
159
+ end
160
+ if n != length (ipiv)
161
+ throw (DimensionMismatch (" ipiv has length $(length (ipiv)) , but needs to be $n " ))
162
+ end
163
+ nrhs = size (B, 2 )
164
+ ccall ((@blasfunc (dgetrs_), OpenBLAS_jll. libopenblas), Cvoid,
165
+ (Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{Float64}, Ref{BlasInt},
166
+ Ptr{BlasInt}, Ptr{Float64}, Ref{BlasInt}, Ptr{BlasInt}, Clong),
167
+ trans, n, size (B, 2 ), A, max (1 , stride (A, 2 )), ipiv, B, max (1 , stride (B, 2 )), info,
168
+ 1 )
169
+ LinearAlgebra. LAPACK. chklapackerror (BlasInt (info[]))
170
+ B
171
+ end
172
+
173
+ function getrs! (trans:: AbstractChar ,
174
+ A:: AbstractMatrix{<:Float32} ,
175
+ ipiv:: AbstractVector{BlasInt} ,
176
+ B:: AbstractVecOrMat{<:Float32} ;
177
+ info = Ref {BlasInt} ())
178
+ require_one_based_indexing (A, ipiv, B)
179
+ LinearAlgebra. LAPACK. chktrans (trans)
180
+ chkstride1 (A, B, ipiv)
181
+ n = LinearAlgebra. checksquare (A)
182
+ if n != size (B, 1 )
183
+ throw (DimensionMismatch (" B has leading dimension $(size (B,1 )) , but needs $n " ))
184
+ end
185
+ if n != length (ipiv)
186
+ throw (DimensionMismatch (" ipiv has length $(length (ipiv)) , but needs to be $n " ))
187
+ end
188
+ nrhs = size (B, 2 )
189
+ ccall ((@blasfunc (sgetrs_), OpenBLAS_jll. libopenblas), Cvoid,
190
+ (Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{Float32}, Ref{BlasInt},
191
+ Ptr{BlasInt}, Ptr{Float32}, Ref{BlasInt}, Ptr{BlasInt}, Clong),
192
+ trans, n, size (B, 2 ), A, max (1 , stride (A, 2 )), ipiv, B, max (1 , stride (B, 2 )), info,
193
+ 1 )
194
+ LinearAlgebra. LAPACK. chklapackerror (BlasInt (info[]))
195
+ B
196
+ end
197
+
198
+ end # module OpenBLASLU
199
+
200
+ default_alias_A (:: OpenBLASLUFactorization , :: Any , :: Any ) = false
201
+ default_alias_b (:: OpenBLASLUFactorization , :: Any , :: Any ) = false
202
+
203
+ const PREALLOCATED_OPENBLAS_LU = begin
204
+ A = rand (0 , 0 )
205
+ luinst = ArrayInterface. lu_instance (A), Ref {BlasInt} ()
206
+ end
207
+
208
+ function LinearSolve. init_cacheval (alg:: OpenBLASLUFactorization , A, b, u, Pl, Pr,
209
+ maxiters:: Int , abstol, reltol, verbose:: LinearVerbosity ,
210
+ assumptions:: OperatorAssumptions )
211
+ PREALLOCATED_OPENBLAS_LU
212
+ end
213
+
214
+ function LinearSolve. init_cacheval (alg:: OpenBLASLUFactorization ,
215
+ A:: AbstractMatrix{<:Union{Float32, ComplexF32, ComplexF64}} , b, u, Pl, Pr,
216
+ maxiters:: Int , abstol, reltol, verbose:: LinearVerbosity ,
217
+ assumptions:: OperatorAssumptions )
218
+ A = rand (eltype (A), 0 , 0 )
219
+ ArrayInterface. lu_instance (A), Ref {BlasInt} ()
220
+ end
221
+
222
+ function SciMLBase. solve! (cache:: LinearCache , alg:: OpenBLASLUFactorization ;
223
+ kwargs... )
224
+ A = cache. A
225
+ A = convert (AbstractMatrix, A)
226
+ if cache. isfresh
227
+ cacheval = @get_cacheval (cache, :OpenBLASLUFactorization )
228
+ res = OpenBLASLU. getrf! (A; ipiv = cacheval[1 ]. ipiv, info = cacheval[2 ])
229
+ fact = LU (res[1 : 3 ]. .. ), res[4 ]
230
+ cache. cacheval = fact
231
+
232
+ if ! LinearAlgebra. issuccess (fact[1 ])
233
+ return SciMLBase. build_linear_solution (
234
+ alg, cache. u, nothing , cache; retcode = ReturnCode. Failure)
235
+ end
236
+ cache. isfresh = false
237
+ end
238
+
239
+ A, info = @get_cacheval (cache, :OpenBLASLUFactorization )
240
+ require_one_based_indexing (cache. u, cache. b)
241
+ m, n = size (A, 1 ), size (A, 2 )
242
+ if m > n
243
+ Bc = copy (cache. b)
244
+ OpenBLASLU. getrs! (' N' , A. factors, A. ipiv, Bc; info)
245
+ copyto! (cache. u, 1 , Bc, 1 , n)
246
+ else
247
+ copyto! (cache. u, cache. b)
248
+ OpenBLASLU. getrs! (' N' , A. factors, A. ipiv, cache. u; info)
249
+ end
250
+
251
+ SciMLBase. build_linear_solution (alg, cache. u, nothing , cache; retcode = ReturnCode. Success)
252
+ end
0 commit comments