Skip to content

Commit d4da31b

Browse files
committed
Fix stale imports and add compat entry for OpenBLAS_jll
- Remove redundant import of LinearAlgebra items that are already available via using LinearAlgebra - Qualify all uses of BlasInt, LU, require_one_based_indexing, and checksquare with LinearAlgebra prefix - Add OpenBLAS_jll = "0.3" to compat section in Project.toml - Apply JuliaFormatter with SciMLStyle to ensure consistent formatting
1 parent 676c27f commit d4da31b

File tree

2 files changed

+73
-65
lines changed

2 files changed

+73
-65
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ MPI = "0.20"
113113
Markdown = "1.10"
114114
Metal = "1.4"
115115
MultiFloats = "2.3"
116+
OpenBLAS_jll = "0.3"
116117
Pardiso = "1"
117118
Pkg = "1.10"
118119
PrecompileTools = "1.2"

src/openblas.jl

Lines changed: 72 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,22 @@
33
OpenBLASLUFactorization()
44
```
55
6-
A direct wrapper over OpenBLAS's LU factorization (`getrf!` and `getrs!`).
7-
This solver makes direct calls to OpenBLAS_jll without going through Julia's
6+
A direct wrapper over OpenBLAS's LU factorization (`getrf!` and `getrs!`).
7+
This solver makes direct calls to OpenBLAS_jll without going through Julia's
88
libblastrampoline, which can provide performance benefits in certain configurations.
99
1010
## Performance Characteristics
1111
12-
- Pre-allocates workspace to avoid allocations during solving
13-
- Makes direct `ccall`s to OpenBLAS routines
14-
- Can be faster than `LUFactorization` when OpenBLAS is well-optimized for the hardware
15-
- Supports `Float32`, `Float64`, `ComplexF32`, and `ComplexF64` element types
12+
- Pre-allocates workspace to avoid allocations during solving
13+
- Makes direct `ccall`s to OpenBLAS routines
14+
- Can be faster than `LUFactorization` when OpenBLAS is well-optimized for the hardware
15+
- Supports `Float32`, `Float64`, `ComplexF32`, and `ComplexF64` element types
1616
1717
## When to Use
1818
19-
- When you want to ensure OpenBLAS is used regardless of the system BLAS configuration
20-
- When benchmarking shows better performance than `LUFactorization` on your specific hardware
21-
- When you need consistent behavior across different systems (always uses OpenBLAS)
19+
- When you want to ensure OpenBLAS is used regardless of the system BLAS configuration
20+
- When benchmarking shows better performance than `LUFactorization` on your specific hardware
21+
- When you need consistent behavior across different systems (always uses OpenBLAS)
2222
2323
## Example
2424
@@ -36,96 +36,96 @@ struct OpenBLASLUFactorization <: AbstractFactorization end
3636
module OpenBLASLU
3737

3838
using LinearAlgebra
39-
using LinearAlgebra: BlasInt, LU, require_one_based_indexing, checksquare
40-
using LinearAlgebra.LAPACK: chkfinite, chkstride1, @blasfunc, chkargsok, chktrans, chklapackerror
39+
using LinearAlgebra.LAPACK: chkfinite, chkstride1, @blasfunc, chkargsok, chktrans,
40+
chklapackerror
4141
using OpenBLAS_jll
4242

4343
function getrf!(A::AbstractMatrix{<:ComplexF64};
44-
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))),
45-
info = Ref{BlasInt}(),
44+
ipiv = similar(A, LinearAlgebra.BlasInt, min(size(A, 1), size(A, 2))),
45+
info = Ref{LinearAlgebra.BlasInt}(),
4646
check = false)
47-
require_one_based_indexing(A)
47+
LinearAlgebra.require_one_based_indexing(A)
4848
check && chkfinite(A)
4949
chkstride1(A)
5050
m, n = size(A)
5151
lda = max(1, stride(A, 2))
5252
if isempty(ipiv)
53-
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2)))
53+
ipiv = similar(A, LinearAlgebra.BlasInt, min(size(A, 1), size(A, 2)))
5454
end
5555
ccall((@blasfunc(zgetrf_), OpenBLAS_jll.libopenblas), Cvoid,
56-
(Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF64},
57-
Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
56+
(Ref{LinearAlgebra.BlasInt}, Ref{LinearAlgebra.BlasInt}, Ptr{ComplexF64},
57+
Ref{LinearAlgebra.BlasInt}, Ptr{LinearAlgebra.BlasInt}, Ptr{LinearAlgebra.BlasInt}),
5858
m, n, A, lda, ipiv, info)
5959
chkargsok(info[])
6060
A, ipiv, info[], info #Error code is stored in LU factorization type
6161
end
6262

6363
function getrf!(A::AbstractMatrix{<:ComplexF32};
64-
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))),
65-
info = Ref{BlasInt}(),
64+
ipiv = similar(A, LinearAlgebra.BlasInt, min(size(A, 1), size(A, 2))),
65+
info = Ref{LinearAlgebra.BlasInt}(),
6666
check = false)
67-
require_one_based_indexing(A)
67+
LinearAlgebra.require_one_based_indexing(A)
6868
check && chkfinite(A)
6969
chkstride1(A)
7070
m, n = size(A)
7171
lda = max(1, stride(A, 2))
7272
if isempty(ipiv)
73-
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2)))
73+
ipiv = similar(A, LinearAlgebra.BlasInt, min(size(A, 1), size(A, 2)))
7474
end
7575
ccall((@blasfunc(cgetrf_), OpenBLAS_jll.libopenblas), Cvoid,
76-
(Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF32},
77-
Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
76+
(Ref{LinearAlgebra.BlasInt}, Ref{LinearAlgebra.BlasInt}, Ptr{ComplexF32},
77+
Ref{LinearAlgebra.BlasInt}, Ptr{LinearAlgebra.BlasInt}, Ptr{LinearAlgebra.BlasInt}),
7878
m, n, A, lda, ipiv, info)
7979
chkargsok(info[])
8080
A, ipiv, info[], info #Error code is stored in LU factorization type
8181
end
8282

8383
function getrf!(A::AbstractMatrix{<:Float64};
84-
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))),
85-
info = Ref{BlasInt}(),
84+
ipiv = similar(A, LinearAlgebra.BlasInt, min(size(A, 1), size(A, 2))),
85+
info = Ref{LinearAlgebra.BlasInt}(),
8686
check = false)
87-
require_one_based_indexing(A)
87+
LinearAlgebra.require_one_based_indexing(A)
8888
check && chkfinite(A)
8989
chkstride1(A)
9090
m, n = size(A)
9191
lda = max(1, stride(A, 2))
9292
if isempty(ipiv)
93-
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2)))
93+
ipiv = similar(A, LinearAlgebra.BlasInt, min(size(A, 1), size(A, 2)))
9494
end
9595
ccall((@blasfunc(dgetrf_), OpenBLAS_jll.libopenblas), Cvoid,
96-
(Ref{BlasInt}, Ref{BlasInt}, Ptr{Float64},
97-
Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
96+
(Ref{LinearAlgebra.BlasInt}, Ref{LinearAlgebra.BlasInt}, Ptr{Float64},
97+
Ref{LinearAlgebra.BlasInt}, Ptr{LinearAlgebra.BlasInt}, Ptr{LinearAlgebra.BlasInt}),
9898
m, n, A, lda, ipiv, info)
9999
chkargsok(info[])
100100
A, ipiv, info[], info #Error code is stored in LU factorization type
101101
end
102102

103103
function getrf!(A::AbstractMatrix{<:Float32};
104-
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))),
105-
info = Ref{BlasInt}(),
104+
ipiv = similar(A, LinearAlgebra.BlasInt, min(size(A, 1), size(A, 2))),
105+
info = Ref{LinearAlgebra.BlasInt}(),
106106
check = false)
107-
require_one_based_indexing(A)
107+
LinearAlgebra.require_one_based_indexing(A)
108108
check && chkfinite(A)
109109
chkstride1(A)
110110
m, n = size(A)
111111
lda = max(1, stride(A, 2))
112112
if isempty(ipiv)
113-
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2)))
113+
ipiv = similar(A, LinearAlgebra.BlasInt, min(size(A, 1), size(A, 2)))
114114
end
115115
ccall((@blasfunc(sgetrf_), OpenBLAS_jll.libopenblas), Cvoid,
116-
(Ref{BlasInt}, Ref{BlasInt}, Ptr{Float32},
117-
Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
116+
(Ref{LinearAlgebra.BlasInt}, Ref{LinearAlgebra.BlasInt}, Ptr{Float32},
117+
Ref{LinearAlgebra.BlasInt}, Ptr{LinearAlgebra.BlasInt}, Ptr{LinearAlgebra.BlasInt}),
118118
m, n, A, lda, ipiv, info)
119119
chkargsok(info[])
120120
A, ipiv, info[], info #Error code is stored in LU factorization type
121121
end
122122

123123
function getrs!(trans::AbstractChar,
124124
A::AbstractMatrix{<:ComplexF64},
125-
ipiv::AbstractVector{BlasInt},
125+
ipiv::AbstractVector{LinearAlgebra.BlasInt},
126126
B::AbstractVecOrMat{<:ComplexF64};
127-
info = Ref{BlasInt}())
128-
require_one_based_indexing(A, ipiv, B)
127+
info = Ref{LinearAlgebra.BlasInt}())
128+
LinearAlgebra.require_one_based_indexing(A, ipiv, B)
129129
LinearAlgebra.LAPACK.chktrans(trans)
130130
chkstride1(A, B, ipiv)
131131
n = LinearAlgebra.checksquare(A)
@@ -137,20 +137,22 @@ function getrs!(trans::AbstractChar,
137137
end
138138
nrhs = size(B, 2)
139139
ccall((@blasfunc(zgetrs_), OpenBLAS_jll.libopenblas), Cvoid,
140-
(Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF64}, Ref{BlasInt},
141-
Ptr{BlasInt}, Ptr{ComplexF64}, Ref{BlasInt}, Ptr{BlasInt}, Clong),
140+
(Ref{UInt8}, Ref{LinearAlgebra.BlasInt}, Ref{LinearAlgebra.BlasInt},
141+
Ptr{ComplexF64}, Ref{LinearAlgebra.BlasInt},
142+
Ptr{LinearAlgebra.BlasInt}, Ptr{ComplexF64}, Ref{LinearAlgebra.BlasInt},
143+
Ptr{LinearAlgebra.BlasInt}, Clong),
142144
trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info,
143145
1)
144-
LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[]))
146+
LinearAlgebra.LAPACK.chklapackerror(LinearAlgebra.BlasInt(info[]))
145147
B
146148
end
147149

148150
function getrs!(trans::AbstractChar,
149151
A::AbstractMatrix{<:ComplexF32},
150-
ipiv::AbstractVector{BlasInt},
152+
ipiv::AbstractVector{LinearAlgebra.BlasInt},
151153
B::AbstractVecOrMat{<:ComplexF32};
152-
info = Ref{BlasInt}())
153-
require_one_based_indexing(A, ipiv, B)
154+
info = Ref{LinearAlgebra.BlasInt}())
155+
LinearAlgebra.require_one_based_indexing(A, ipiv, B)
154156
LinearAlgebra.LAPACK.chktrans(trans)
155157
chkstride1(A, B, ipiv)
156158
n = LinearAlgebra.checksquare(A)
@@ -162,20 +164,22 @@ function getrs!(trans::AbstractChar,
162164
end
163165
nrhs = size(B, 2)
164166
ccall((@blasfunc(cgetrs_), OpenBLAS_jll.libopenblas), Cvoid,
165-
(Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF32}, Ref{BlasInt},
166-
Ptr{BlasInt}, Ptr{ComplexF32}, Ref{BlasInt}, Ptr{BlasInt}, Clong),
167+
(Ref{UInt8}, Ref{LinearAlgebra.BlasInt}, Ref{LinearAlgebra.BlasInt},
168+
Ptr{ComplexF32}, Ref{LinearAlgebra.BlasInt},
169+
Ptr{LinearAlgebra.BlasInt}, Ptr{ComplexF32}, Ref{LinearAlgebra.BlasInt},
170+
Ptr{LinearAlgebra.BlasInt}, Clong),
167171
trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info,
168172
1)
169-
LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[]))
173+
LinearAlgebra.LAPACK.chklapackerror(LinearAlgebra.BlasInt(info[]))
170174
B
171175
end
172176

173177
function getrs!(trans::AbstractChar,
174178
A::AbstractMatrix{<:Float64},
175-
ipiv::AbstractVector{BlasInt},
179+
ipiv::AbstractVector{LinearAlgebra.BlasInt},
176180
B::AbstractVecOrMat{<:Float64};
177-
info = Ref{BlasInt}())
178-
require_one_based_indexing(A, ipiv, B)
181+
info = Ref{LinearAlgebra.BlasInt}())
182+
LinearAlgebra.require_one_based_indexing(A, ipiv, B)
179183
LinearAlgebra.LAPACK.chktrans(trans)
180184
chkstride1(A, B, ipiv)
181185
n = LinearAlgebra.checksquare(A)
@@ -187,20 +191,21 @@ function getrs!(trans::AbstractChar,
187191
end
188192
nrhs = size(B, 2)
189193
ccall((@blasfunc(dgetrs_), OpenBLAS_jll.libopenblas), Cvoid,
190-
(Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{Float64}, Ref{BlasInt},
191-
Ptr{BlasInt}, Ptr{Float64}, Ref{BlasInt}, Ptr{BlasInt}, Clong),
194+
(Ref{UInt8}, Ref{LinearAlgebra.BlasInt}, Ref{LinearAlgebra.BlasInt},
195+
Ptr{Float64}, Ref{LinearAlgebra.BlasInt},
196+
Ptr{LinearAlgebra.BlasInt}, Ptr{Float64}, Ref{LinearAlgebra.BlasInt}, Ptr{LinearAlgebra.BlasInt}, Clong),
192197
trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info,
193198
1)
194-
LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[]))
199+
LinearAlgebra.LAPACK.chklapackerror(LinearAlgebra.BlasInt(info[]))
195200
B
196201
end
197202

198203
function getrs!(trans::AbstractChar,
199204
A::AbstractMatrix{<:Float32},
200-
ipiv::AbstractVector{BlasInt},
205+
ipiv::AbstractVector{LinearAlgebra.BlasInt},
201206
B::AbstractVecOrMat{<:Float32};
202-
info = Ref{BlasInt}())
203-
require_one_based_indexing(A, ipiv, B)
207+
info = Ref{LinearAlgebra.BlasInt}())
208+
LinearAlgebra.require_one_based_indexing(A, ipiv, B)
204209
LinearAlgebra.LAPACK.chktrans(trans)
205210
chkstride1(A, B, ipiv)
206211
n = LinearAlgebra.checksquare(A)
@@ -212,11 +217,12 @@ function getrs!(trans::AbstractChar,
212217
end
213218
nrhs = size(B, 2)
214219
ccall((@blasfunc(sgetrs_), OpenBLAS_jll.libopenblas), Cvoid,
215-
(Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{Float32}, Ref{BlasInt},
216-
Ptr{BlasInt}, Ptr{Float32}, Ref{BlasInt}, Ptr{BlasInt}, Clong),
220+
(Ref{UInt8}, Ref{LinearAlgebra.BlasInt}, Ref{LinearAlgebra.BlasInt},
221+
Ptr{Float32}, Ref{LinearAlgebra.BlasInt},
222+
Ptr{LinearAlgebra.BlasInt}, Ptr{Float32}, Ref{LinearAlgebra.BlasInt}, Ptr{LinearAlgebra.BlasInt}, Clong),
217223
trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info,
218224
1)
219-
LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[]))
225+
LinearAlgebra.LAPACK.chklapackerror(LinearAlgebra.BlasInt(info[]))
220226
B
221227
end
222228

@@ -227,7 +233,7 @@ default_alias_b(::OpenBLASLUFactorization, ::Any, ::Any) = false
227233

228234
const PREALLOCATED_OPENBLAS_LU = begin
229235
A = rand(0, 0)
230-
luinst = ArrayInterface.lu_instance(A), Ref{BlasInt}()
236+
luinst = ArrayInterface.lu_instance(A), Ref{LinearAlgebra.BlasInt}()
231237
end
232238

233239
function LinearSolve.init_cacheval(alg::OpenBLASLUFactorization, A, b, u, Pl, Pr,
@@ -241,7 +247,7 @@ function LinearSolve.init_cacheval(alg::OpenBLASLUFactorization,
241247
maxiters::Int, abstol, reltol, verbose::LinearVerbosity,
242248
assumptions::OperatorAssumptions)
243249
A = rand(eltype(A), 0, 0)
244-
ArrayInterface.lu_instance(A), Ref{BlasInt}()
250+
ArrayInterface.lu_instance(A), Ref{LinearAlgebra.BlasInt}()
245251
end
246252

247253
function SciMLBase.solve!(cache::LinearCache, alg::OpenBLASLUFactorization;
@@ -251,7 +257,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::OpenBLASLUFactorization;
251257
if cache.isfresh
252258
cacheval = @get_cacheval(cache, :OpenBLASLUFactorization)
253259
res = OpenBLASLU.getrf!(A; ipiv = cacheval[1].ipiv, info = cacheval[2])
254-
fact = LU(res[1:3]...), res[4]
260+
fact = LinearAlgebra.LU(res[1:3]...), res[4]
255261
cache.cacheval = fact
256262

257263
if !LinearAlgebra.issuccess(fact[1])
@@ -262,7 +268,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::OpenBLASLUFactorization;
262268
end
263269

264270
A, info = @get_cacheval(cache, :OpenBLASLUFactorization)
265-
require_one_based_indexing(cache.u, cache.b)
271+
LinearAlgebra.require_one_based_indexing(cache.u, cache.b)
266272
m, n = size(A, 1), size(A, 2)
267273
if m > n
268274
Bc = copy(cache.b)
@@ -273,5 +279,6 @@ function SciMLBase.solve!(cache::LinearCache, alg::OpenBLASLUFactorization;
273279
OpenBLASLU.getrs!('N', A.factors, A.ipiv, cache.u; info)
274280
end
275281

276-
SciMLBase.build_linear_solution(alg, cache.u, nothing, cache; retcode = ReturnCode.Success)
277-
end
282+
SciMLBase.build_linear_solution(
283+
alg, cache.u, nothing, cache; retcode = ReturnCode.Success)
284+
end

0 commit comments

Comments
 (0)