Skip to content

Commit 7fc47b5

Browse files
committed
Adding iamax/iamin support using oneMKL sycl::buffer interface alternative
1 parent ece8283 commit 7fc47b5

File tree

6 files changed

+140
-8
lines changed

6 files changed

+140
-8
lines changed

deps/onemkl.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,44 @@ extern "C" void onemklCcopy(syclQueue_t device_queue, int64_t n, const float _Co
105105
reinterpret_cast<std::complex<float> *>(y), incy);
106106
}
107107

108+
extern "C" void onemklDamax(syclQueue_t device_queue, int64_t n, const double *x, int64_t incx, int64_t *result)
109+
{
110+
auto status = oneapi::mkl::blas::column_major::iamax(device_queue->val, n, x, incx, result);
111+
status.wait();
112+
}
113+
extern "C" void onemklSamax(syclQueue_t device_queue, int64_t n, const float *x, int64_t incx, int64_t *result)
114+
{
115+
auto status = oneapi::mkl::blas::column_major::iamax(device_queue->val, n, x, incx, result);
116+
status.wait();
117+
}
118+
extern "C" void onemklZamax(syclQueue_t device_queue, int64_t n, const double _Complex *x, int64_t incx, int64_t *result){
119+
auto status = oneapi::mkl::blas::column_major::iamax(device_queue->val, n, reinterpret_cast<const std::complex<double> *>(x), incx, result);
120+
status.wait();
121+
}
122+
extern "C" void onemklCamax(syclQueue_t device_queue, int64_t n, const float _Complex *x, int64_t incx, int64_t *result){
123+
auto status = oneapi::mkl::blas::column_major::iamax(device_queue->val, n, reinterpret_cast<const std::complex<float> *>(x), incx, result);
124+
status.wait();
125+
}
126+
127+
extern "C" void onemklDamin(syclQueue_t device_queue, int64_t n, const double *x, int64_t incx, int64_t *result)
128+
{
129+
auto status = oneapi::mkl::blas::column_major::iamin(device_queue->val, n, x, incx, result);
130+
status.wait();
131+
}
132+
extern "C" void onemklSamin(syclQueue_t device_queue, int64_t n, const float *x, int64_t incx, int64_t *result){
133+
auto status = oneapi::mkl::blas::column_major::iamin(device_queue->val, n, x, incx, result);
134+
status.wait();
135+
}
136+
extern "C" void onemklZamin(syclQueue_t device_queue, int64_t n, const double _Complex *x, int64_t incx, int64_t *result){
137+
auto status = oneapi::mkl::blas::column_major::iamin(device_queue->val, n, reinterpret_cast<const std::complex<double> *>(x), incx, result);
138+
status.wait();
139+
}
140+
extern "C" void onemklCamin(syclQueue_t device_queue, int64_t n, const float _Complex *x, int64_t incx, int64_t *result){
141+
auto status = oneapi::mkl::blas::column_major::iamin(device_queue->val, n, reinterpret_cast<const std::complex<float> *>(x), incx, result);
142+
status.wait();
143+
}
144+
145+
108146
// other
109147

110148
// oneMKL keeps a cache of SYCL queues and tries to destroy them when unloading the library.

deps/onemkl.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,16 @@ void onemklZcopy(syclQueue_t device_queue, int64_t n, const double _Complex *x,
4848
void onemklCcopy(syclQueue_t device_queue, int64_t n, const float _Complex *x,
4949
int64_t incx, float _Complex *y, int64_t incy);
5050

51+
void onemklDamax(syclQueue_t device_queue, int64_t n, const double *x, int64_t incx, int64_t *result);
52+
void onemklSamax(syclQueue_t device_queue, int64_t n, const float *x, int64_t incx, int64_t *result);
53+
void onemklZamax(syclQueue_t device_queue, int64_t n, const double _Complex *x, int64_t incx, int64_t *result);
54+
void onemklCamax(syclQueue_t device_queue, int64_t n, const float _Complex *x, int64_t incx, int64_t *result);
55+
56+
void onemklDamin(syclQueue_t device_queue, int64_t n, const double *x, int64_t incx, int64_t *result);
57+
void onemklSamin(syclQueue_t device_queue, int64_t n, const float *x, int64_t incx, int64_t *result);
58+
void onemklZamin(syclQueue_t device_queue, int64_t n, const double _Complex *x, int64_t incx, int64_t *result);
59+
void onemklCamin(syclQueue_t device_queue, int64_t n, const float _Complex *x, int64_t incx, int64_t *result);
60+
5161
void onemklDestroy();
5262
#ifdef __cplusplus
5363
}

lib/mkl/libonemkl.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,42 @@ function onemklCcopy(device_queue, n, x, incx, y, incy)
6666
y::ZePtr{ComplexF32}, incy::Int64)::Cvoid
6767
end
6868

