Skip to content

Commit 54b7bec

Browse files
Setup Accelerate and MKL for 32-bit, MKL getrf, fix Metal
1 parent 50cf341 commit 54b7bec

File tree

3 files changed

+112
-5
lines changed

3 files changed

+112
-5
lines changed

ext/LinearSolveMKLExt.jl

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,63 @@ function getrf!(A::AbstractMatrix{<:Float64}; ipiv = similar(A, BlasInt, min(siz
2727
A, ipiv, info[], info #Error code is stored in LU factorization type
2828
end
2929

30+
function getrf!(A::AbstractMatrix{<:Float32}; ipiv = similar(A, BlasInt, min(size(A,1),size(A,2))), info = Ref{BlasInt}(), check = false)
31+
require_one_based_indexing(A)
32+
check && chkfinite(A)
33+
chkstride1(A)
34+
m, n = size(A)
35+
lda = max(1,stride(A, 2))
36+
if isempty(ipiv)
37+
ipiv = similar(A, BlasInt, min(size(A,1),size(A,2)))
38+
end
39+
ccall((@blasfunc(sgetrf_), MKL_jll.libmkl_rt), Cvoid,
40+
(Ref{BlasInt}, Ref{BlasInt}, Ptr{Float32},
41+
Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
42+
m, n, A, lda, ipiv, info)
43+
chkargsok(info[])
44+
A, ipiv, info[], info #Error code is stored in LU factorization type
45+
end
46+
47+
function getrs!(trans::AbstractChar, A::AbstractMatrix{<:Float64}, ipiv::AbstractVector{Cint}, B::AbstractVecOrMat{<:Float64}; info = Ref{Cint}())
48+
require_one_based_indexing(A, ipiv, B)
49+
LinearAlgebra.LAPACK.chktrans(trans)
50+
chkstride1(A, B, ipiv)
51+
n = LinearAlgebra.checksquare(A)
52+
if n != size(B, 1)
53+
throw(DimensionMismatch("B has leading dimension $(size(B,1)), but needs $n"))
54+
end
55+
if n != length(ipiv)
56+
throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n"))
57+
end
58+
nrhs = size(B, 2)
59+
ccall(("dgetrs_", MKL_jll.libmkl_rt), Cvoid,
60+
(Ref{UInt8}, Ref{Cint}, Ref{Cint}, Ptr{Float64}, Ref{Cint},
61+
Ptr{Cint}, Ptr{Float64}, Ref{Cint}, Ptr{Cint}, Clong),
62+
trans, n, size(B,2), A, max(1,stride(A,2)), ipiv, B, max(1,stride(B,2)), info, 1)
63+
LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[]))
64+
B
65+
end
66+
67+
function getrs!(trans::AbstractChar, A::AbstractMatrix{<:Float32}, ipiv::AbstractVector{Cint}, B::AbstractVecOrMat{<:Float32}; info = Ref{Cint}())
68+
require_one_based_indexing(A, ipiv, B)
69+
LinearAlgebra.LAPACK.chktrans(trans)
70+
chkstride1(A, B, ipiv)
71+
n = LinearAlgebra.checksquare(A)
72+
if n != size(B, 1)
73+
throw(DimensionMismatch("B has leading dimension $(size(B,1)), but needs $n"))
74+
end
75+
if n != length(ipiv)
76+
throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n"))
77+
end
78+
nrhs = size(B, 2)
79+
ccall(("sgetrs_", MKL_jll.libmkl_rt), Cvoid,
80+
(Ref{UInt8}, Ref{Cint}, Ref{Cint}, Ptr{Float32}, Ref{Cint},
81+
Ptr{Cint}, Ptr{Float32}, Ref{Cint}, Ptr{Cint}, Clong),
82+
trans, n, size(B,2), A, max(1,stride(A,2)), ipiv, B, max(1,stride(B,2)), info, 1)
83+
LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[]))
84+
B
85+
end
86+
3087
default_alias_A(::MKLLUFactorization, ::Any, ::Any) = false
3188
default_alias_b(::MKLLUFactorization, ::Any, ::Any) = false
3289

