Skip to content

Commit 5a97dbb

Browse files
committed
Add OpenBLASLUFactorization implementation
- Implement OpenBLASLUFactorization as a direct wrapper over OpenBLAS_jll - Add getrf! and getrs! functions for LU factorization and solving - Support Float32, Float64, ComplexF32, and ComplexF64 types - Include proper module structure and exports - Add OpenBLAS_jll as a dependency - Tests confirm functionality matches existing LUFactorization
1 parent 99c54ec commit 5a97dbb

File tree

4 files changed

+316
-0
lines changed

4 files changed

+316
-0
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
1717
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1818
MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7"
1919
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
20+
OpenBLAS_jll = "4536629a-c528-5b80-bd46-f80d51c5b363"
2021
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
2122
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
2223
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"

src/LinearSolve.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ else
5959
const usemkl = false
6060
end
6161

62+
using OpenBLAS_jll
63+
6264

6365
@reexport using SciMLBase
6466

@@ -345,6 +347,7 @@ include("extension_algs.jl")
345347
include("factorization.jl")
346348
include("appleaccelerate.jl")
347349
include("mkl.jl")
350+
include("openblas.jl")
348351
include("simplelu.jl")
349352
include("simplegmres.jl")
350353
include("iterative_wrappers.jl")
@@ -461,6 +464,7 @@ export MKLPardisoFactorize, MKLPardisoIterate
461464
export PanuaPardisoFactorize, PanuaPardisoIterate
462465
export PardisoJL
463466
export MKLLUFactorization
467+
export OpenBLASLUFactorization
464468
export AppleAccelerateLUFactorization
465469
export MetalLUFactorization
466470

