From 6ee9e08e6303273ce2aecc425243db822e42f7e4 Mon Sep 17 00:00:00 2001 From: kballeda Date: Thu, 8 Dec 2022 12:12:10 +0530 Subject: [PATCH] tbsv is not working - results mismatch --- deps/src/onemkl.cpp | 63 ++++++++++++++++++++++++++++++++++++++++++ deps/src/onemkl.h | 30 ++++++++++++++++++++ lib/mkl/libonemkl.jl | 43 +++++++++++++++++++++++++++++ lib/mkl/oneMKL.jl | 31 +++++++++++++++++++++ lib/mkl/wrappers.jl | 65 ++++++++++++++++++++++++++++++++++++++++++++ test/onemkl.jl | 51 +++++++++++++++++++++++++++++++++- 6 files changed, 282 insertions(+), 1 deletion(-) diff --git a/deps/src/onemkl.cpp b/deps/src/onemkl.cpp index efc7e031..42ec4a24 100644 --- a/deps/src/onemkl.cpp +++ b/deps/src/onemkl.cpp @@ -24,6 +24,33 @@ oneapi::mkl::transpose convert(onemklTranspose val) { } } +oneapi::mkl::side convert(onemklSide val) { + switch (val) { + case ONEMKL_SIDE_LEFT: + return oneapi::mkl::side::left; + case ONEMKL_SIDE_RIGHT: + return oneapi::mkl::side::right; + } +} + +oneapi::mkl::uplo convert(onemklUplo val) { + switch(val) { + case ONEMKL_UPLO_UPPER: + return oneapi::mkl::uplo::upper; + case ONEMKL_UPLO_LOWER: + return oneapi::mkl::uplo::lower; + } +} + +oneapi::mkl::diag convert(onemklDiag val) { + switch(val) { + case ONEMKL_DIAG_NONUNIT: + return oneapi::mkl::diag::nonunit; + case ONEMKL_DIAG_UNIT: + return oneapi::mkl::diag::unit; + } +} + extern "C" int onemklHgemm(syclQueue_t device_queue, onemklTranspose transA, onemklTranspose transB, int64_t m, int64_t n, int64_t k, sycl::half alpha, const sycl::half *A, int64_t lda, @@ -87,6 +114,42 @@ extern "C" int onemklZgemm(syclQueue_t device_queue, onemklTranspose transA, return 0; } +extern "C" void onemklStbsv(syclQueue_t device_queue, onemklUplo uplo, + onemklTranspose trans, onemklDiag diag, int64_t n, + int64_t k, const float *a, int64_t lda, float *x, int64_t incx) { + auto status = oneapi::mkl::blas::column_major::tbsv(device_queue->val, convert(uplo), convert(trans), + convert(diag), n, k, a, lda, x, incx); + __FORCE_MKL_FLUSH__(status); +} + +extern "C" void onemklDtbsv(syclQueue_t device_queue, onemklUplo uplo, + onemklTranspose trans, onemklDiag diag, int64_t n, + int64_t k, const double *a, int64_t lda, double *x, int64_t incx) { + auto status = oneapi::mkl::blas::column_major::tbsv(device_queue->val, convert(uplo), convert(trans), + convert(diag), n, k, a, lda, x, incx); + __FORCE_MKL_FLUSH__(status); +} + +extern "C" void onemklCtbsv(syclQueue_t device_queue, onemklUplo uplo, + onemklTranspose trans, onemklDiag diag, int64_t n, + int64_t k, const float _Complex *a, int64_t lda, float _Complex *x, + int64_t incx) { + auto status = oneapi::mkl::blas::column_major::tbsv(device_queue->val, convert(uplo), convert(trans), + convert(diag), n, k, reinterpret_cast *>(a), + lda, reinterpret_cast *>(x), incx); + __FORCE_MKL_FLUSH__(status); +} + +extern "C" void onemklZtbsv(syclQueue_t device_queue, onemklUplo uplo, + onemklTranspose trans, onemklDiag diag, int64_t n, + int64_t k, const double _Complex *a, int64_t lda, double _Complex *x, + int64_t incx) { + auto status = oneapi::mkl::blas::column_major::tbsv(device_queue->val, convert(uplo), convert(trans), + convert(diag), n, k, reinterpret_cast *>(a), + lda, reinterpret_cast *>(x), incx); + __FORCE_MKL_FLUSH__(status); +} + extern "C" void onemklSasum(syclQueue_t device_queue, int64_t n, const float *x, int64_t incx, float *result) { diff --git a/deps/src/onemkl.h b/deps/src/onemkl.h index e8fbbe16..b29532d8 100644 --- a/deps/src/onemkl.h +++ b/deps/src/onemkl.h @@ -15,6 +15,21 @@ typedef enum { ONEMLK_TRANSPOSE_CONJTRANS } onemklTranspose; +typedef enum { + ONEMKL_UPLO_UPPER, + ONEMKL_UPLO_LOWER +} onemklUplo; + +typedef enum { + ONEMKL_DIAG_NONUNIT, + ONEMKL_DIAG_UNIT + } onemklDiag; + +typedef enum { + ONEMKL_SIDE_LEFT, + ONEMKL_SIDE_RIGHT +} onemklSide; + // XXX: how to expose half in C? // int onemklHgemm(syclQueue_t device_queue, onemklTranspose transA, // onemklTranspose transB, int64_t m, int64_t n, int64_t k, @@ -39,6 +54,21 @@ int onemklZgemm(syclQueue_t device_queue, onemklTranspose transA, const double _Complex *B, int64_t ldb, double _Complex beta, double _Complex *C, int64_t ldc); +void onemklStbsv(syclQueue_t device_queue, onemklUplo uplo, + onemklTranspose trans, onemklDiag diag, int64_t n, + int64_t k, const float *a, int64_t lda, float *x, int64_t incx); +void onemklDtbsv(syclQueue_t device_queue, onemklUplo uplo, + onemklTranspose trans, onemklDiag diag, int64_t n, + int64_t k, const double *a, int64_t lda, double *x, int64_t incx); +void onemklCtbsv(syclQueue_t device_queue, onemklUplo uplo, + onemklTranspose trans, onemklDiag diag, int64_t n, + int64_t k, const float _Complex *a, int64_t lda, float _Complex *x, + int64_t incx); +void onemklZtbsv(syclQueue_t device_queue, onemklUplo uplo, + onemklTranspose trans, onemklDiag diag, int64_t n, + int64_t k, const double _Complex *a, int64_t lda, double _Complex *x, + int64_t incx); + void onemklSasum(syclQueue_t device_queue, int64_t n, const float *x, int64_t incx, float *result); void onemklDasum(syclQueue_t device_queue, int64_t n, diff --git a/lib/mkl/libonemkl.jl b/lib/mkl/libonemkl.jl index 469feee3..a17c9207 100644 --- a/lib/mkl/libonemkl.jl +++ b/lib/mkl/libonemkl.jl @@ -6,6 +6,21 @@ using CEnum ONEMLK_TRANSPOSE_CONJTRANS = 2 end +@cenum onemklUplo::UInt32 begin + ONEMKL_UPLO_UPPER = 0 + ONEMKL_UPLO_LOWER = 1 +end + +@cenum onemklDiag::UInt32 begin + ONEMKL_DIAG_NONUNIT = 0 + ONEMKL_DIAG_UNIT = 1 +end + +@cenum onemklSide::UInt32 begin + ONEMKL_SIDE_LEFT = 0 + ONEMKL_SIDE_RIGHT = 1 +end + function onemklSgemm(device_queue, transA, transB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc) @ccall liboneapi_support.onemklSgemm(device_queue::syclQueue_t, transA::onemklTranspose, @@ -42,6 +57,34 @@ function onemklZgemm(device_queue, transA, transB, m, n, k, alpha, A, lda, B, ld C::ZePtr{ComplexF64}, ldc::Int64)::Cint end +function onemklStbsv(device_queue, uplo, trans, diag, n, k, a, lda, x, incx) + @ccall liboneapi_support.onemklStbsv(device_queue::syclQueue_t, uplo::onemklUplo, + trans::onemklTranspose, diag::onemklDiag, n::Int64, + k::Int64, a::ZePtr{Cfloat}, lda::Int64, x::ZePtr{Cfloat}, + incx::Int64)::Cvoid +end + +function onemklDtbsv(device_queue, uplo, trans, diag, n, k, a, lda, x, incx) + @ccall liboneapi_support.onemklDtbsv(device_queue::syclQueue_t, uplo::onemklUplo, + trans::onemklTranspose, diag::onemklDiag, n::Int64, + k::Int64, a::ZePtr{Cdouble}, lda::Int64, x::ZePtr{Cdouble}, + incx::Int64)::Cvoid +end + +function onemklCtbsv(device_queue, uplo, trans, diag, n, k, a, lda, x, incx) + @ccall liboneapi_support.onemklCtbsv(device_queue::syclQueue_t, uplo::onemklUplo, + trans::onemklTranspose, diag::onemklDiag, n::Int64, + k::Int64, a::ZePtr{ComplexF32}, lda::Int64, x::ZePtr{ComplexF32}, + incx::Int64)::Cvoid +end + +function onemklZtbsv(device_queue, uplo, trans, diag, n, k, a, lda, x, incx) + @ccall liboneapi_support.onemklZtbsv(device_queue::syclQueue_t, uplo::onemklUplo, + trans::onemklTranspose, diag::onemklDiag, n::Int64, + k::Int64, a::ZePtr{ComplexF64}, lda::Int64, x::ZePtr{ComplexF64}, + incx::Int64)::Cvoid +end + function onemklSasum(device_queue, n, x, incx, result) @ccall liboneapi_support.onemklSasum(device_queue::syclQueue_t, n::Int64, x::ZePtr{Cfloat}, incx::Int64, diff --git a/lib/mkl/oneMKL.jl b/lib/mkl/oneMKL.jl index 7b0b24b2..743bce1e 100644 --- a/lib/mkl/oneMKL.jl +++ b/lib/mkl/oneMKL.jl @@ -18,4 +18,35 @@ const onemklFloat = Union{Float64,Float32,ComplexF64,ComplexF32} include("wrappers.jl") include("linalg.jl") +function band(A::StridedArray, kl, ku) + m, n = size(A) + AB = zeros(eltype(A),kl+ku+1,n) + for j = 1:n + for i = max(1,j-ku):min(m,j+kl) + AB[ku+1-j+i,j] = A[i,j] + end + end + return AB +end + +# convert band storage to general matrix +function unband(AB::StridedArray,m,kl,ku) + bm, n = size(AB) + A = zeros(eltype(AB),m,n) + for j = 1:n + for i = max(1,j-ku):min(m,j+kl) + A[i,j] = AB[ku+1-j+i,j] + end + end + return A +end + +# zero out elements not on matrix bands +function bandex(A::AbstractMatrix,kl,ku) + m, n = size(A) + AB = band(A,kl,ku) + B = unband(AB,m,kl,ku) + return B +end + end diff --git a/lib/mkl/wrappers.jl b/lib/mkl/wrappers.jl index 795deab5..0b62b621 100644 --- a/lib/mkl/wrappers.jl +++ b/lib/mkl/wrappers.jl @@ -14,6 +14,36 @@ function Base.convert(::Type{onemklTranspose}, trans::Char) end end +function Base.convert(::Type{onemklSide}, side::Char) + if side == 'L' + return ONEMKL_SIDE_LEFT + elseif side == 'R' + return ONEMKL_SIDE_RIGHT + else + throw(ArgumentError("Unknown transpose $side")) + end +end + +function Base.convert(::Type{onemklUplo}, uplo::Char) + if uplo == 'U' + return ONEMKL_UPLO_UPPER + elseif uplo == 'L' + return ONEMKL_UPLO_LOWER + else + throw(ArgumentError("Unknown uplo $uplo")) + end +end + +function Base.convert(::Type{onemklDiag}, diag::Char) + if diag == 'N' + return ONEMKL_DIAG_NONUNIT + elseif diag == 'U' + return ONEMKL_DIAG_UNIT + else + throw(ArgumentError("Unknown transpose $diag")) + end +end + # level 1 ## axpy primitive for (fname, elty) in @@ -226,3 +256,38 @@ for (fname, elty) in end end end + +### tbsv, (TB) triangular banded matrix solve +for (fname, elty) in ((:onemklStbsv,:Float32), + (:onemklDtbsv,:Float64), + (:onemklCtbsv,:ComplexF32), + (:onemklZtbsv,:ComplexF64)) + @eval begin + function tbsv!(uplo::Char, + trans::Char, + diag::Char, + k::Integer, + A::oneStridedVecOrMat{$elty}, + x::oneStridedVecOrMat{$elty}) + m, n = size(A) + if !(1<=(1+k)<=n) throw(DimensionMismatch("Incorrect number of bands")) end + if m < 1+k throw(DimensionMismatch("Array A has fewer than 1+k rows")) end + if n != length(x) throw(DimensionMismatch("")) end + lda = max(1,stride(A,2)) + incx = stride(x,1) + queue = global_queue(context(A), device(A)) + $fname(sycl_queue(queue), uplo, trans, diag, n, k, A, lda, x, incx) + x + end + + function tbsv(uplo::Char, + trans::Char, + diag::Char, + k::Integer, + A::oneStridedVecOrMat{$elty}, + x::oneStridedVecOrMat{$elty}) + tbsv!(uplo, trans, diag, k, A, copy(x)) + end + + end +end diff --git a/test/onemkl.jl b/test/onemkl.jl index 4c937196..b2d40146 100644 --- a/test/onemkl.jl +++ b/test/onemkl.jl @@ -1,5 +1,5 @@ using oneAPI -using oneAPI.oneMKL +using oneAPI.oneMKL: band, bandex using LinearAlgebra @@ -67,3 +67,52 @@ m = 20 end end end + +@testset "level 2" begin + @testset for T in intersect(eltypes, [Float32, Float64, ComplexF32, ComplexF64]) +#= + @testset "tbsv!" begin + # generate triangular matrix + A = rand(T,m,m) + # restrict to 3 bands + nbands = 3 + @test m >= 1+nbands + A = bandex(A,0,nbands) + # convert to 'upper' banded storage format + AB = band(A,0,nbands) + d_AB = oneArray(AB) + # construct x + x = rand(T,m) + d_x = oneArray(x) + d_y = copy(d_x) + #tbsv! + oneMKL.tbsv!('U','N','N',nbands,d_AB,d_y) + y = A\x + # compare + h_y = Array(d_y) + @test y ≈ h_y + end +=# + if T == Float32 + @testset "tbsv" begin + # generate triangular matrix + A = rand(T,m,m) + # restrict to 3 bands + nbands = 3 + @test m >= 1+nbands + A = bandex(A,0,nbands) + # convert to 'upper' banded storage format + AB = band(A,0,nbands) + d_AB = oneArray(AB) + # construct x + x = rand(T,m) + d_x = oneArray(x) + d_y = oneMKL.tbsv('U','N','N',nbands,d_AB,d_x) + y = A\x + # compare + h_y = Array(d_y) + @test y ≈ h_y + end + end + end +end \ No newline at end of file