1+ module LinearSolveBLISExt
2+
3+ using Libdl
4+ using blis_jll
5+ using LinearAlgebra
6+ using LinearSolve
7+
8+ using LinearAlgebra: BlasInt, LU
9+ using LinearAlgebra. LAPACK: require_one_based_indexing, chkfinite, chkstride1,
10+ @blasfunc , chkargsok
11+ using LinearSolve: ArrayInterface, BLISLUFactorization, @get_cacheval , LinearCache, SciMLBase
12+
13+ const global libblis = dlopen (blis_jll. blis_path)
14+
15+ function getrf! (A:: AbstractMatrix{<:ComplexF64} ;
16+ ipiv = similar (A, BlasInt, min (size (A, 1 ), size (A, 2 ))),
17+ info = Ref {BlasInt} (),
18+ check = false )
19+ require_one_based_indexing (A)
20+ check && chkfinite (A)
21+ chkstride1 (A)
22+ m, n = size (A)
23+ lda = max (1 , stride (A, 2 ))
24+ if isempty (ipiv)
25+ ipiv = similar (A, BlasInt, min (size (A, 1 ), size (A, 2 )))
26+ end
27+ ccall ((@blasfunc (zgetrf_), libblis), Cvoid,
28+ (Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF64},
29+ Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
30+ m, n, A, lda, ipiv, info)
31+ chkargsok (info[])
32+ A, ipiv, info[], info # Error code is stored in LU factorization type
33+ end
34+
35+ function getrf! (A:: AbstractMatrix{<:ComplexF32} ;
36+ ipiv = similar (A, BlasInt, min (size (A, 1 ), size (A, 2 ))),
37+ info = Ref {BlasInt} (),
38+ check = false )
39+ require_one_based_indexing (A)
40+ check && chkfinite (A)
41+ chkstride1 (A)
42+ m, n = size (A)
43+ lda = max (1 , stride (A, 2 ))
44+ if isempty (ipiv)
45+ ipiv = similar (A, BlasInt, min (size (A, 1 ), size (A, 2 )))
46+ end
47+ ccall ((@blasfunc (cgetrf_), libblis), Cvoid,
48+ (Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF32},
49+ Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
50+ m, n, A, lda, ipiv, info)
51+ chkargsok (info[])
52+ A, ipiv, info[], info # Error code is stored in LU factorization type
53+ end
54+
55+ function getrf! (A:: AbstractMatrix{<:Float64} ;
56+ ipiv = similar (A, BlasInt, min (size (A, 1 ), size (A, 2 ))),
57+ info = Ref {BlasInt} (),
58+ check = false )
59+ require_one_based_indexing (A)
60+ check && chkfinite (A)
61+ chkstride1 (A)
62+ m, n = size (A)
63+ lda = max (1 , stride (A, 2 ))
64+ if isempty (ipiv)
65+ ipiv = similar (A, BlasInt, min (size (A, 1 ), size (A, 2 )))
66+ end
67+ ccall ((@blasfunc (dgetrf_), libblis), Cvoid,
68+ (Ref{BlasInt}, Ref{BlasInt}, Ptr{Float64},
69+ Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
70+ m, n, A, lda, ipiv, info)
71+ chkargsok (info[])
72+ A, ipiv, info[], info # Error code is stored in LU factorization type
73+ end
74+
75+ function getrf! (A:: AbstractMatrix{<:Float32} ;
76+ ipiv = similar (A, BlasInt, min (size (A, 1 ), size (A, 2 ))),
77+ info = Ref {BlasInt} (),
78+ check = false )
79+ require_one_based_indexing (A)
80+ check && chkfinite (A)
81+ chkstride1 (A)
82+ m, n = size (A)
83+ lda = max (1 , stride (A, 2 ))
84+ if isempty (ipiv)
85+ ipiv = similar (A, BlasInt, min (size (A, 1 ), size (A, 2 )))
86+ end
87+ ccall ((@blasfunc (sgetrf_), libblis), Cvoid,
88+ (Ref{BlasInt}, Ref{BlasInt}, Ptr{Float32},
89+ Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
90+ m, n, A, lda, ipiv, info)
91+ chkargsok (info[])
92+ A, ipiv, info[], info # Error code is stored in LU factorization type
93+ end
94+
95+ function getrs! (trans:: AbstractChar ,
96+ A:: AbstractMatrix{<:ComplexF64} ,
97+ ipiv:: AbstractVector{BlasInt} ,
98+ B:: AbstractVecOrMat{<:ComplexF64} ;
99+ info = Ref {BlasInt} ())
100+ require_one_based_indexing (A, ipiv, B)
101+ LinearAlgebra. LAPACK. chktrans (trans)
102+ chkstride1 (A, B, ipiv)
103+ n = LinearAlgebra. checksquare (A)
104+ if n != size (B, 1 )
105+ throw (DimensionMismatch (" B has leading dimension $(size (B,1 )) , but needs $n " ))
106+ end
107+ if n != length (ipiv)
108+ throw (DimensionMismatch (" ipiv has length $(length (ipiv)) , but needs to be $n " ))
109+ end
110+ nrhs = size (B, 2 )
111+ ccall ((" zgetrs_" , libblis), Cvoid,
112+ (Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF64}, Ref{BlasInt},
113+ Ptr{BlasInt}, Ptr{ComplexF64}, Ref{BlasInt}, Ptr{BlasInt}, Clong),
114+ trans, n, size (B, 2 ), A, max (1 , stride (A, 2 )), ipiv, B, max (1 , stride (B, 2 )), info,
115+ 1 )
116+ LinearAlgebra. LAPACK. chklapackerror (BlasInt (info[]))
117+ B
118+ end
119+
120+ function getrs! (trans:: AbstractChar ,
121+ A:: AbstractMatrix{<:ComplexF32} ,
122+ ipiv:: AbstractVector{BlasInt} ,
123+ B:: AbstractVecOrMat{<:ComplexF32} ;
124+ info = Ref {BlasInt} ())
125+ require_one_based_indexing (A, ipiv, B)
126+ LinearAlgebra. LAPACK. chktrans (trans)
127+ chkstride1 (A, B, ipiv)
128+ n = LinearAlgebra. checksquare (A)
129+ if n != size (B, 1 )
130+ throw (DimensionMismatch (" B has leading dimension $(size (B,1 )) , but needs $n " ))
131+ end
132+ if n != length (ipiv)
133+ throw (DimensionMismatch (" ipiv has length $(length (ipiv)) , but needs to be $n " ))
134+ end
135+ nrhs = size (B, 2 )
136+ ccall ((" cgetrs_" , libblis), Cvoid,
137+ (Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF32}, Ref{BlasInt},
138+ Ptr{BlasInt}, Ptr{ComplexF32}, Ref{BlasInt}, Ptr{BlasInt}, Clong),
139+ trans, n, size (B, 2 ), A, max (1 , stride (A, 2 )), ipiv, B, max (1 , stride (B, 2 )), info,
140+ 1 )
141+ LinearAlgebra. LAPACK. chklapackerror (BlasInt (info[]))
142+ B
143+ end
144+
145+ function getrs! (trans:: AbstractChar ,
146+ A:: AbstractMatrix{<:Float64} ,
147+ ipiv:: AbstractVector{BlasInt} ,
148+ B:: AbstractVecOrMat{<:Float64} ;
149+ info = Ref {BlasInt} ())
150+ require_one_based_indexing (A, ipiv, B)
151+ LinearAlgebra. LAPACK. chktrans (trans)
152+ chkstride1 (A, B, ipiv)
153+ n = LinearAlgebra. checksquare (A)
154+ if n != size (B, 1 )
155+ throw (DimensionMismatch (" B has leading dimension $(size (B,1 )) , but needs $n " ))
156+ end
157+ if n != length (ipiv)
158+ throw (DimensionMismatch (" ipiv has length $(length (ipiv)) , but needs to be $n " ))
159+ end
160+ nrhs = size (B, 2 )
161+ ccall ((" dgetrs_" , libblis), Cvoid,
162+ (Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{Float64}, Ref{BlasInt},
163+ Ptr{BlasInt}, Ptr{Float64}, Ref{BlasInt}, Ptr{BlasInt}, Clong),
164+ trans, n, size (B, 2 ), A, max (1 , stride (A, 2 )), ipiv, B, max (1 , stride (B, 2 )), info,
165+ 1 )
166+ LinearAlgebra. LAPACK. chklapackerror (BlasInt (info[]))
167+ B
168+ end
169+
170+ function getrs! (trans:: AbstractChar ,
171+ A:: AbstractMatrix{<:Float32} ,
172+ ipiv:: AbstractVector{BlasInt} ,
173+ B:: AbstractVecOrMat{<:Float32} ;
174+ info = Ref {BlasInt} ())
175+ require_one_based_indexing (A, ipiv, B)
176+ LinearAlgebra. LAPACK. chktrans (trans)
177+ chkstride1 (A, B, ipiv)
178+ n = LinearAlgebra. checksquare (A)
179+ if n != size (B, 1 )
180+ throw (DimensionMismatch (" B has leading dimension $(size (B,1 )) , but needs $n " ))
181+ end
182+ if n != length (ipiv)
183+ throw (DimensionMismatch (" ipiv has length $(length (ipiv)) , but needs to be $n " ))
184+ end
185+ nrhs = size (B, 2 )
186+ ccall ((" sgetrs_" , libblis), Cvoid,
187+ (Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{Float32}, Ref{BlasInt},
188+ Ptr{BlasInt}, Ptr{Float32}, Ref{BlasInt}, Ptr{BlasInt}, Clong),
189+ trans, n, size (B, 2 ), A, max (1 , stride (A, 2 )), ipiv, B, max (1 , stride (B, 2 )), info,
190+ 1 )
191+ LinearAlgebra. LAPACK. chklapackerror (BlasInt (info[]))
192+ B
193+ end
194+
195+ default_alias_A (:: BLISLUFactorization , :: Any , :: Any ) = false
196+ default_alias_b (:: BLISLUFactorization , :: Any , :: Any ) = false
197+
198+ const PREALLOCATED_BLIS_LU = begin
199+ A = rand (0 , 0 )
200+ luinst = ArrayInterface. lu_instance (A), Ref {BlasInt} ()
201+ end
202+
203+ function LinearSolve. init_cacheval (alg:: BLISLUFactorization , A, b, u, Pl, Pr,
204+ maxiters:: Int , abstol, reltol, verbose:: Bool ,
205+ assumptions:: OperatorAssumptions )
206+ PREALLOCATED_BLIS_LU
207+ end
208+
209+ function LinearSolve. init_cacheval (alg:: BLISLUFactorization , A:: AbstractMatrix{<:Union{Float32,ComplexF32,ComplexF64}} , b, u, Pl, Pr,
210+ maxiters:: Int , abstol, reltol, verbose:: Bool ,
211+ assumptions:: OperatorAssumptions )
212+ A = rand (eltype (A), 0 , 0 )
213+ ArrayInterface. lu_instance (A), Ref {BlasInt} ()
214+ end
215+
216+ function SciMLBase. solve! (cache:: LinearCache , alg:: BLISLUFactorization ;
217+ kwargs... )
218+ A = cache. A
219+ A = convert (AbstractMatrix, A)
220+ if cache. isfresh
221+ cacheval = @get_cacheval (cache, :BLISLUFactorization )
222+ res = getrf! (A; ipiv = cacheval[1 ]. ipiv, info = cacheval[2 ])
223+ fact = LU (res[1 : 3 ]. .. ), res[4 ]
224+ cache. cacheval = fact
225+ cache. isfresh = false
226+ end
227+
228+ y = ldiv! (cache. u, @get_cacheval (cache, :BLISLUFactorization )[1 ], cache. b)
229+ SciMLBase. build_linear_solution (alg, y, nothing , cache)
230+
231+ #=
232+ A, info = @get_cacheval(cache, :BLISLUFactorization)
233+ LinearAlgebra.require_one_based_indexing(cache.u, cache.b)
234+ m, n = size(A, 1), size(A, 2)
235+ if m > n
236+ Bc = copy(cache.b)
237+ getrs!('N', A.factors, A.ipiv, Bc; info)
238+ return copyto!(cache.u, 1, Bc, 1, n)
239+ else
240+ copyto!(cache.u, cache.b)
241+ getrs!('N', A.factors, A.ipiv, cache.u; info)
242+ end
243+
244+ SciMLBase.build_linear_solution(alg, cache.u, nothing, cache)
245+ =#
246+ end
247+
248+ end
0 commit comments