Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
664480c
cuBlas level 1 method - scal supported
kballeda Oct 31, 2022
5180b9b
scal test case updated in onemkl.jl
kballeda Oct 31, 2022
1307b8d
NITS - cleanup
kballeda Oct 31, 2022
86f9dc7
indentation NITS
kballeda Oct 31, 2022
57495e9
updated with scal - deps
kballeda Oct 31, 2022
28f245f
indentation fixes
kballeda Oct 31, 2022
86591f5
NITS
kballeda Oct 31, 2022
c49620f
NITS
kballeda Oct 31, 2022
990300a
cleanup onemkl.jl
kballeda Oct 31, 2022
1e164db
updated with rmul! and testf usage
kballeda Nov 1, 2022
139de1c
NITS
kballeda Nov 1, 2022
23943bb
testf used for cpu/gpu testing.
kballeda Nov 2, 2022
b51b870
NITS - clenaup & included int specific calls to rmul! diverted to
kballeda Nov 2, 2022
ab0abb2
wrapper alpha turns elttype and support all combinationswq
kballeda Nov 3, 2022
d5a58e0
NITS
kballeda Nov 3, 2022
c39e90c
support for Cs, Zd configs of scal function
kballeda Nov 3, 2022
0a17298
updated with staticcast complex alpha
kballeda Nov 3, 2022
e0aa24e
added onestridedarray
kballeda Nov 4, 2022
b9a7d29
enable tests of complex tye
kballeda Nov 4, 2022
645dad7
Merge branch 'master' into l1_scal
kballeda Nov 7, 2022
1c1b206
updated with Csscal and Zdscal test enabled
kballeda Nov 8, 2022
d423806
NITS
kballeda Nov 8, 2022
071f69e
NITS
kballeda Nov 8, 2022
8168b50
Merge branch 'master' into l1_scal
kballeda Nov 8, 2022
7be9df6
NITS
kballeda Nov 8, 2022
038252f
NITS
kballeda Nov 8, 2022
d07339e
Cleanup of tests
kballeda Nov 8, 2022
cbbd3a4
Instead of isapprox use compare op
kballeda Nov 9, 2022
af68c99
Bug fix: disable f16 check as it is not supported (CI crash)
kballeda Nov 10, 2022
7dd2fa3
Merge branch 'master' into l1_scal
kballeda Nov 10, 2022
41db9d5
Merge branch 'master' into l1_scal
kballeda Nov 22, 2022
578c300
use force flush instead of wait()
kballeda Nov 22, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions deps/onemkl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,33 @@ extern "C" int onemklZgemm(syclQueue_t device_queue, onemklTranspose transA,
return 0;
}

// Support Level-1: SCAL primitive
extern "C" void onemklDscal(syclQueue_t device_queue, int64_t n, double alpha,
double *x, int64_t incx) {
oneapi::mkl::blas::column_major::scal(device_queue->val, n, alpha, x, incx);
}

extern "C" void onemklSscal(syclQueue_t device_queue, int64_t n, float alpha,
float *x, int64_t incx) {
oneapi::mkl::blas::column_major::scal(device_queue->val, n, alpha, x, incx);
}

extern "C" void onemklCscal(syclQueue_t device_queue, int64_t n,
float alpha, float _Complex *x,
int64_t incx) {
oneapi::mkl::blas::column_major::scal(
device_queue->val, n, alpha, reinterpret_cast<std::complex<float> *>(x),
incx);
}

extern "C" void onemklZscal(syclQueue_t device_queue, int64_t n,
double alpha, double _Complex *x,
int64_t incx) {
oneapi::mkl::blas::column_major::scal(
device_queue->val, n, alpha, reinterpret_cast<std::complex<double> *>(x),
incx);
}


// other

Expand Down
6 changes: 6 additions & 0 deletions deps/onemkl.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ int onemklZgemm(syclQueue_t device_queue, onemklTranspose transA,
const double _Complex *B, int64_t ldb, double _Complex beta,
double _Complex *C, int64_t ldc);

// Level-1: scal oneMKL
void onemklDscal(syclQueue_t device_queue, int64_t n, double alpha, double *x, int64_t incx);
void onemklSscal(syclQueue_t device_queue, int64_t n, float alpha, float *x, int64_t incx);
void onemklCscal(syclQueue_t device_queue, int64_t n, float alpha, float _Complex *x, int64_t incx);
void onemklZscal(syclQueue_t device_queue, int64_t n, double alpha, double _Complex *x, int64_t incx);

void onemklDestroy();
#ifdef __cplusplus
}
Expand Down
17 changes: 17 additions & 0 deletions lib/mkl/libonemkl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,20 @@ function onemklZgemm(device_queue, transA, transB, m, n, k, alpha, A, lda, B, ld
B::ZePtr{ComplexF64}, ldb::Int64, beta::ComplexF64,
C::ZePtr{ComplexF64}, ldc::Int64)::Cint
end

function onemklDscal(device_queue, n, alpha, x, incx)
@ccall liboneapi_support.onemklDscal(device_queue::syclQueue_t, n::Int64, alpha::Cdouble, x::ZePtr{Cdouble}, incx::Int64)::Cvoid
end

function onemklSscal(device_queue, n, alpha, x, incx)
@ccall liboneapi_support.onemklSscal(device_queue::syclQueue_t, n::Int64, alpha::Cfloat, x::ZePtr{Cfloat}, incx::Int64)::Cvoid
end

function onemklZscal(device_queue, n, alpha, x, incx)
@ccall liboneapi_support.onemklZscal(device_queue::syclQueue_t, n::Int64, alpha::ComplexF64, x::ZePtr{ComplexF64}, incx::Int64)::Cvoid
end

function onemklCscal(device_queue, n, alpha, x, incx)
@ccall liboneapi_support.onemklCscal(device_queue::syclQueue_t, n::Int64, alpha::ComplexF32, x::ZePtr{ComplexF32}, incx::Int64)::Cvoid
end

18 changes: 17 additions & 1 deletion lib/mkl/wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,23 @@ function Base.convert(::Type{onemklTranspose}, trans::Char)
end
end


# level 1
## scal
for (fname, elty) in
((:onemklDscal,:Float64),
(:onemklSscal,:Float32),
(:onemklZscal,:ComplexF64),
(:onemklCscal,:ComplexF32))
@eval begin
function scal!(n::Integer,
alpha::Number,
x::StridedArray{$elty})
queue = global_queue(context(x), device(x))
$fname(sycl_queue(queue), n, alpha, x, stride(x,1))
x
end
end
end

#
# BLAS
Expand Down
22 changes: 22 additions & 0 deletions test/onemkl.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
using oneAPI.oneMKL
using LinearAlgebra

m = 20
n = 35
k = 13


########################
@testset "level 1" begin
@testset for T in eltypes
A = rand(T, m)
gpuA = oneArray(A)
if T === Float32
oneMKL.scal!(m, 5f0, gpuA)
else
oneMKL.scal!(m, 5.0, gpuA)
end
_A = Array(gpuA)
@test isapprox(A .* 5, _A)
end
end