|
| 1 | +""" |
| 2 | +```julia |
| 3 | +OpenBLASLUFactorization() |
| 4 | +``` |
| 5 | +
|
| 6 | +A direct wrapper over OpenBLAS's LU factorization (`getrf!` and `getrs!`). |
| 7 | +This solver makes direct calls to OpenBLAS_jll without going through Julia's |
| 8 | +libblastrampoline, which can provide performance benefits in certain configurations. |
| 9 | +
|
| 10 | +## Performance Characteristics |
| 11 | +
|
| 12 | + - Pre-allocates workspace to avoid allocations during solving |
| 13 | + - Makes direct `ccall`s to OpenBLAS routines |
| 14 | + - Can be faster than `LUFactorization` when OpenBLAS is well-optimized for the hardware |
| 15 | + - Supports `Float32`, `Float64`, `ComplexF32`, and `ComplexF64` element types |
| 16 | +
|
| 17 | +## When to Use |
| 18 | +
|
| 19 | + - When you want to ensure OpenBLAS is used regardless of the system BLAS configuration |
| 20 | + - When benchmarking shows better performance than `LUFactorization` on your specific hardware |
| 21 | + - When you need consistent behavior across different systems (always uses OpenBLAS) |
| 22 | +
|
| 23 | +## Example |
| 24 | +
|
| 25 | +```julia |
| 26 | +using LinearSolve, LinearAlgebra |
| 27 | +
|
| 28 | +A = rand(100, 100) |
| 29 | +b = rand(100) |
| 30 | +prob = LinearProblem(A, b) |
| 31 | +sol = solve(prob, OpenBLASLUFactorization()) |
| 32 | +``` |
| 33 | +""" |
| 34 | +struct OpenBLASLUFactorization <: AbstractFactorization end |
| 35 | + |
| 36 | +# OpenBLAS methods - OpenBLAS_jll is always available as a standard library |
| 37 | + |
| 38 | +function openblas_getrf!(A::AbstractMatrix{<:ComplexF64}; |
| 39 | + ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))), |
| 40 | + info = Ref{BlasInt}(), |
| 41 | + check = false) |
| 42 | + require_one_based_indexing(A) |
| 43 | + check && chkfinite(A) |
| 44 | + chkstride1(A) |
| 45 | + m, n = size(A) |
| 46 | + lda = max(1, stride(A, 2)) |
| 47 | + if isempty(ipiv) |
| 48 | + ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))) |
| 49 | + end |
| 50 | + ccall((@blasfunc(zgetrf_), OpenBLAS_jll.libopenblas), Cvoid, |
| 51 | + (Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF64}, |
| 52 | + Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}), |
| 53 | + m, n, A, lda, ipiv, info) |
| 54 | + chkargsok(info[]) |
| 55 | + A, ipiv, info[], info #Error code is stored in LU factorization type |
| 56 | +end |
| 57 | + |
| 58 | +function openblas_getrf!(A::AbstractMatrix{<:ComplexF32}; |
| 59 | + ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))), |
| 60 | + info = Ref{BlasInt}(), |
| 61 | + check = false) |
| 62 | + require_one_based_indexing(A) |
| 63 | + check && chkfinite(A) |
| 64 | + chkstride1(A) |
| 65 | + m, n = size(A) |
| 66 | + lda = max(1, stride(A, 2)) |
| 67 | + if isempty(ipiv) |
| 68 | + ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))) |
| 69 | + end |
| 70 | + ccall((@blasfunc(cgetrf_), OpenBLAS_jll.libopenblas), Cvoid, |
| 71 | + (Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF32}, |
| 72 | + Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}), |
| 73 | + m, n, A, lda, ipiv, info) |
| 74 | + chkargsok(info[]) |
| 75 | + A, ipiv, info[], info #Error code is stored in LU factorization type |
| 76 | +end |
| 77 | + |
| 78 | +function openblas_getrf!(A::AbstractMatrix{<:Float64}; |
| 79 | + ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))), |
| 80 | + info = Ref{BlasInt}(), |
| 81 | + check = false) |
| 82 | + require_one_based_indexing(A) |
| 83 | + check && chkfinite(A) |
| 84 | + chkstride1(A) |
| 85 | + m, n = size(A) |
| 86 | + lda = max(1, stride(A, 2)) |
| 87 | + if isempty(ipiv) |
| 88 | + ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))) |
| 89 | + end |
| 90 | + ccall((@blasfunc(dgetrf_), OpenBLAS_jll.libopenblas), Cvoid, |
| 91 | + (Ref{BlasInt}, Ref{BlasInt}, Ptr{Float64}, |
| 92 | + Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}), |
| 93 | + m, n, A, lda, ipiv, info) |
| 94 | + chkargsok(info[]) |
| 95 | + A, ipiv, info[], info #Error code is stored in LU factorization type |
| 96 | +end |
| 97 | + |
| 98 | +function openblas_getrf!(A::AbstractMatrix{<:Float32}; |
| 99 | + ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))), |
| 100 | + info = Ref{BlasInt}(), |
| 101 | + check = false) |
| 102 | + require_one_based_indexing(A) |
| 103 | + check && chkfinite(A) |
| 104 | + chkstride1(A) |
| 105 | + m, n = size(A) |
| 106 | + lda = max(1, stride(A, 2)) |
| 107 | + if isempty(ipiv) |
| 108 | + ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))) |
| 109 | + end |
| 110 | + ccall((@blasfunc(sgetrf_), OpenBLAS_jll.libopenblas), Cvoid, |
| 111 | + (Ref{BlasInt}, Ref{BlasInt}, Ptr{Float32}, |
| 112 | + Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}), |
| 113 | + m, n, A, lda, ipiv, info) |
| 114 | + chkargsok(info[]) |
| 115 | + A, ipiv, info[], info #Error code is stored in LU factorization type |
| 116 | +end |
| 117 | + |
| 118 | +function openblas_getrs!(trans::AbstractChar, |
| 119 | + A::AbstractMatrix{<:ComplexF64}, |
| 120 | + ipiv::AbstractVector{BlasInt}, |
| 121 | + B::AbstractVecOrMat{<:ComplexF64}; |
| 122 | + info = Ref{BlasInt}()) |
| 123 | + require_one_based_indexing(A, ipiv, B) |
| 124 | + LinearAlgebra.LAPACK.chktrans(trans) |
| 125 | + chkstride1(A, B, ipiv) |
| 126 | + n = LinearAlgebra.checksquare(A) |
| 127 | + if n != size(B, 1) |
| 128 | + throw(DimensionMismatch("B has leading dimension $(size(B,1)), but needs $n")) |
| 129 | + end |
| 130 | + if n != length(ipiv) |
| 131 | + throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n")) |
| 132 | + end |
| 133 | + nrhs = size(B, 2) |
| 134 | + ccall((@blasfunc(zgetrs_), OpenBLAS_jll.libopenblas), Cvoid, |
| 135 | + (Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF64}, Ref{BlasInt}, |
| 136 | + Ptr{BlasInt}, Ptr{ComplexF64}, Ref{BlasInt}, Ptr{BlasInt}, Clong), |
| 137 | + trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info, |
| 138 | + 1) |
| 139 | + LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[])) |
| 140 | + B |
| 141 | +end |
| 142 | + |
| 143 | +function openblas_getrs!(trans::AbstractChar, |
| 144 | + A::AbstractMatrix{<:ComplexF32}, |
| 145 | + ipiv::AbstractVector{BlasInt}, |
| 146 | + B::AbstractVecOrMat{<:ComplexF32}; |
| 147 | + info = Ref{BlasInt}()) |
| 148 | + require_one_based_indexing(A, ipiv, B) |
| 149 | + LinearAlgebra.LAPACK.chktrans(trans) |
| 150 | + chkstride1(A, B, ipiv) |
| 151 | + n = LinearAlgebra.checksquare(A) |
| 152 | + if n != size(B, 1) |
| 153 | + throw(DimensionMismatch("B has leading dimension $(size(B,1)), but needs $n")) |
| 154 | + end |
| 155 | + if n != length(ipiv) |
| 156 | + throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n")) |
| 157 | + end |
| 158 | + nrhs = size(B, 2) |
| 159 | + ccall((@blasfunc(cgetrs_), OpenBLAS_jll.libopenblas), Cvoid, |
| 160 | + (Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF32}, Ref{BlasInt}, |
| 161 | + Ptr{BlasInt}, Ptr{ComplexF32}, Ref{BlasInt}, Ptr{BlasInt}, Clong), |
| 162 | + trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info, |
| 163 | + 1) |
| 164 | + LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[])) |
| 165 | + B |
| 166 | +end |
| 167 | + |
| 168 | +function openblas_getrs!(trans::AbstractChar, |
| 169 | + A::AbstractMatrix{<:Float64}, |
| 170 | + ipiv::AbstractVector{BlasInt}, |
| 171 | + B::AbstractVecOrMat{<:Float64}; |
| 172 | + info = Ref{BlasInt}()) |
| 173 | + require_one_based_indexing(A, ipiv, B) |
| 174 | + LinearAlgebra.LAPACK.chktrans(trans) |
| 175 | + chkstride1(A, B, ipiv) |
| 176 | + n = LinearAlgebra.checksquare(A) |
| 177 | + if n != size(B, 1) |
| 178 | + throw(DimensionMismatch("B has leading dimension $(size(B,1)), but needs $n")) |
| 179 | + end |
| 180 | + if n != length(ipiv) |
| 181 | + throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n")) |
| 182 | + end |
| 183 | + nrhs = size(B, 2) |
| 184 | + ccall((@blasfunc(dgetrs_), OpenBLAS_jll.libopenblas), Cvoid, |
| 185 | + (Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{Float64}, Ref{BlasInt}, |
| 186 | + Ptr{BlasInt}, Ptr{Float64}, Ref{BlasInt}, Ptr{BlasInt}, Clong), |
| 187 | + trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info, |
| 188 | + 1) |
| 189 | + LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[])) |
| 190 | + B |
| 191 | +end |
| 192 | + |
| 193 | +function openblas_getrs!(trans::AbstractChar, |
| 194 | + A::AbstractMatrix{<:Float32}, |
| 195 | + ipiv::AbstractVector{BlasInt}, |
| 196 | + B::AbstractVecOrMat{<:Float32}; |
| 197 | + info = Ref{BlasInt}()) |
| 198 | + require_one_based_indexing(A, ipiv, B) |
| 199 | + LinearAlgebra.LAPACK.chktrans(trans) |
| 200 | + chkstride1(A, B, ipiv) |
| 201 | + n = LinearAlgebra.checksquare(A) |
| 202 | + if n != size(B, 1) |
| 203 | + throw(DimensionMismatch("B has leading dimension $(size(B,1)), but needs $n")) |
| 204 | + end |
| 205 | + if n != length(ipiv) |
| 206 | + throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n")) |
| 207 | + end |
| 208 | + nrhs = size(B, 2) |
| 209 | + ccall((@blasfunc(sgetrs_), OpenBLAS_jll.libopenblas), Cvoid, |
| 210 | + (Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{Float32}, Ref{BlasInt}, |
| 211 | + Ptr{BlasInt}, Ptr{Float32}, Ref{BlasInt}, Ptr{BlasInt}, Clong), |
| 212 | + trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info, |
| 213 | + 1) |
| 214 | + LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[])) |
| 215 | + B |
| 216 | +end |
| 217 | + |
| 218 | +default_alias_A(::OpenBLASLUFactorization, ::Any, ::Any) = false |
| 219 | +default_alias_b(::OpenBLASLUFactorization, ::Any, ::Any) = false |
| 220 | + |
| 221 | +const PREALLOCATED_OPENBLAS_LU = begin |
| 222 | + A = rand(0, 0) |
| 223 | + luinst = ArrayInterface.lu_instance(A), Ref{BlasInt}() |
| 224 | +end |
| 225 | + |
| 226 | +function LinearSolve.init_cacheval(alg::OpenBLASLUFactorization, A, b, u, Pl, Pr, |
| 227 | + maxiters::Int, abstol, reltol, verbose::LinearVerbosity, |
| 228 | + assumptions::OperatorAssumptions) |
| 229 | + PREALLOCATED_OPENBLAS_LU |
| 230 | +end |
| 231 | + |
| 232 | +function LinearSolve.init_cacheval(alg::OpenBLASLUFactorization, |
| 233 | + A::AbstractMatrix{<:Union{Float32, ComplexF32, ComplexF64}}, b, u, Pl, Pr, |
| 234 | + maxiters::Int, abstol, reltol, verbose::LinearVerbosity, |
| 235 | + assumptions::OperatorAssumptions) |
| 236 | + A = rand(eltype(A), 0, 0) |
| 237 | + ArrayInterface.lu_instance(A), Ref{BlasInt}() |
| 238 | +end |
| 239 | + |
| 240 | +function SciMLBase.solve!(cache::LinearCache, alg::OpenBLASLUFactorization; |
| 241 | + kwargs...) |
| 242 | + A = cache.A |
| 243 | + A = convert(AbstractMatrix, A) |
| 244 | + if cache.isfresh |
| 245 | + cacheval = @get_cacheval(cache, :OpenBLASLUFactorization) |
| 246 | + res = openblas_getrf!(A; ipiv = cacheval[1].ipiv, info = cacheval[2]) |
| 247 | + fact = LU(res[1:3]...), res[4] |
| 248 | + cache.cacheval = fact |
| 249 | + |
| 250 | + if !LinearAlgebra.issuccess(fact[1]) |
| 251 | + return SciMLBase.build_linear_solution( |
| 252 | + alg, cache.u, nothing, cache; retcode = ReturnCode.Failure) |
| 253 | + end |
| 254 | + cache.isfresh = false |
| 255 | + end |
| 256 | + |
| 257 | + A, info = @get_cacheval(cache, :OpenBLASLUFactorization) |
| 258 | + require_one_based_indexing(cache.u, cache.b) |
| 259 | + m, n = size(A, 1), size(A, 2) |
| 260 | + if m > n |
| 261 | + Bc = copy(cache.b) |
| 262 | + openblas_getrs!('N', A.factors, A.ipiv, Bc; info) |
| 263 | + copyto!(cache.u, 1, Bc, 1, n) |
| 264 | + else |
| 265 | + copyto!(cache.u, cache.b) |
| 266 | + openblas_getrs!('N', A.factors, A.ipiv, cache.u; info) |
| 267 | + end |
| 268 | + |
| 269 | + SciMLBase.build_linear_solution( |
| 270 | + alg, cache.u, nothing, cache; retcode = ReturnCode.Success) |
| 271 | +end |
0 commit comments