69+
function onemklSamax(device_queue, n, x, incx, result)
70+
@ccall liboneapi_support.onemklSamax(device_queue::syclQueue_t, n::Int64,
71+
x::ZePtr{Cfloat}, incx::Int64, result::RefOrZeRef{Int64})::Cvoid
72+
end
73+
74+
function onemklDamax(device_queue, n, x, incx, result)
75+
@ccall liboneapi_support.onemklDamax(device_queue::syclQueue_t, n::Int64,
76+
x::ZePtr{Cdouble}, incx::Int64, result::RefOrZeRef{Int64})::Cvoid
77+
end
78+
79+
function onemklCamax(device_queue, n, x, incx, result)
80+
@ccall liboneapi_support.onemklCamax(device_queue::syclQueue_t, n::Int64,
81+
x::ZePtr{ComplexF32}, incx::Int64,result::RefOrZeRef{Int64})::Cvoid
82+
end
83+
84+
function onemklZamax(device_queue, n, x, incx, result)
85+
@ccall liboneapi_support.onemklZamax(device_queue::syclQueue_t, n::Int64,
86+
x::ZePtr{ComplexF64}, incx::Int64, result::RefOrZeRef{Int64})::Cvoid
87+
end
88+
89+
function onemklSamin(device_queue, n, x, incx, result)
90+
@ccall liboneapi_support.onemklSamin(device_queue::syclQueue_t, n::Int64,
91+
x::ZePtr{Cfloat}, incx::Int64, result::RefOrZeRef{Int64})::Cvoid
92+
end
93+
94+
function onemklDamin(device_queue, n, x, incx, result)
95+
@ccall liboneapi_support.onemklDamin(device_queue::syclQueue_t, n::Int64,
96+
x::ZePtr{Cdouble}, incx::Int64, result::RefOrZeRef{Int64})::Cvoid
97+
end
98+
99+
function onemklCamin(device_queue, n, x, incx, result)
100+
@ccall liboneapi_support.onemklCamin(device_queue::syclQueue_t, n::Int64,
101+
x::ZePtr{ComplexF32}, incx::Int64,result::RefOrZeRef{Int64})::Cvoid
102+
end
103+
104+
function onemklZamin(device_queue, n, x, incx, result)
105+
@ccall liboneapi_support.onemklZamin(device_queue::syclQueue_t, n::Int64,
106+
x::ZePtr{ComplexF64}, incx::Int64, result::RefOrZeRef{Int64})::Cvoid
107+
end

lib/mkl/oneMKL.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@ using GPUArrays
1212

1313
include("libonemkl.jl")
1414

15-
const onemklFloat = Union{Float64,Float32,Float16,ComplexF64,ComplexF32}
15+
# Remove Float16 for now since not all oneMKL functions support intersect
16+
# Revisit this later as a seperate task.
17+
#const onemklFloat = Union{Float64,Float32,Float16,ComplexF64,ComplexF32}
18+
const onemklFloat = Union{Float64,Float32,ComplexF64,ComplexF32}
1619

1720
include("wrappers.jl")
1821
include("linalg.jl")

lib/mkl/wrappers.jl

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,50 @@ for (fname, elty) in
2929
(:onemklCcopy,:ComplexF32))
3030
@eval begin
3131
function copy!(n::Integer,
32-
x::StridedArray{$elty},
33-
y::StridedArray{$elty})
32+
x::oneStridedArray{$elty},
33+
y::oneStridedArray{$elty})
3434
queue = global_queue(context(x), device(x))
3535
$fname(sycl_queue(queue), n, x, stride(x, 1), y, stride(y, 1))
3636
y
3737
end
3838
end
3939
end
4040

41+
## iamax
42+
for (fname, elty) in
43+
((:onemklDamax,:Float64),
44+
(:onemklSamax,:Float32),
45+
(:onemklZamax,:ComplexF64),
46+
(:onemklCamax,:ComplexF32))
47+
@eval begin
48+
function iamax(x::oneStridedArray{$elty})
49+
result = oneArray{Int64}([0]);
50+
n = length(x)
51+
queue = global_queue(context(x), device(x))
52+
$fname(sycl_queue(queue), n, x, stride(x, 1), result)
53+
res = Array(result)
54+
return res[1]+1
55+
end
56+
end
57+
end
58+
59+
# iamin
60+
for (fname, elty) in
61+
((:onemklDamin,:Float64),
62+
(:onemklSamin,:Float32),
63+
(:onemklZamin,:ComplexF64),
64+
(:onemklCamin,:ComplexF32))
65+
@eval begin
66+
function iamin(x::oneStridedArray{$elty})
67+
result = oneArray{Int64}([0]);
68+
n = length(x)
69+
queue = global_queue(context(x), device(x))
70+
$fname(sycl_queue(queue), n, x, stride(x, 1), result)
71+
res = Array(result)
72+
return res[1] + 1
73+
end
74+
end
75+
end
4176

4277
# level 3
4378

test/onemkl.jl

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,17 @@ k = 13
99

1010
############################################################################################
1111
@testset "level 1" begin
12-
@testset for T in intersect(eltypes, [Float32, Float64, ComplexF32, ComplexF64])
13-
A = oneArray(rand(T, m))
14-
B = oneArray{T}(undef, m)
15-
oneMKL.copy!(m,A,B)
16-
@test Array(A) == Array(B)
12+
@testset for T in eltypes
13+
if T <: oneMKL.onemklFloat
14+
A = oneArray(rand(T, m))
15+
B = oneArray{T}(undef, m)
16+
oneMKL.copy!(m,A,B)
17+
@test Array(A) == Array(B)
18+
19+
a = convert.(T, [1.0, 2.0, -0.8, 5.0, 3.0])
20+
ca = oneArray(a)
21+
@test BLAS.iamax(a) == oneMKL.iamax(ca)
22+
@test oneMKL.iamin(ca) == 3
23+
end
1724
end # level 1 testset
1825
end

0 commit comments

Comments
 (0)