@@ -47,8 +104,20 @@ function SciMLBase.solve!(cache::LinearCache, alg::MKLLUFactorization;
47104
cache.cacheval = fact
48105
cache.isfresh = false
49106
end
50-
y = ldiv!(cache.u, @get_cacheval(cache, :MKLLUFactorization)[1], cache.b)
51-
SciMLBase.build_linear_solution(alg, y, nothing, cache)
107+
108+
A, info = @get_cacheval(cache, :MKLLUFactorization)
109+
LinearAlgebra.require_one_based_indexing(cache.u, cache.b)
110+
m, n = size(A, 1), size(A, 2)
111+
if m > n
112+
Bc = copy(cache.b)
113+
getrs!('N', A.factors, A.ipiv, Bc; info)
114+
return copyto!(cache.u, 1, Bc, 1, n)
115+
else
116+
copyto!(cache.u, cache.b)
117+
getrs!('N', A.factors, A.ipiv, cache.u; info)
118+
end
119+
120+
SciMLBase.build_linear_solution(alg, cache.u, nothing, cache)
52121
end
53122

54123
end

ext/LinearSolveMetalExt.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ default_alias_b(::MetalLUFactorization, ::Any, ::Any) = false
1111
function LinearSolve.init_cacheval(alg::MetalLUFactorization, A, b, u, Pl, Pr,
1212
maxiters::Int, abstol, reltol, verbose::Bool,
1313
assumptions::OperatorAssumptions)
14-
ArrayInterface.lu_instance(convert(AbstractMatrix, MtlArray(A)))
14+
ArrayInterface.lu_instance(convert(AbstractMatrix, A))
1515
end
1616

1717
function SciMLBase.solve!(cache::LinearCache, alg::MetalLUFactorization;
@@ -21,10 +21,10 @@ function SciMLBase.solve!(cache::LinearCache, alg::MetalLUFactorization;
2121
if cache.isfresh
2222
cacheval = @get_cacheval(cache, :MetalLUFactorization)
2323
res = lu(MtlArray(A))
24-
cache.cacheval = fact
24+
cache.cacheval = LU(Array(res.factors), Array{Int}(res.ipiv), res.info)
2525
cache.isfresh = false
2626
end
27-
y = Array(ldiv!(MtlArray(cache.u), @get_cacheval(cache, :MetalLUFactorization), MtlArray(cache.b)))
27+
y = ldiv!(cache.u, @get_cacheval(cache, :MetalLUFactorization), cache.b)
2828
SciMLBase.build_linear_solution(alg, y, nothing, cache)
2929
end
3030

src/appleaccelerate.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,24 @@ function aa_getrf!(A::AbstractMatrix{<:Float64}; ipiv = similar(A, Cint, min(siz
4444
A, ipiv, BlasInt(info[]), info #Error code is stored in LU factorization type
4545
end
4646

47+
function aa_getrf!(A::AbstractMatrix{<:Float32}; ipiv = similar(A, Cint, min(size(A,1),size(A,2))), info = Ref{Cint}(), check = false)
48+
require_one_based_indexing(A)
49+
check && chkfinite(A)
50+
chkstride1(A)
51+
m, n = size(A)
52+
lda = max(1,stride(A, 2))
53+
if isempty(ipiv)
54+
ipiv = similar(A, Cint, min(size(A,1),size(A,2)))
55+
end
56+
57+
ccall(("sgetrf_", libacc), Cvoid,
58+
(Ref{Cint}, Ref{Cint}, Ptr{Float32},
59+
Ref{Cint}, Ptr{Cint}, Ptr{Cint}),
60+
m, n, A, lda, ipiv, info)
61+
info[] < 0 && throw(ArgumentError("Invalid arguments sent to LAPACK dgetrf_"))
62+
A, ipiv, BlasInt(info[]), info #Error code is stored in LU factorization type
63+
end
64+
4765
function aa_getrs!(trans::AbstractChar, A::AbstractMatrix{<:Float64}, ipiv::AbstractVector{Cint}, B::AbstractVecOrMat{<:Float64}; info = Ref{Cint}())
4866
require_one_based_indexing(A, ipiv, B)
4967
LinearAlgebra.LAPACK.chktrans(trans)
@@ -64,6 +82,26 @@ function aa_getrs!(trans::AbstractChar, A::AbstractMatrix{<:Float64}, ipiv::Abst
6482
B
6583
end
6684

85+
function aa_getrs!(trans::AbstractChar, A::AbstractMatrix{<:Float32}, ipiv::AbstractVector{Cint}, B::AbstractVecOrMat{<:Float32}; info = Ref{Cint}())
86+
require_one_based_indexing(A, ipiv, B)
87+
LinearAlgebra.LAPACK.chktrans(trans)
88+
chkstride1(A, B, ipiv)
89+
n = LinearAlgebra.checksquare(A)
90+
if n != size(B, 1)
91+
throw(DimensionMismatch("B has leading dimension $(size(B,1)), but needs $n"))
92+
end
93+
if n != length(ipiv)
94+
throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n"))
95+
end
96+
nrhs = size(B, 2)
97+
ccall(("sgetrs_", libacc), Cvoid,
98+
(Ref{UInt8}, Ref{Cint}, Ref{Cint}, Ptr{Float32}, Ref{Cint},
99+
Ptr{Cint}, Ptr{Float32}, Ref{Cint}, Ptr{Cint}, Clong),
100+
trans, n, size(B,2), A, max(1,stride(A,2)), ipiv, B, max(1,stride(B,2)), info, 1)
101+
LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[]))
102+
B
103+
end
104+
67105
default_alias_A(::AppleAccelerateLUFactorization, ::Any, ::Any) = false
68106
default_alias_b(::AppleAccelerateLUFactorization, ::Any, ::Any) = false
69107

0 commit comments

Comments
 (0)