1+ """
2+ ```julia
3+ OpenBLASLUFactorization()
4+ ```
5+
6+ A wrapper over OpenBLAS. Direct calls to OpenBLAS in a way that pre-allocates workspace
7+ to avoid allocations and does not require libblastrampoline.
8+ """
9+ struct OpenBLASLUFactorization <: AbstractFactorization end
10+
11+ module OpenBLASLU
12+
13+ using LinearAlgebra
14+ using LinearAlgebra: BlasInt, LU, require_one_based_indexing, checksquare
15+ using LinearAlgebra. LAPACK: chkfinite, chkstride1, @blasfunc , chkargsok, chktrans, chklapackerror
16+ using OpenBLAS_jll
17+
18+ function getrf! (A:: AbstractMatrix{<:ComplexF64} ;
19+ ipiv = similar (A, BlasInt, min (size (A, 1 ), size (A, 2 ))),
20+ info = Ref {BlasInt} (),
21+ check = false )
22+ require_one_based_indexing (A)
23+ check && chkfinite (A)
24+ chkstride1 (A)
25+ m, n = size (A)
26+ lda = max (1 , stride (A, 2 ))
27+ if isempty (ipiv)
28+ ipiv = similar (A, BlasInt, min (size (A, 1 ), size (A, 2 )))
29+ end
30+ ccall ((@blasfunc (zgetrf_), OpenBLAS_jll. libopenblas), Cvoid,
31+ (Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF64},
32+ Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
33+ m, n, A, lda, ipiv, info)
34+ chkargsok (info[])
35+ A, ipiv, info[], info # Error code is stored in LU factorization type
36+ end
37+
38+ function getrf! (A:: AbstractMatrix{<:ComplexF32} ;
39+ ipiv = similar (A, BlasInt, min (size (A, 1 ), size (A, 2 ))),
40+ info = Ref {BlasInt} (),
41+ check = false )
42+ require_one_based_indexing (A)
43+ check && chkfinite (A)
44+ chkstride1 (A)
45+ m, n = size (A)
46+ lda = max (1 , stride (A, 2 ))
47+ if isempty (ipiv)
48+ ipiv = similar (A, BlasInt, min (size (A, 1 ), size (A, 2 )))
49+ end
50+ ccall ((@blasfunc (cgetrf_), OpenBLAS_jll. libopenblas), Cvoid,
51+ (Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF32},
52+ Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
53+ m, n, A, lda, ipiv, info)
54+ chkargsok (info[])
55+ A, ipiv, info[], info # Error code is stored in LU factorization type
56+ end
57+
58+ function getrf! (A:: AbstractMatrix{<:Float64} ;
59+ ipiv = similar (A, BlasInt, min (size (A, 1 ), size (A, 2 ))),
60+ info = Ref {BlasInt} (),
61+ check = false )
62+ require_one_based_indexing (A)
63+ check && chkfinite (A)
64+ chkstride1 (A)
65+ m, n = size (A)
66+ lda = max (1 , stride (A, 2 ))
67+ if isempty (ipiv)
68+ ipiv = similar (A, BlasInt, min (size (A, 1 ), size (A, 2 )))
69+ end
70+ ccall ((@blasfunc (dgetrf_), OpenBLAS_jll. libopenblas), Cvoid,
71+ (Ref{BlasInt}, Ref{BlasInt}, Ptr{Float64},
72+ Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
73+ m, n, A, lda, ipiv, info)
74+ chkargsok (info[])
75+ A, ipiv, info[], info # Error code is stored in LU factorization type
76+ end
77+
78+ function getrf! (A:: AbstractMatrix{<:Float32} ;
79+ ipiv = similar (A, BlasInt, min (size (A, 1 ), size (A, 2 ))),
80+ info = Ref {BlasInt} (),
81+ check = false )
82+ require_one_based_indexing (A)
83+ check && chkfinite (A)
84+ chkstride1 (A)
85+ m, n = size (A)
86+ lda = max (1 , stride (A, 2 ))
87+ if isempty (ipiv)
88+ ipiv = similar (A, BlasInt, min (size (A, 1 ), size (A, 2 )))
89+ end
90+ ccall ((@blasfunc (sgetrf_), OpenBLAS_jll. libopenblas), Cvoid,
91+ (Ref{BlasInt}, Ref{BlasInt}, Ptr{Float32},
92+ Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
93+ m, n, A, lda, ipiv, info)
94+ chkargsok (info[])
95+ A, ipiv, info[], info # Error code is stored in LU factorization type
96+ end
97+
98+ function getrs! (trans:: AbstractChar ,
99+ A:: AbstractMatrix{<:ComplexF64} ,
100+ ipiv:: AbstractVector{BlasInt} ,
101+ B:: AbstractVecOrMat{<:ComplexF64} ;
102+ info = Ref {BlasInt} ())
103+ require_one_based_indexing (A, ipiv, B)
104+ LinearAlgebra. LAPACK. chktrans (trans)
105+ chkstride1 (A, B, ipiv)
106+ n = LinearAlgebra. checksquare (A)
107+ if n != size (B, 1 )
108+ throw (DimensionMismatch (" B has leading dimension $(size (B,1 )) , but needs $n " ))
109+ end
110+ if n != length (ipiv)
111+ throw (DimensionMismatch (" ipiv has length $(length (ipiv)) , but needs to be $n " ))
112+ end
113+ nrhs = size (B, 2 )
114+ ccall ((@blasfunc (zgetrs_), OpenBLAS_jll. libopenblas), Cvoid,
115+ (Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF64}, Ref{BlasInt},
116+ Ptr{BlasInt}, Ptr{ComplexF64}, Ref{BlasInt}, Ptr{BlasInt}, Clong),
117+ trans, n, size (B, 2 ), A, max (1 , stride (A, 2 )), ipiv, B, max (1 , stride (B, 2 )), info,
118+ 1 )
119+ LinearAlgebra. LAPACK. chklapackerror (BlasInt (info[]))
120+ B
121+ end
122+
123+ function getrs! (trans:: AbstractChar ,
124+ A:: AbstractMatrix{<:ComplexF32} ,
125+ ipiv:: AbstractVector{BlasInt} ,
126+ B:: AbstractVecOrMat{<:ComplexF32} ;
127+ info = Ref {BlasInt} ())
128+ require_one_based_indexing (A, ipiv, B)
129+ LinearAlgebra. LAPACK. chktrans (trans)
130+ chkstride1 (A, B, ipiv)
131+ n = LinearAlgebra. checksquare (A)
132+ if n != size (B, 1 )
133+ throw (DimensionMismatch (" B has leading dimension $(size (B,1 )) , but needs $n " ))
134+ end
135+ if n != length (ipiv)
136+ throw (DimensionMismatch (" ipiv has length $(length (ipiv)) , but needs to be $n " ))
137+ end
138+ nrhs = size (B, 2 )
139+ ccall ((@blasfunc (cgetrs_), OpenBLAS_jll. libopenblas), Cvoid,
140+ (Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF32}, Ref{BlasInt},
141+ Ptr{BlasInt}, Ptr{ComplexF32}, Ref{BlasInt}, Ptr{BlasInt}, Clong),
142+ trans, n, size (B, 2 ), A, max (1 , stride (A, 2 )), ipiv, B, max (1 , stride (B, 2 )), info,
143+ 1 )
144+ LinearAlgebra. LAPACK. chklapackerror (BlasInt (info[]))
145+ B
146+ end
147+
148+ function getrs! (trans:: AbstractChar ,
149+ A:: AbstractMatrix{<:Float64} ,
150+ ipiv:: AbstractVector{BlasInt} ,
151+ B:: AbstractVecOrMat{<:Float64} ;
152+ info = Ref {BlasInt} ())
153+ require_one_based_indexing (A, ipiv, B)
154+ LinearAlgebra. LAPACK. chktrans (trans)
155+ chkstride1 (A, B, ipiv)
156+ n = LinearAlgebra. checksquare (A)
157+ if n != size (B, 1 )
158+ throw (DimensionMismatch (" B has leading dimension $(size (B,1 )) , but needs $n " ))
159+ end
160+ if n != length (ipiv)
161+ throw (DimensionMismatch (" ipiv has length $(length (ipiv)) , but needs to be $n " ))
162+ end
163+ nrhs = size (B, 2 )
164+ ccall ((@blasfunc (dgetrs_), OpenBLAS_jll. libopenblas), Cvoid,
165+ (Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{Float64}, Ref{BlasInt},
166+ Ptr{BlasInt}, Ptr{Float64}, Ref{BlasInt}, Ptr{BlasInt}, Clong),
167+ trans, n, size (B, 2 ), A, max (1 , stride (A, 2 )), ipiv, B, max (1 , stride (B, 2 )), info,
168+ 1 )
169+ LinearAlgebra. LAPACK. chklapackerror (BlasInt (info[]))
170+ B
171+ end
172+
173+ function getrs! (trans:: AbstractChar ,
174+ A:: AbstractMatrix{<:Float32} ,
175+ ipiv:: AbstractVector{BlasInt} ,
176+ B:: AbstractVecOrMat{<:Float32} ;
177+ info = Ref {BlasInt} ())
178+ require_one_based_indexing (A, ipiv, B)
179+ LinearAlgebra. LAPACK. chktrans (trans)
180+ chkstride1 (A, B, ipiv)
181+ n = LinearAlgebra. checksquare (A)
182+ if n != size (B, 1 )
183+ throw (DimensionMismatch (" B has leading dimension $(size (B,1 )) , but needs $n " ))
184+ end
185+ if n != length (ipiv)
186+ throw (DimensionMismatch (" ipiv has length $(length (ipiv)) , but needs to be $n " ))
187+ end
188+ nrhs = size (B, 2 )
189+ ccall ((@blasfunc (sgetrs_), OpenBLAS_jll. libopenblas), Cvoid,
190+ (Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{Float32}, Ref{BlasInt},
191+ Ptr{BlasInt}, Ptr{Float32}, Ref{BlasInt}, Ptr{BlasInt}, Clong),
192+ trans, n, size (B, 2 ), A, max (1 , stride (A, 2 )), ipiv, B, max (1 , stride (B, 2 )), info,
193+ 1 )
194+ LinearAlgebra. LAPACK. chklapackerror (BlasInt (info[]))
195+ B
196+ end
197+
198+ end # module OpenBLASLU
199+
200+ default_alias_A (:: OpenBLASLUFactorization , :: Any , :: Any ) = false
201+ default_alias_b (:: OpenBLASLUFactorization , :: Any , :: Any ) = false
202+
203+ const PREALLOCATED_OPENBLAS_LU = begin
204+ A = rand (0 , 0 )
205+ luinst = ArrayInterface. lu_instance (A), Ref {BlasInt} ()
206+ end
207+
208+ function LinearSolve. init_cacheval (alg:: OpenBLASLUFactorization , A, b, u, Pl, Pr,
209+ maxiters:: Int , abstol, reltol, verbose:: LinearVerbosity ,
210+ assumptions:: OperatorAssumptions )
211+ PREALLOCATED_OPENBLAS_LU
212+ end
213+
214+ function LinearSolve. init_cacheval (alg:: OpenBLASLUFactorization ,
215+ A:: AbstractMatrix{<:Union{Float32, ComplexF32, ComplexF64}} , b, u, Pl, Pr,
216+ maxiters:: Int , abstol, reltol, verbose:: LinearVerbosity ,
217+ assumptions:: OperatorAssumptions )
218+ A = rand (eltype (A), 0 , 0 )
219+ ArrayInterface. lu_instance (A), Ref {BlasInt} ()
220+ end
221+
222+ function SciMLBase. solve! (cache:: LinearCache , alg:: OpenBLASLUFactorization ;
223+ kwargs... )
224+ A = cache. A
225+ A = convert (AbstractMatrix, A)
226+ if cache. isfresh
227+ cacheval = @get_cacheval (cache, :OpenBLASLUFactorization )
228+ res = OpenBLASLU. getrf! (A; ipiv = cacheval[1 ]. ipiv, info = cacheval[2 ])
229+ fact = LU (res[1 : 3 ]. .. ), res[4 ]
230+ cache. cacheval = fact
231+
232+ if ! LinearAlgebra. issuccess (fact[1 ])
233+ return SciMLBase. build_linear_solution (
234+ alg, cache. u, nothing , cache; retcode = ReturnCode. Failure)
235+ end
236+ cache. isfresh = false
237+ end
238+
239+ A, info = @get_cacheval (cache, :OpenBLASLUFactorization )
240+ require_one_based_indexing (cache. u, cache. b)
241+ m, n = size (A, 1 ), size (A, 2 )
242+ if m > n
243+ Bc = copy (cache. b)
244+ OpenBLASLU. getrs! (' N' , A. factors, A. ipiv, Bc; info)
245+ copyto! (cache. u, 1 , Bc, 1 , n)
246+ else
247+ copyto! (cache. u, cache. b)
248+ OpenBLASLU. getrs! (' N' , A. factors, A. ipiv, cache. u; info)
249+ end
250+
251+ SciMLBase. build_linear_solution (alg, cache. u, nothing , cache; retcode = ReturnCode. Success)
252+ end
0 commit comments