Skip to content

Commit f78754f

Browse files
handle the getrs as well and support more lapack installations
1 parent 94be106 commit f78754f

File tree

2 files changed

+44
-9
lines changed

2 files changed

+44
-9
lines changed

ext/LinearSolveMKLExt.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ function getrf!(A::AbstractMatrix{<:Float64}; ipiv = similar(A, BlasInt, min(siz
2424
Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
2525
m, n, A, lda, ipiv, info)
2626
chkargsok(info[])
27-
A, ipiv, info[] #Error code is stored in LU factorization type
27+
A, ipiv, info[], info #Error code is stored in LU factorization type
2828
end
2929

3030
default_alias_A(::MKLLUFactorization, ::Any, ::Any) = false
@@ -33,7 +33,7 @@ default_alias_b(::MKLLUFactorization, ::Any, ::Any) = false
3333
function LinearSolve.init_cacheval(alg::MKLLUFactorization, A, b, u, Pl, Pr,
3434
maxiters::Int, abstol, reltol, verbose::Bool,
3535
assumptions::OperatorAssumptions)
36-
ArrayInterface.lu_instance(convert(AbstractMatrix, A))
36+
ArrayInterface.lu_instance(convert(AbstractMatrix, A)), Ref{BlasInt}()
3737
end
3838

3939
function SciMLBase.solve!(cache::LinearCache, alg::MKLLUFactorization;
@@ -42,11 +42,12 @@ function SciMLBase.solve!(cache::LinearCache, alg::MKLLUFactorization;
4242
A = convert(AbstractMatrix, A)
4343
if cache.isfresh
4444
cacheval = @get_cacheval(cache, :MKLLUFactorization)
45-
fact = LU(getrf!(A; ipiv = cacheval.ipiv)...)
45+
res = getrf!(A; ipiv = cacheval[1].ipiv, info = cacheval[2])
46+
fact = LU(res[1:3]...), res[4]
4647
cache.cacheval = fact
4748
cache.isfresh = false
4849
end
49-
y = ldiv!(cache.u, @get_cacheval(cache, :MKLLUFactorization), cache.b)
50+
y = ldiv!(cache.u, @get_cacheval(cache, :MKLLUFactorization)[1], cache.b)
5051
SciMLBase.build_linear_solution(alg, y, nothing, cache)
5152
end
5253

src/appleaccelerate.jl

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,27 @@ function aa_getrf!(A::AbstractMatrix{<:Float64}; ipiv = similar(A, Cint, min(siz
4141
Ref{Cint}, Ptr{Cint}, Ptr{Cint}),
4242
m, n, A, lda, ipiv, info)
4343
info[] < 0 && throw(ArgumentError("Invalid arguments sent to LAPACK dgetrf_"))
44-
A, Vector{BlasInt}(ipiv), BlasInt(info[]) #Error code is stored in LU factorization type
44+
A, ipiv, BlasInt(info[]), info #Error code is stored in LU factorization type
45+
end
46+
47+
function aa_getrs!(trans::AbstractChar, A::AbstractMatrix{<:Float64}, ipiv::AbstractVector{Cint}, B::AbstractVecOrMat{<:Float64}; info = Ref{Cint}())
48+
require_one_based_indexing(A, ipiv, B)
49+
LinearAlgebra.LAPACK.chktrans(trans)
50+
chkstride1(A, B, ipiv)
51+
n = LinearAlgebra.checksquare(A)
52+
if n != size(B, 1)
53+
throw(DimensionMismatch("B has leading dimension $(size(B,1)), but needs $n"))
54+
end
55+
if n != length(ipiv)
56+
throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n"))
57+
end
58+
nrhs = size(B, 2)
59+
ccall(("dgetrs_", libacc), Cvoid,
60+
(Ref{UInt8}, Ref{Cint}, Ref{Cint}, Ptr{Float64}, Ref{Cint},
61+
Ptr{Cint}, Ptr{Float64}, Ref{Cint}, Ptr{Cint}, Clong),
62+
trans, n, size(B,2), A, max(1,stride(A,2)), ipiv, B, max(1,stride(B,2)), info, 1)
63+
LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[]))
64+
B
4565
end
4666

4767
default_alias_A(::AppleAccelerateLUFactorization, ::Any, ::Any) = false
@@ -50,7 +70,8 @@ default_alias_b(::AppleAccelerateLUFactorization, ::Any, ::Any) = false
5070
function LinearSolve.init_cacheval(alg::AppleAccelerateLUFactorization, A, b, u, Pl, Pr,
5171
maxiters::Int, abstol, reltol, verbose::Bool,
5272
assumptions::OperatorAssumptions)
53-
ArrayInterface.lu_instance(convert(AbstractMatrix, A))
73+
luinst = ArrayInterface.lu_instance(convert(AbstractMatrix, A))
74+
LU(luinst.factors,similar(A, Cint, 0), luinst.info), Ref{Cint}()
5475
end
5576

5677
function SciMLBase.solve!(cache::LinearCache, alg::AppleAccelerateLUFactorization;
@@ -59,10 +80,23 @@ function SciMLBase.solve!(cache::LinearCache, alg::AppleAccelerateLUFactorizatio
5980
A = convert(AbstractMatrix, A)
6081
if cache.isfresh
6182
cacheval = @get_cacheval(cache, :AppleAccelerateLUFactorization)
62-
fact = LU(aa_getrf!(A; ipiv = cacheval.ipiv)...)
83+
res = aa_getrf!(A; ipiv = cacheval[1].ipiv, info = cacheval[2])
84+
fact = LU(res[1:3]...), res[4]
6385
cache.cacheval = fact
6486
cache.isfresh = false
6587
end
66-
y = ldiv!(cache.u, @get_cacheval(cache, :AppleAccelerateLUFactorization), cache.b)
67-
SciMLBase.build_linear_solution(alg, y, nothing, cache)
88+
89+
A, info = @get_cacheval(cache, :AppleAccelerateLUFactorization)
90+
LinearAlgebra.require_one_based_indexing(cache.u, cache.b)
91+
m, n = size(A, 1), size(A, 2)
92+
if m > n
93+
Bc = copy(cache.b)
94+
aa_getrs!('N', A.factors, A.ipiv, Bc; info)
95+
return copyto!(cache.u, 1, Bc, 1, n)
96+
else
97+
copyto!(cache.u, cache.b)
98+
aa_getrs!('N', A.factors, A.ipiv, cache.u; info)
99+
end
100+
101+
SciMLBase.build_linear_solution(alg, cache.u, nothing, cache)
68102
end

0 commit comments

Comments
 (0)