Skip to content

Commit 22c377d

Browse files
committed
[oneMKL] Interface lapack routines
1 parent d2a810d commit 22c377d

File tree

7 files changed

+1001
-219
lines changed

7 files changed

+1001
-219
lines changed

deps/src/onemkl.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2291,25 +2291,21 @@ extern "C" int64_t onemklZgeqrf_scratchpad_size(syclQueue_t device_queue, int64_
22912291

22922292
extern "C" int onemklCgeqrf(syclQueue_t device_queue, int64_t m, int64_t n, float _Complex *a, int64_t lda, float _Complex *tau, float _Complex *scratchpad, int64_t scratchpad_size) {
22932293
auto status = oneapi::mkl::lapack::geqrf(device_queue->val, m, n, reinterpret_cast<std::complex<float>*>(a), lda, reinterpret_cast<std::complex<float>*>(tau), reinterpret_cast<std::complex<float>*>(scratchpad), scratchpad_size, {});
2294-
__FORCE_MKL_FLUSH__(status);
22952294
return 0;
22962295
}
22972296

22982297
extern "C" int onemklDgeqrf(syclQueue_t device_queue, int64_t m, int64_t n, double *a, int64_t lda, double *tau, double *scratchpad, int64_t scratchpad_size) {
22992298
auto status = oneapi::mkl::lapack::geqrf(device_queue->val, m, n, a, lda, tau, scratchpad, scratchpad_size, {});
2300-
__FORCE_MKL_FLUSH__(status);
23012299
return 0;
23022300
}
23032301

23042302
extern "C" int onemklSgeqrf(syclQueue_t device_queue, int64_t m, int64_t n, float *a, int64_t lda, float *tau, float *scratchpad, int64_t scratchpad_size) {
23052303
auto status = oneapi::mkl::lapack::geqrf(device_queue->val, m, n, a, lda, tau, scratchpad, scratchpad_size, {});
2306-
__FORCE_MKL_FLUSH__(status);
23072304
return 0;
23082305
}
23092306

23102307
extern "C" int onemklZgeqrf(syclQueue_t device_queue, int64_t m, int64_t n, double _Complex *a, int64_t lda, double _Complex *tau, double _Complex *scratchpad, int64_t scratchpad_size) {
23112308
auto status = oneapi::mkl::lapack::geqrf(device_queue->val, m, n, reinterpret_cast<std::complex<double>*>(a), lda, reinterpret_cast<std::complex<double>*>(tau), reinterpret_cast<std::complex<double>*>(scratchpad), scratchpad_size, {});
2312-
__FORCE_MKL_FLUSH__(status);
23132309
return 0;
23142310
}
23152311

lib/mkl/oneMKL.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ using ..SYCL: syclQueue_t
1212
using GPUArrays
1313

1414
using LinearAlgebra
15+
using LinearAlgebra: checksquare
16+
using LinearAlgebra.LAPACK: chkargsok, chklapackerror, chktrans, chkside, chkdiag, chkuplo
17+
1518
using SparseArrays
1619

1720
# Exclude Float16 for now, since many oneMKL functions - copy, scal, do not take Float16
@@ -21,6 +24,7 @@ const onemklHalf = Union{Float16,ComplexF16}
2124

2225
include("utils.jl")
2326
include("wrappers_blas.jl")
27+
include("wrappers_lapack.jl")
2428
include("wrappers_sparse.jl")
2529
include("linalg.jl")
2630

lib/mkl/wrappers_blas.jl

Lines changed: 115 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -175,10 +175,10 @@ function symm(side::Char,
175175
end
176176

177177
## syrk
178-
for (fname, elty) in ((:onemklDsyrk,:Float64),
179-
(:onemklSsyrk,:Float32),
180-
(:onemklCsyrk,:ComplexF32),
181-
(:onemklZsyrk,:ComplexF64))
178+
for (fname, elty) in ((:onemklSsyrk, :Float32),
179+
(:onemklDsyrk, :Float64),
180+
(:onemklCsyrk, :ComplexF32),
181+
(:onemklZsyrk, :ComplexF64))
182182
@eval begin
183183
function syrk!(uplo::Char,
184184
trans::Char,
@@ -703,10 +703,28 @@ for (fname, elty) in ((:onemklSger, :Float32),
703703
end
704704
end
705705

706+
# spr
707+
for (fname, elty) in ((:onemklSspr, :Float32),
708+
(:onemklDspr, :Float64))
709+
@eval begin
710+
function spr!(uplo::Char,
711+
alpha::Number,
712+
x::oneStridedVector{$elty},
713+
A::oneStridedVector{$elty})
714+
n = round(Int, (sqrt(8*length(A))-1)/2)
715+
length(x) == n || throw(DimensionMismatch("Length of vector must be the same as the matrix dimensions"))
716+
incx = stride(x,1)
717+
queue = global_queue(context(x), device(x))
718+
$fname(sycl_queue(queue), uplo, n, alpha, x, incx, A)
719+
A
720+
end
721+
end
722+
end
723+
706724
#symv
707725
for (fname, elty) in ((:onemklSsymv,:Float32),
708726
(:onemklDsymv,:Float64))
709-
# Note that the complex symv are not BLAS but auiliary functions in LAPACK
727+
# Note that the complex symv are not BLAS but auxiliary functions in LAPACK
710728
@eval begin
711729
function symv!(uplo::Char,
712730
alpha::Number,
@@ -898,6 +916,67 @@ function gbmv(trans::Char,
898916
gbmv(trans, m, kl, ku, one(T), a, x)
899917
end
900918

919+
# spmv
920+
for (fname, elty) in ((:onemklSspmv, :Float32),
921+
(:onemklDspmv, :Float64))
922+
@eval begin
923+
function spmv!(uplo::Char,
924+
alpha::Number,
925+
A::oneStridedVector{$elty},
926+
x::oneStridedVector{$elty},
927+
beta::Number,
928+
y::oneStridedVector{$elty})
929+
n = round(Int, (sqrt(8*length(A))-1)/2)
930+
if n != length(x) || n != length(y)
931+
throw(DimensionMismatch(""))
932+
end
933+
incx = stride(x,1)
934+
incy = stride(y,1)
935+
queue = global_queue(context(x), device(x))
936+
$fname(sycl_queue(queue), uplo, n, alpha, A, x, incx, beta, y, incy)
937+
y
938+
end
939+
end
940+
end
941+
942+
function spmv(uplo::Char, alpha::Number,
943+
A::oneStridedVector{T}, x::oneStridedVector{T}) where T
944+
spmv!(uplo, alpha, A, x, zero(T), similar(x))
945+
end
946+
947+
function spmv(uplo::Char, A::oneStridedVector{T}, x::oneStridedVector{T}) where T
948+
spmv(uplo, one(T), A, x)
949+
end
950+
951+
# tbsv, (TB) triangular banded matrix solve
952+
for (fname, elty) in ((:onemklStbsv, :Float32),
953+
(:onemklDtbsv, :Float64),
954+
(:onemklCtbsv, :ComplexF32),
955+
(:onemklZtbsv, :ComplexF64))
956+
@eval begin
957+
function tbsv!(uplo::Char,
958+
trans::Char,
959+
diag::Char,
960+
k::Integer,
961+
A::oneStridedMatrix{$elty},
962+
x::oneStridedVector{$elty})
963+
m, n = size(A)
964+
if !(1<=(1+k)<=n) throw(DimensionMismatch("Incorrect number of bands")) end
965+
if m < 1+k throw(DimensionMismatch("Array A has fewer than 1+k rows")) end
966+
if n != length(x) throw(DimensionMismatch("")) end
967+
lda = max(1,stride(A,2))
968+
incx = stride(x,1)
969+
queue = global_queue(context(x), device(x))
970+
$fname(sycl_queue(queue), uplo, trans, diag, n, k, A, lda, x, incx)
971+
x
972+
end
973+
end
974+
end
975+
function tbsv(uplo::Char, trans::Char, diag::Char, k::Integer,
976+
A::oneStridedMatrix{T}, x::oneStridedVector{T}) where T
977+
tbsv!(uplo, trans, diag, k, A, copy(x))
978+
end
979+
901980
# tbmv
902981
### tbmv, (TB) triangular banded matrix-vector multiplication
903982
for (fname, elty) in ((:onemklStbmv,:Float32),
@@ -1150,6 +1229,37 @@ function gemm(transA::Char,
11501229
B::oneStridedVecOrMat{T}) where T
11511230
gemm(transA, transB, one(T), A, B)
11521231
end
1232+
1233+
## dgmm
1234+
for (fname, elty) in ((:onemklSdgmm, :Float32),
1235+
(:onemklDdgmm, :Float64),
1236+
(:onemklCdgmm, :ComplexF32),
1237+
(:onemklZdgmm, :ComplexF64))
1238+
@eval begin
1239+
function dgmm!(mode::Char,
1240+
A::oneStridedMatrix{$elty},
1241+
X::oneStridedVector{$elty},
1242+
C::oneStridedMatrix{$elty})
1243+
m, n = size(C)
1244+
mA, nA = size(A)
1245+
lx = length(X)
1246+
if ((mA != m) || (nA != n )) throw(DimensionMismatch("")) end
1247+
if ((mode == 'L') && (lx != m)) throw(DimensionMismatch("")) end
1248+
if ((mode == 'R') && (lx != n)) throw(DimensionMismatch("")) end
1249+
lda = max(1,stride(A,2))
1250+
incx = stride(X,1)
1251+
ldc = max(1,stride(C,2))
1252+
queue = global_queue(context(A), device(A))
1253+
$fname(sycl_queue(queue), mode, m, n, A, lda, X, incx, C, ldc)
1254+
C
1255+
end
1256+
end
1257+
end
1258+
function dgmm(mode::Char, A::oneStridedMatrix{T}, X::oneStridedVector{T}) where T
1259+
m,n = size(A)
1260+
dgmm!( mode, A, X, similar(A, (m,n) ) )
1261+
end
1262+
11531263
for (fname, elty) in
11541264
((:onemklSgemmBatchStrided, Float32),
11551265
(:onemklDgemmBatchStrided, Float64),

0 commit comments

Comments
 (0)