src/openblas.jl

Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
"""
2+
```julia
3+
OpenBLASLUFactorization()
4+
```
5+
6+
A wrapper over OpenBLAS. Direct calls to OpenBLAS in a way that pre-allocates workspace
7+
to avoid allocations and does not require libblastrampoline.
8+
"""
9+
struct OpenBLASLUFactorization <: AbstractFactorization end
10+
11+
module OpenBLASLU
12+
13+
using LinearAlgebra
14+
using LinearAlgebra: BlasInt, LU, require_one_based_indexing, checksquare
15+
using LinearAlgebra.LAPACK: chkfinite, chkstride1, @blasfunc, chkargsok, chktrans, chklapackerror
16+
using OpenBLAS_jll
17+
18+
function getrf!(A::AbstractMatrix{<:ComplexF64};
19+
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))),
20+
info = Ref{BlasInt}(),
21+
check = false)
22+
require_one_based_indexing(A)
23+
check && chkfinite(A)
24+
chkstride1(A)
25+
m, n = size(A)
26+
lda = max(1, stride(A, 2))
27+
if isempty(ipiv)
28+
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2)))
29+
end
30+
ccall((@blasfunc(zgetrf_), OpenBLAS_jll.libopenblas), Cvoid,
31+
(Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF64},
32+
Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
33+
m, n, A, lda, ipiv, info)
34+
chkargsok(info[])
35+
A, ipiv, info[], info #Error code is stored in LU factorization type
36+
end
37+
38+
function getrf!(A::AbstractMatrix{<:ComplexF32};
39+
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))),
40+
info = Ref{BlasInt}(),
41+
check = false)
42+
require_one_based_indexing(A)
43+
check && chkfinite(A)
44+
chkstride1(A)
45+
m, n = size(A)
46+
lda = max(1, stride(A, 2))
47+
if isempty(ipiv)
48+
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2)))
49+
end
50+
ccall((@blasfunc(cgetrf_), OpenBLAS_jll.libopenblas), Cvoid,
51+
(Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF32},
52+
Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
53+
m, n, A, lda, ipiv, info)
54+
chkargsok(info[])
55+
A, ipiv, info[], info #Error code is stored in LU factorization type
56+
end
57+
58+
function getrf!(A::AbstractMatrix{<:Float64};
59+
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))),
60+
info = Ref{BlasInt}(),
61+
check = false)
62+
require_one_based_indexing(A)
63+
check && chkfinite(A)
64+
chkstride1(A)
65+
m, n = size(A)
66+
lda = max(1, stride(A, 2))
67+
if isempty(ipiv)
68+
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2)))
69+
end
70+
ccall((@blasfunc(dgetrf_), OpenBLAS_jll.libopenblas), Cvoid,
71+
(Ref{BlasInt}, Ref{BlasInt}, Ptr{Float64},
72+
Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
73+
m, n, A, lda, ipiv, info)
74+
chkargsok(info[])
75+
A, ipiv, info[], info #Error code is stored in LU factorization type
76+
end
77+
78+
function getrf!(A::AbstractMatrix{<:Float32};
79+
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))),
80+
info = Ref{BlasInt}(),
81+
check = false)
82+
require_one_based_indexing(A)
83+
check && chkfinite(A)
84+
chkstride1(A)
85+
m, n = size(A)
86+
lda = max(1, stride(A, 2))
87+
if isempty(ipiv)
88+
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2)))
89+
end
90+
ccall((@blasfunc(sgetrf_), OpenBLAS_jll.libopenblas), Cvoid,
91+
(Ref{BlasInt}, Ref{BlasInt}, Ptr{Float32},
92+
Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
93+
m, n, A, lda, ipiv, info)
94+
chkargsok(info[])
95+
A, ipiv, info[], info #Error code is stored in LU factorization type
96+
end
97+
98+
function getrs!(trans::AbstractChar,
99+
A::AbstractMatrix{<:ComplexF64},
100+
ipiv::AbstractVector{BlasInt},
101+
B::AbstractVecOrMat{<:ComplexF64};
102+
info = Ref{BlasInt}())
103+
require_one_based_indexing(A, ipiv, B)
104+
LinearAlgebra.LAPACK.chktrans(trans)
105+
chkstride1(A, B, ipiv)
106+
n = LinearAlgebra.checksquare(A)
107+
if n != size(B, 1)
108+
throw(DimensionMismatch("B has leading dimension $(size(B,1)), but needs $n"))
109+
end
110+
if n != length(ipiv)
111+
throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n"))
112+
end
113+
nrhs = size(B, 2)
114+
ccall((@blasfunc(zgetrs_), OpenBLAS_jll.libopenblas), Cvoid,
115+
(Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF64}, Ref{BlasInt},
116+
Ptr{BlasInt}, Ptr{ComplexF64}, Ref{BlasInt}, Ptr{BlasInt}, Clong),
117+
trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info,
118+
1)
119+
LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[]))
120+
B
121+
end
122+
123+
function getrs!(trans::AbstractChar,
124+
A::AbstractMatrix{<:ComplexF32},
125+
ipiv::AbstractVector{BlasInt},
126+
B::AbstractVecOrMat{<:ComplexF32};
127+
info = Ref{BlasInt}())
128+
require_one_based_indexing(A, ipiv, B)
129+
LinearAlgebra.LAPACK.chktrans(trans)
130+
chkstride1(A, B, ipiv)
131+
n = LinearAlgebra.checksquare(A)
132+
if n != size(B, 1)
133+
throw(DimensionMismatch("B has leading dimension $(size(B,1)), but needs $n"))
134+
end
135+
if n != length(ipiv)
136+
throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n"))
137+
end
138+
nrhs = size(B, 2)
139+
ccall((@blasfunc(cgetrs_), OpenBLAS_jll.libopenblas), Cvoid,
140+
(Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF32}, Ref{BlasInt},
141+
Ptr{BlasInt}, Ptr{ComplexF32}, Ref{BlasInt}, Ptr{BlasInt}, Clong),
142+
trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info,
143+
1)
144+
LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[]))
145+
B
146+
end
147+
148+
function getrs!(trans::AbstractChar,
149+
A::AbstractMatrix{<:Float64},
150+
ipiv::AbstractVector{BlasInt},
151+
B::AbstractVecOrMat{<:Float64};
152+
info = Ref{BlasInt}())
153+
require_one_based_indexing(A, ipiv, B)
154+
LinearAlgebra.LAPACK.chktrans(trans)
155+
chkstride1(A, B, ipiv)
156+
n = LinearAlgebra.checksquare(A)
157+
if n != size(B, 1)
158+
throw(DimensionMismatch("B has leading dimension $(size(B,1)), but needs $n"))
159+
end
160+
if n != length(ipiv)
161+
throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n"))
162+
end
163+
nrhs = size(B, 2)
164+
ccall((@blasfunc(dgetrs_), OpenBLAS_jll.libopenblas), Cvoid,
165+
(Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{Float64}, Ref{BlasInt},
166+
Ptr{BlasInt}, Ptr{Float64}, Ref{BlasInt}, Ptr{BlasInt}, Clong),
167+
trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info,
168+
1)
169+
LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[]))
170+
B
171+
end
172+
173+
function getrs!(trans::AbstractChar,
174+
A::AbstractMatrix{<:Float32},
175+
ipiv::AbstractVector{BlasInt},
176+
B::AbstractVecOrMat{<:Float32};
177+
info = Ref{BlasInt}())
178+
require_one_based_indexing(A, ipiv, B)
179+
LinearAlgebra.LAPACK.chktrans(trans)
180+
chkstride1(A, B, ipiv)
181+
n = LinearAlgebra.checksquare(A)
182+
if n != size(B, 1)
183+
throw(DimensionMismatch("B has leading dimension $(size(B,1)), but needs $n"))
184+
end
185+
if n != length(ipiv)
186+
throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n"))
187+
end
188+
nrhs = size(B, 2)
189+
ccall((@blasfunc(sgetrs_), OpenBLAS_jll.libopenblas), Cvoid,
190+
(Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{Float32}, Ref{BlasInt},
191+
Ptr{BlasInt}, Ptr{Float32}, Ref{BlasInt}, Ptr{BlasInt}, Clong),
192+
trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info,
193+
1)
194+
LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[]))
195+
B
196+
end
197+
198+
end # module OpenBLASLU
199+
200+
default_alias_A(::OpenBLASLUFactorization, ::Any, ::Any) = false
201+
default_alias_b(::OpenBLASLUFactorization, ::Any, ::Any) = false
202+
203+
const PREALLOCATED_OPENBLAS_LU = begin
204+
A = rand(0, 0)
205+
luinst = ArrayInterface.lu_instance(A), Ref{BlasInt}()
206+
end
207+
208+
function LinearSolve.init_cacheval(alg::OpenBLASLUFactorization, A, b, u, Pl, Pr,
209+
maxiters::Int, abstol, reltol, verbose::LinearVerbosity,
210+
assumptions::OperatorAssumptions)
211+
PREALLOCATED_OPENBLAS_LU
212+
end
213+
214+
function LinearSolve.init_cacheval(alg::OpenBLASLUFactorization,
215+
A::AbstractMatrix{<:Union{Float32, ComplexF32, ComplexF64}}, b, u, Pl, Pr,
216+
maxiters::Int, abstol, reltol, verbose::LinearVerbosity,
217+
assumptions::OperatorAssumptions)
218+
A = rand(eltype(A), 0, 0)
219+
ArrayInterface.lu_instance(A), Ref{BlasInt}()
220+
end
221+
222+
function SciMLBase.solve!(cache::LinearCache, alg::OpenBLASLUFactorization;
223+
kwargs...)
224+
A = cache.A
225+
A = convert(AbstractMatrix, A)
226+
if cache.isfresh
227+
cacheval = @get_cacheval(cache, :OpenBLASLUFactorization)
228+
res = OpenBLASLU.getrf!(A; ipiv = cacheval[1].ipiv, info = cacheval[2])
229+
fact = LU(res[1:3]...), res[4]
230+
cache.cacheval = fact
231+
232+
if !LinearAlgebra.issuccess(fact[1])
233+
return SciMLBase.build_linear_solution(
234+
alg, cache.u, nothing, cache; retcode = ReturnCode.Failure)
235+
end
236+
cache.isfresh = false
237+
end
238+
239+
A, info = @get_cacheval(cache, :OpenBLASLUFactorization)
240+
require_one_based_indexing(cache.u, cache.b)
241+
m, n = size(A, 1), size(A, 2)
242+
if m > n
243+
Bc = copy(cache.b)
244+
OpenBLASLU.getrs!('N', A.factors, A.ipiv, Bc; info)
245+
copyto!(cache.u, 1, Bc, 1, n)
246+
else
247+
copyto!(cache.u, cache.b)
248+
OpenBLASLU.getrs!('N', A.factors, A.ipiv, cache.u; info)
249+
end
250+
251+
SciMLBase.build_linear_solution(alg, cache.u, nothing, cache; retcode = ReturnCode.Success)
252+
end

