Skip to content

Commit a89490a

Browse files
committed
Refactor OpenBLAS implementation to match MKL pattern
- Remove separate module structure, follow MKL's pattern exactly - Conditionally check for OpenBLAS_jll availability using is_available() - Keep OpenBLAS_jll as a dependency (required even for stdlib packages) - Simplify implementation without nested @static checks - Tests conditionally run based on LinearSolve.useopenblas flag
1 parent ea53e7f commit a89490a

File tree

5 files changed

+78
-85
lines changed

5 files changed

+78
-85
lines changed

src/LinearSolve.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,9 @@ else
5959
const usemkl = false
6060
end
6161

62+
# OpenBLAS_jll is a standard library, always available
6263
using OpenBLAS_jll
64+
const useopenblas = OpenBLAS_jll.is_available()
6365

6466

6567
@reexport using SciMLBase

src/openblas.jl

Lines changed: 64 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -33,99 +33,94 @@ sol = solve(prob, OpenBLASLUFactorization())
3333
"""
3434
struct OpenBLASLUFactorization <: AbstractFactorization end
3535

36-
module OpenBLASLU
36+
# OpenBLAS methods - OpenBLAS_jll is always available as a standard library
3737

38-
using LinearAlgebra
39-
using LinearAlgebra.LAPACK: chkfinite, chkstride1, @blasfunc, chkargsok, chktrans,
40-
chklapackerror
41-
using OpenBLAS_jll
42-
43-
function getrf!(A::AbstractMatrix{<:ComplexF64};
44-
ipiv = similar(A, LinearAlgebra.BlasInt, min(size(A, 1), size(A, 2))),
45-
info = Ref{LinearAlgebra.BlasInt}(),
38+
function openblas_getrf!(A::AbstractMatrix{<:ComplexF64};
39+
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))),
40+
info = Ref{BlasInt}(),
4641
check = false)
47-
LinearAlgebra.require_one_based_indexing(A)
42+
require_one_based_indexing(A)
4843
check && chkfinite(A)
4944
chkstride1(A)
5045
m, n = size(A)
5146
lda = max(1, stride(A, 2))
5247
if isempty(ipiv)
53-
ipiv = similar(A, LinearAlgebra.BlasInt, min(size(A, 1), size(A, 2)))
48+
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2)))
5449
end
5550
ccall((@blasfunc(zgetrf_), OpenBLAS_jll.libopenblas), Cvoid,
56-
(Ref{LinearAlgebra.BlasInt}, Ref{LinearAlgebra.BlasInt}, Ptr{ComplexF64},
57-
Ref{LinearAlgebra.BlasInt}, Ptr{LinearAlgebra.BlasInt}, Ptr{LinearAlgebra.BlasInt}),
51+
(Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF64},
52+
Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
5853
m, n, A, lda, ipiv, info)
5954
chkargsok(info[])
6055
A, ipiv, info[], info #Error code is stored in LU factorization type
6156
end
6257

63-
function getrf!(A::AbstractMatrix{<:ComplexF32};
64-
ipiv = similar(A, LinearAlgebra.BlasInt, min(size(A, 1), size(A, 2))),
65-
info = Ref{LinearAlgebra.BlasInt}(),
58+
function openblas_getrf!(A::AbstractMatrix{<:ComplexF32};
59+
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))),
60+
info = Ref{BlasInt}(),
6661
check = false)
67-
LinearAlgebra.require_one_based_indexing(A)
62+
require_one_based_indexing(A)
6863
check && chkfinite(A)
6964
chkstride1(A)
7065
m, n = size(A)
7166
lda = max(1, stride(A, 2))
7267
if isempty(ipiv)
73-
ipiv = similar(A, LinearAlgebra.BlasInt, min(size(A, 1), size(A, 2)))
68+
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2)))
7469
end
7570
ccall((@blasfunc(cgetrf_), OpenBLAS_jll.libopenblas), Cvoid,
76-
(Ref{LinearAlgebra.BlasInt}, Ref{LinearAlgebra.BlasInt}, Ptr{ComplexF32},
77-
Ref{LinearAlgebra.BlasInt}, Ptr{LinearAlgebra.BlasInt}, Ptr{LinearAlgebra.BlasInt}),
71+
(Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF32},
72+
Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
7873
m, n, A, lda, ipiv, info)
7974
chkargsok(info[])
8075
A, ipiv, info[], info #Error code is stored in LU factorization type
8176
end
8277

83-
function getrf!(A::AbstractMatrix{<:Float64};
84-
ipiv = similar(A, LinearAlgebra.BlasInt, min(size(A, 1), size(A, 2))),
85-
info = Ref{LinearAlgebra.BlasInt}(),
78+
function openblas_getrf!(A::AbstractMatrix{<:Float64};
79+
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))),
80+
info = Ref{BlasInt}(),
8681
check = false)
87-
LinearAlgebra.require_one_based_indexing(A)
82+
require_one_based_indexing(A)
8883
check && chkfinite(A)
8984
chkstride1(A)
9085
m, n = size(A)
9186
lda = max(1, stride(A, 2))
9287
if isempty(ipiv)
93-
ipiv = similar(A, LinearAlgebra.BlasInt, min(size(A, 1), size(A, 2)))
88+
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2)))
9489
end
9590
ccall((@blasfunc(dgetrf_), OpenBLAS_jll.libopenblas), Cvoid,
96-
(Ref{LinearAlgebra.BlasInt}, Ref{LinearAlgebra.BlasInt}, Ptr{Float64},
97-
Ref{LinearAlgebra.BlasInt}, Ptr{LinearAlgebra.BlasInt}, Ptr{LinearAlgebra.BlasInt}),
91+
(Ref{BlasInt}, Ref{BlasInt}, Ptr{Float64},
92+
Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
9893
m, n, A, lda, ipiv, info)
9994
chkargsok(info[])
10095
A, ipiv, info[], info #Error code is stored in LU factorization type
10196
end
10297

103-
function getrf!(A::AbstractMatrix{<:Float32};
104-
ipiv = similar(A, LinearAlgebra.BlasInt, min(size(A, 1), size(A, 2))),
105-
info = Ref{LinearAlgebra.BlasInt}(),
98+
function openblas_getrf!(A::AbstractMatrix{<:Float32};
99+
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))),
100+
info = Ref{BlasInt}(),
106101
check = false)
107-
LinearAlgebra.require_one_based_indexing(A)
102+
require_one_based_indexing(A)
108103
check && chkfinite(A)
109104
chkstride1(A)
110105
m, n = size(A)
111106
lda = max(1, stride(A, 2))
112107
if isempty(ipiv)
113-
ipiv = similar(A, LinearAlgebra.BlasInt, min(size(A, 1), size(A, 2)))
108+
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2)))
114109
end
115110
ccall((@blasfunc(sgetrf_), OpenBLAS_jll.libopenblas), Cvoid,
116-
(Ref{LinearAlgebra.BlasInt}, Ref{LinearAlgebra.BlasInt}, Ptr{Float32},
117-
Ref{LinearAlgebra.BlasInt}, Ptr{LinearAlgebra.BlasInt}, Ptr{LinearAlgebra.BlasInt}),
111+
(Ref{BlasInt}, Ref{BlasInt}, Ptr{Float32},
112+
Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
118113
m, n, A, lda, ipiv, info)
119114
chkargsok(info[])
120115
A, ipiv, info[], info #Error code is stored in LU factorization type
121116
end
122117

123-
function getrs!(trans::AbstractChar,
118+
function openblas_getrs!(trans::AbstractChar,
124119
A::AbstractMatrix{<:ComplexF64},
125-
ipiv::AbstractVector{LinearAlgebra.BlasInt},
120+
ipiv::AbstractVector{BlasInt},
126121
B::AbstractVecOrMat{<:ComplexF64};
127-
info = Ref{LinearAlgebra.BlasInt}())
128-
LinearAlgebra.require_one_based_indexing(A, ipiv, B)
122+
info = Ref{BlasInt}())
123+
require_one_based_indexing(A, ipiv, B)
129124
LinearAlgebra.LAPACK.chktrans(trans)
130125
chkstride1(A, B, ipiv)
131126
n = LinearAlgebra.checksquare(A)
@@ -137,22 +132,20 @@ function getrs!(trans::AbstractChar,
137132
end
138133
nrhs = size(B, 2)
139134
ccall((@blasfunc(zgetrs_), OpenBLAS_jll.libopenblas), Cvoid,
140-
(Ref{UInt8}, Ref{LinearAlgebra.BlasInt}, Ref{LinearAlgebra.BlasInt},
141-
Ptr{ComplexF64}, Ref{LinearAlgebra.BlasInt},
142-
Ptr{LinearAlgebra.BlasInt}, Ptr{ComplexF64}, Ref{LinearAlgebra.BlasInt},
143-
Ptr{LinearAlgebra.BlasInt}, Clong),
135+
(Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF64}, Ref{BlasInt},
136+
Ptr{BlasInt}, Ptr{ComplexF64}, Ref{BlasInt}, Ptr{BlasInt}, Clong),
144137
trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info,
145138
1)
146-
LinearAlgebra.LAPACK.chklapackerror(LinearAlgebra.BlasInt(info[]))
139+
LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[]))
147140
B
148141
end
149142

150-
function getrs!(trans::AbstractChar,
143+
function openblas_getrs!(trans::AbstractChar,
151144
A::AbstractMatrix{<:ComplexF32},
152-
ipiv::AbstractVector{LinearAlgebra.BlasInt},
145+
ipiv::AbstractVector{BlasInt},
153146
B::AbstractVecOrMat{<:ComplexF32};
154-
info = Ref{LinearAlgebra.BlasInt}())
155-
LinearAlgebra.require_one_based_indexing(A, ipiv, B)
147+
info = Ref{BlasInt}())
148+
require_one_based_indexing(A, ipiv, B)
156149
LinearAlgebra.LAPACK.chktrans(trans)
157150
chkstride1(A, B, ipiv)
158151
n = LinearAlgebra.checksquare(A)
@@ -164,22 +157,20 @@ function getrs!(trans::AbstractChar,
164157
end
165158
nrhs = size(B, 2)
166159
ccall((@blasfunc(cgetrs_), OpenBLAS_jll.libopenblas), Cvoid,
167-
(Ref{UInt8}, Ref{LinearAlgebra.BlasInt}, Ref{LinearAlgebra.BlasInt},
168-
Ptr{ComplexF32}, Ref{LinearAlgebra.BlasInt},
169-
Ptr{LinearAlgebra.BlasInt}, Ptr{ComplexF32}, Ref{LinearAlgebra.BlasInt},
170-
Ptr{LinearAlgebra.BlasInt}, Clong),
160+
(Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF32}, Ref{BlasInt},
161+
Ptr{BlasInt}, Ptr{ComplexF32}, Ref{BlasInt}, Ptr{BlasInt}, Clong),
171162
trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info,
172163
1)
173-
LinearAlgebra.LAPACK.chklapackerror(LinearAlgebra.BlasInt(info[]))
164+
LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[]))
174165
B
175166
end
176167

177-
function getrs!(trans::AbstractChar,
168+
function openblas_getrs!(trans::AbstractChar,
178169
A::AbstractMatrix{<:Float64},
179-
ipiv::AbstractVector{LinearAlgebra.BlasInt},
170+
ipiv::AbstractVector{BlasInt},
180171
B::AbstractVecOrMat{<:Float64};
181-
info = Ref{LinearAlgebra.BlasInt}())
182-
LinearAlgebra.require_one_based_indexing(A, ipiv, B)
172+
info = Ref{BlasInt}())
173+
require_one_based_indexing(A, ipiv, B)
183174
LinearAlgebra.LAPACK.chktrans(trans)
184175
chkstride1(A, B, ipiv)
185176
n = LinearAlgebra.checksquare(A)
@@ -191,21 +182,20 @@ function getrs!(trans::AbstractChar,
191182
end
192183
nrhs = size(B, 2)
193184
ccall((@blasfunc(dgetrs_), OpenBLAS_jll.libopenblas), Cvoid,
194-
(Ref{UInt8}, Ref{LinearAlgebra.BlasInt}, Ref{LinearAlgebra.BlasInt},
195-
Ptr{Float64}, Ref{LinearAlgebra.BlasInt},
196-
Ptr{LinearAlgebra.BlasInt}, Ptr{Float64}, Ref{LinearAlgebra.BlasInt}, Ptr{LinearAlgebra.BlasInt}, Clong),
185+
(Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{Float64}, Ref{BlasInt},
186+
Ptr{BlasInt}, Ptr{Float64}, Ref{BlasInt}, Ptr{BlasInt}, Clong),
197187
trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info,
198188
1)
199-
LinearAlgebra.LAPACK.chklapackerror(LinearAlgebra.BlasInt(info[]))
189+
LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[]))
200190
B
201191
end
202192

203-
function getrs!(trans::AbstractChar,
193+
function openblas_getrs!(trans::AbstractChar,
204194
A::AbstractMatrix{<:Float32},
205-
ipiv::AbstractVector{LinearAlgebra.BlasInt},
195+
ipiv::AbstractVector{BlasInt},
206196
B::AbstractVecOrMat{<:Float32};
207-
info = Ref{LinearAlgebra.BlasInt}())
208-
LinearAlgebra.require_one_based_indexing(A, ipiv, B)
197+
info = Ref{BlasInt}())
198+
require_one_based_indexing(A, ipiv, B)
209199
LinearAlgebra.LAPACK.chktrans(trans)
210200
chkstride1(A, B, ipiv)
211201
n = LinearAlgebra.checksquare(A)
@@ -217,23 +207,20 @@ function getrs!(trans::AbstractChar,
217207
end
218208
nrhs = size(B, 2)
219209
ccall((@blasfunc(sgetrs_), OpenBLAS_jll.libopenblas), Cvoid,
220-
(Ref{UInt8}, Ref{LinearAlgebra.BlasInt}, Ref{LinearAlgebra.BlasInt},
221-
Ptr{Float32}, Ref{LinearAlgebra.BlasInt},
222-
Ptr{LinearAlgebra.BlasInt}, Ptr{Float32}, Ref{LinearAlgebra.BlasInt}, Ptr{LinearAlgebra.BlasInt}, Clong),
210+
(Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{Float32}, Ref{BlasInt},
211+
Ptr{BlasInt}, Ptr{Float32}, Ref{BlasInt}, Ptr{BlasInt}, Clong),
223212
trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info,
224213
1)
225-
LinearAlgebra.LAPACK.chklapackerror(LinearAlgebra.BlasInt(info[]))
214+
LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[]))
226215
B
227216
end
228217

229-
end # module OpenBLASLU
230-
231218
default_alias_A(::OpenBLASLUFactorization, ::Any, ::Any) = false
232219
default_alias_b(::OpenBLASLUFactorization, ::Any, ::Any) = false
233220

234221
const PREALLOCATED_OPENBLAS_LU = begin
235222
A = rand(0, 0)
236-
luinst = ArrayInterface.lu_instance(A), Ref{LinearAlgebra.BlasInt}()
223+
luinst = ArrayInterface.lu_instance(A), Ref{BlasInt}()
237224
end
238225

239226
function LinearSolve.init_cacheval(alg::OpenBLASLUFactorization, A, b, u, Pl, Pr,
@@ -247,7 +234,7 @@ function LinearSolve.init_cacheval(alg::OpenBLASLUFactorization,
247234
maxiters::Int, abstol, reltol, verbose::LinearVerbosity,
248235
assumptions::OperatorAssumptions)
249236
A = rand(eltype(A), 0, 0)
250-
ArrayInterface.lu_instance(A), Ref{LinearAlgebra.BlasInt}()
237+
ArrayInterface.lu_instance(A), Ref{BlasInt}()
251238
end
252239

253240
function SciMLBase.solve!(cache::LinearCache, alg::OpenBLASLUFactorization;
@@ -256,8 +243,8 @@ function SciMLBase.solve!(cache::LinearCache, alg::OpenBLASLUFactorization;
256243
A = convert(AbstractMatrix, A)
257244
if cache.isfresh
258245
cacheval = @get_cacheval(cache, :OpenBLASLUFactorization)
259-
res = OpenBLASLU.getrf!(A; ipiv = cacheval[1].ipiv, info = cacheval[2])
260-
fact = LinearAlgebra.LU(res[1:3]...), res[4]
246+
res = openblas_getrf!(A; ipiv = cacheval[1].ipiv, info = cacheval[2])
247+
fact = LU(res[1:3]...), res[4]
261248
cache.cacheval = fact
262249

263250
if !LinearAlgebra.issuccess(fact[1])
@@ -268,15 +255,15 @@ function SciMLBase.solve!(cache::LinearCache, alg::OpenBLASLUFactorization;
268255
end
269256

270257
A, info = @get_cacheval(cache, :OpenBLASLUFactorization)
271-
LinearAlgebra.require_one_based_indexing(cache.u, cache.b)
258+
require_one_based_indexing(cache.u, cache.b)
272259
m, n = size(A, 1), size(A, 2)
273260
if m > n
274261
Bc = copy(cache.b)
275-
OpenBLASLU.getrs!('N', A.factors, A.ipiv, Bc; info)
262+
openblas_getrs!('N', A.factors, A.ipiv, Bc; info)
276263
copyto!(cache.u, 1, Bc, 1, n)
277264
else
278265
copyto!(cache.u, cache.b)
279-
OpenBLASLU.getrs!('N', A.factors, A.ipiv, cache.u; info)
266+
openblas_getrs!('N', A.factors, A.ipiv, cache.u; info)
280267
end
281268

282269
SciMLBase.build_linear_solution(

test/basictests.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -287,8 +287,10 @@ end
287287
push!(test_algs, MKLLUFactorization())
288288
end
289289

290-
# Always test OpenBLAS since it's a direct dependency
291-
push!(test_algs, OpenBLASLUFactorization())
290+
# Test OpenBLAS if available
291+
if LinearSolve.useopenblas
292+
push!(test_algs, OpenBLASLUFactorization())
293+
end
292294

293295
# Test BLIS if extension is available
294296
if Base.get_extension(LinearSolve, :LinearSolveBLISExt) !== nothing

test/preferences.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -190,11 +190,13 @@ using Preferences
190190
println("✅ MKLLUFactorization confirmed working")
191191
end
192192

193-
# Test OpenBLAS (always available as a dependency)
194-
sol_openblas = solve(prob, OpenBLASLUFactorization())
195-
@test sol_openblas.retcode == ReturnCode.Success
196-
@test norm(A * sol_openblas.u - b) < 1e-8
197-
println("✅ OpenBLASLUFactorization confirmed working")
193+
# Test OpenBLAS if available
194+
if LinearSolve.useopenblas
195+
sol_openblas = solve(prob, OpenBLASLUFactorization())
196+
@test sol_openblas.retcode == ReturnCode.Success
197+
@test norm(A * sol_openblas.u - b) < 1e-8
198+
println("✅ OpenBLASLUFactorization confirmed working")
199+
end
198200

199201
# Test Apple Accelerate if available
200202
if LinearSolve.appleaccelerate_isavailable()

test/resolve.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ for alg in vcat(InteractiveUtils.subtypes(AbstractDenseFactorization),
2020
(!(alg == AppleAccelerateLUFactorization) ||
2121
LinearSolve.appleaccelerate_isavailable()) &&
2222
(!(alg == MKLLUFactorization) || LinearSolve.usemkl) &&
23-
(!(alg == OpenBLASLUFactorization) || true) # OpenBLAS is always available as a dependency
23+
(!(alg == OpenBLASLUFactorization) || LinearSolve.useopenblas)
2424
A = [1.0 2.0; 3.0 4.0]
2525
alg in [KLUFactorization, UMFPACKFactorization, SparspakFactorization] &&
2626
(A = sparse(A))

0 commit comments

Comments
 (0)