Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions deps/src/onemkl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,33 @@ oneapi::mkl::transpose convert(onemklTranspose val) {
}
}

oneapi::mkl::side convert(onemklSide val) {
switch (val) {
case ONEMKL_SIDE_LEFT:
return oneapi::mkl::side::left;
case ONEMKL_SIDE_RIGHT:
return oneapi::mkl::side::right;
}
}

oneapi::mkl::uplo convert(onemklUplo val) {
switch(val) {
case ONEMKL_UPLO_UPPER:
return oneapi::mkl::uplo::upper;
case ONEMKL_UPLO_LOWER:
return oneapi::mkl::uplo::lower;
}
}

oneapi::mkl::diag convert(onemklDiag val) {
switch(val) {
case ONEMKL_DIAG_NONUNIT:
return oneapi::mkl::diag::nonunit;
case ONEMKL_DIAG_UNIT:
return oneapi::mkl::diag::unit;
}
}

extern "C" int onemklHgemm(syclQueue_t device_queue, onemklTranspose transA,
onemklTranspose transB, int64_t m, int64_t n,
int64_t k, sycl::half alpha, const sycl::half *A, int64_t lda,
Expand Down Expand Up @@ -87,6 +114,42 @@ extern "C" int onemklZgemm(syclQueue_t device_queue, onemklTranspose transA,
return 0;
}

extern "C" void onemklStbsv(syclQueue_t device_queue, onemklUplo uplo,
onemklTranspose trans, onemklDiag diag, int64_t n,
int64_t k, const float *a, int64_t lda, float *x, int64_t incx) {
auto status = oneapi::mkl::blas::column_major::tbsv(device_queue->val, convert(uplo), convert(trans),
convert(diag), n, k, a, lda, x, incx);
__FORCE_MKL_FLUSH__(status);
}

extern "C" void onemklDtbsv(syclQueue_t device_queue, onemklUplo uplo,
onemklTranspose trans, onemklDiag diag, int64_t n,
int64_t k, const double *a, int64_t lda, double *x, int64_t incx) {
auto status = oneapi::mkl::blas::column_major::tbsv(device_queue->val, convert(uplo), convert(trans),
convert(diag), n, k, a, lda, x, incx);
__FORCE_MKL_FLUSH__(status);
}

extern "C" void onemklCtbsv(syclQueue_t device_queue, onemklUplo uplo,
onemklTranspose trans, onemklDiag diag, int64_t n,
int64_t k, const float _Complex *a, int64_t lda, float _Complex *x,
int64_t incx) {
auto status = oneapi::mkl::blas::column_major::tbsv(device_queue->val, convert(uplo), convert(trans),
convert(diag), n, k, reinterpret_cast<const std::complex<float> *>(a),
lda, reinterpret_cast<std::complex<float> *>(x), incx);
__FORCE_MKL_FLUSH__(status);
}

extern "C" void onemklZtbsv(syclQueue_t device_queue, onemklUplo uplo,
onemklTranspose trans, onemklDiag diag, int64_t n,
int64_t k, const double _Complex *a, int64_t lda, double _Complex *x,
int64_t incx) {
auto status = oneapi::mkl::blas::column_major::tbsv(device_queue->val, convert(uplo), convert(trans),
convert(diag), n, k, reinterpret_cast<const std::complex<double> *>(a),
lda, reinterpret_cast<std::complex<double> *>(x), incx);
__FORCE_MKL_FLUSH__(status);
}

extern "C" void onemklSasum(syclQueue_t device_queue, int64_t n,
const float *x, int64_t incx,
float *result) {
Expand Down
30 changes: 30 additions & 0 deletions deps/src/onemkl.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,21 @@ typedef enum {
ONEMLK_TRANSPOSE_CONJTRANS
} onemklTranspose;

typedef enum {
ONEMKL_UPLO_UPPER,
ONEMKL_UPLO_LOWER
} onemklUplo;

typedef enum {
ONEMKL_DIAG_NONUNIT,
ONEMKL_DIAG_UNIT
} onemklDiag;

typedef enum {
ONEMKL_SIDE_LEFT,
ONEMKL_SIDE_RIGHT
} onemklSide;

// XXX: how to expose half in C?
// int onemklHgemm(syclQueue_t device_queue, onemklTranspose transA,
// onemklTranspose transB, int64_t m, int64_t n, int64_t k,
Expand All @@ -39,6 +54,21 @@ int onemklZgemm(syclQueue_t device_queue, onemklTranspose transA,
const double _Complex *B, int64_t ldb, double _Complex beta,
double _Complex *C, int64_t ldc);

void onemklStbsv(syclQueue_t device_queue, onemklUplo uplo,
onemklTranspose trans, onemklDiag diag, int64_t n,
int64_t k, const float *a, int64_t lda, float *x, int64_t incx);
void onemklDtbsv(syclQueue_t device_queue, onemklUplo uplo,
onemklTranspose trans, onemklDiag diag, int64_t n,
int64_t k, const double *a, int64_t lda, double *x, int64_t incx);
void onemklCtbsv(syclQueue_t device_queue, onemklUplo uplo,
onemklTranspose trans, onemklDiag diag, int64_t n,
int64_t k, const float _Complex *a, int64_t lda, float _Complex *x,
int64_t incx);
void onemklZtbsv(syclQueue_t device_queue, onemklUplo uplo,
onemklTranspose trans, onemklDiag diag, int64_t n,
int64_t k, const double _Complex *a, int64_t lda, double _Complex *x,
int64_t incx);

void onemklSasum(syclQueue_t device_queue, int64_t n,
const float *x, int64_t incx, float *result);
void onemklDasum(syclQueue_t device_queue, int64_t n,
Expand Down
43 changes: 43 additions & 0 deletions lib/mkl/libonemkl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,21 @@ using CEnum
ONEMLK_TRANSPOSE_CONJTRANS = 2
end

@cenum onemklUplo::UInt32 begin
ONEMKL_UPLO_UPPER = 0
ONEMKL_UPLO_LOWER = 1
end

@cenum onemklDiag::UInt32 begin
ONEMKL_DIAG_NONUNIT = 0
ONEMKL_DIAG_UNIT = 1
end

@cenum onemklSide::UInt32 begin
ONEMKL_SIDE_LEFT = 0
ONEMKL_SIDE_RIGHT = 1
end

function onemklSgemm(device_queue, transA, transB, m, n, k, alpha, A, lda, B, ldb, beta, C,
ldc)
@ccall liboneapi_support.onemklSgemm(device_queue::syclQueue_t, transA::onemklTranspose,
Expand Down Expand Up @@ -42,6 +57,34 @@ function onemklZgemm(device_queue, transA, transB, m, n, k, alpha, A, lda, B, ld
C::ZePtr{ComplexF64}, ldc::Int64)::Cint
end

function onemklStbsv(device_queue, uplo, trans, diag, n, k, a, lda, x, incx)
@ccall liboneapi_support.onemklStbsv(device_queue::syclQueue_t, uplo::onemklUplo,
trans::onemklTranspose, diag::onemklDiag, n::Int64,
k::Int64, a::ZePtr{Cfloat}, lda::Int64, x::ZePtr{Cfloat},
incx::Int64)::Cvoid
end

function onemklDtbsv(device_queue, uplo, trans, diag, n, k, a, lda, x, incx)
@ccall liboneapi_support.onemklDtbsv(device_queue::syclQueue_t, uplo::onemklUplo,
trans::onemklTranspose, diag::onemklDiag, n::Int64,
k::Int64, a::ZePtr{Cdouble}, lda::Int64, x::ZePtr{Cdouble},
incx::Int64)::Cvoid
end

function onemklCtbsv(device_queue, uplo, trans, diag, n, k, a, lda, x, incx)
@ccall liboneapi_support.onemklCtbsv(device_queue::syclQueue_t, uplo::onemklUplo,
trans::onemklTranspose, diag::onemklDiag, n::Int64,
k::Int64, a::ZePtr{ComplexF32}, lda::Int64, x::ZePtr{ComplexF32},
incx::Int64)::Cvoid
end

function onemklZtbsv(device_queue, uplo, trans, diag, n, k, a, lda, x, incx)
@ccall liboneapi_support.onemklZtbsv(device_queue::syclQueue_t, uplo::onemklUplo,
trans::onemklTranspose, diag::onemklDiag, n::Int64,
k::Int64, a::ZePtr{ComplexF64}, lda::Int64, x::ZePtr{ComplexF64},
incx::Int64)::Cvoid
end

function onemklSasum(device_queue, n, x, incx, result)
@ccall liboneapi_support.onemklSasum(device_queue::syclQueue_t, n::Int64,
x::ZePtr{Cfloat}, incx::Int64,
Expand Down
31 changes: 31 additions & 0 deletions lib/mkl/oneMKL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,35 @@ const onemklFloat = Union{Float64,Float32,ComplexF64,ComplexF32}
include("wrappers.jl")
include("linalg.jl")

function band(A::StridedArray, kl, ku)
m, n = size(A)
AB = zeros(eltype(A),kl+ku+1,n)
for j = 1:n
for i = max(1,j-ku):min(m,j+kl)
AB[ku+1-j+i,j] = A[i,j]
end
end
return AB
end

# convert band storage to general matrix
function unband(AB::StridedArray,m,kl,ku)
bm, n = size(AB)
A = zeros(eltype(AB),m,n)
for j = 1:n
for i = max(1,j-ku):min(m,j+kl)
A[i,j] = AB[ku+1-j+i,j]
end
end
return A
end

# zero out elements not on matrix bands
function bandex(A::AbstractMatrix,kl,ku)
m, n = size(A)
AB = band(A,kl,ku)
B = unband(AB,m,kl,ku)
return B
end

end
65 changes: 65 additions & 0 deletions lib/mkl/wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,36 @@ function Base.convert(::Type{onemklTranspose}, trans::Char)
end
end

function Base.convert(::Type{onemklSide}, side::Char)
if side == 'L'
return ONEMKL_SIDE_LEFT
elseif side == 'R'
return ONEMKL_SIDE_RIGHT
else
throw(ArgumentError("Unknown transpose $side"))
end
end

function Base.convert(::Type{onemklUplo}, uplo::Char)
if uplo == 'U'
return ONEMKL_UPLO_UPPER
elseif uplo == 'L'
return ONEMKL_UPLO_LOWER
else
throw(ArgumentError("Unknown uplo $uplo"))
end
end

function Base.convert(::Type{onemklDiag}, diag::Char)
if diag == 'N'
return ONEMKL_DIAG_NONUNIT
elseif diag == 'U'
return ONEMKL_DIAG_UNIT
else
throw(ArgumentError("Unknown transpose $diag"))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change to "Unknown diag $diag"

end
end

# level 1
## axpy primitive
for (fname, elty) in
Expand Down Expand Up @@ -226,3 +256,38 @@ for (fname, elty) in
end
end
end

### tbsv, (TB) triangular banded matrix solve
for (fname, elty) in ((:onemklStbsv,:Float32),
(:onemklDtbsv,:Float64),
(:onemklCtbsv,:ComplexF32),
(:onemklZtbsv,:ComplexF64))
@eval begin
function tbsv!(uplo::Char,
trans::Char,
diag::Char,
k::Integer,
A::oneStridedVecOrMat{$elty},
x::oneStridedVecOrMat{$elty})
m, n = size(A)
if !(1<=(1+k)<=n) throw(DimensionMismatch("Incorrect number of bands")) end
if m < 1+k throw(DimensionMismatch("Array A has fewer than 1+k rows")) end
if n != length(x) throw(DimensionMismatch("")) end
lda = max(1,stride(A,2))
incx = stride(x,1)
queue = global_queue(context(A), device(A))
$fname(sycl_queue(queue), uplo, trans, diag, n, k, A, lda, x, incx)
x
end

function tbsv(uplo::Char,
trans::Char,
diag::Char,
k::Integer,
A::oneStridedVecOrMat{$elty},
x::oneStridedVecOrMat{$elty})
tbsv!(uplo, trans, diag, k, A, copy(x))
end

end
end
51 changes: 50 additions & 1 deletion test/onemkl.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using oneAPI
using oneAPI.oneMKL
using oneAPI.oneMKL: band, bandex

using LinearAlgebra

Expand Down Expand Up @@ -67,3 +67,52 @@ m = 20
end
end
end

@testset "level 2" begin
@testset for T in intersect(eltypes, [Float32, Float64, ComplexF32, ComplexF64])
#=
@testset "tbsv!" begin
# generate triangular matrix
A = rand(T,m,m)
# restrict to 3 bands
nbands = 3
@test m >= 1+nbands
A = bandex(A,0,nbands)
# convert to 'upper' banded storage format
AB = band(A,0,nbands)
d_AB = oneArray(AB)
# construct x
x = rand(T,m)
d_x = oneArray(x)
d_y = copy(d_x)
#tbsv!
oneMKL.tbsv!('U','N','N',nbands,d_AB,d_y)
y = A\x
# compare
h_y = Array(d_y)
@test y ≈ h_y
end
=#
if T == Float32
@testset "tbsv" begin
# generate triangular matrix
A = rand(T,m,m)
# restrict to 3 bands
nbands = 3
@test m >= 1+nbands
A = bandex(A,0,nbands)
# convert to 'upper' banded storage format
AB = band(A,0,nbands)
d_AB = oneArray(AB)
# construct x
x = rand(T,m)
d_x = oneArray(x)
d_y = oneMKL.tbsv('U','N','N',nbands,d_AB,d_x)
y = A\x
# compare
h_y = Array(d_y)
@test y ≈ h_y
end
end
end
end