test_openblas.jl

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
using LinearAlgebra
2+
using LinearSolve
3+
using Test
4+
5+
@testset "OpenBLASLUFactorization Tests" begin
6+
# Test with Float64
7+
@testset "Float64" begin
8+
A = rand(10, 10)
9+
b = rand(10)
10+
prob = LinearProblem(A, b)
11+
12+
sol_openblas = solve(prob, OpenBLASLUFactorization())
13+
sol_default = solve(prob, LUFactorization())
14+
15+
@test norm(A * sol_openblas.u - b) < 1e-10
16+
@test norm(sol_openblas.u - sol_default.u) < 1e-10
17+
end
18+
19+
# Test with Float32
20+
@testset "Float32" begin
21+
A = rand(Float32, 10, 10)
22+
b = rand(Float32, 10)
23+
prob = LinearProblem(A, b)
24+
25+
sol_openblas = solve(prob, OpenBLASLUFactorization())
26+
sol_default = solve(prob, LUFactorization())
27+
28+
@test norm(A * sol_openblas.u - b) < 1e-5
29+
@test norm(sol_openblas.u - sol_default.u) < 1e-5
30+
end
31+
32+
# Test with ComplexF64
33+
@testset "ComplexF64" begin
34+
A = rand(ComplexF64, 10, 10)
35+
b = rand(ComplexF64, 10)
36+
prob = LinearProblem(A, b)
37+
38+
sol_openblas = solve(prob, OpenBLASLUFactorization())
39+
sol_default = solve(prob, LUFactorization())
40+
41+
@test norm(A * sol_openblas.u - b) < 1e-10
42+
@test norm(sol_openblas.u - sol_default.u) < 1e-10
43+
end
44+
45+
# Test with ComplexF32
46+
@testset "ComplexF32" begin
47+
A = rand(ComplexF32, 10, 10)
48+
b = rand(ComplexF32, 10)
49+
prob = LinearProblem(A, b)
50+
51+
sol_openblas = solve(prob, OpenBLASLUFactorization())
52+
sol_default = solve(prob, LUFactorization())
53+
54+
@test norm(A * sol_openblas.u - b) < 1e-5
55+
@test norm(sol_openblas.u - sol_default.u) < 1e-5
56+
end
57+
end
58+
59+
println("All tests passed!")

0 commit comments

Comments
 (0)