Skip to content

Commit 39eaeee

Browse files
committed
[oneMKL] Interface sparse routines
1 parent e5f6dc7 commit 39eaeee

File tree

10 files changed

+391
-135
lines changed

10 files changed

+391
-135
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1919
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2020
SPIRV_LLVM_Translator_unified_jll = "85f0d8ed-5b39-5caa-b1ae-7472de402361"
2121
SPIRV_Tools_jll = "6ac6d60f-d740-5983-97d7-a4482c0689f4"
22+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2223
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
2324
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2425
oneAPI_Level_Zero_Headers_jll = "f4bc562b-d309-54f8-9efb-476e56f0410d"
@@ -40,7 +41,7 @@ SpecialFunctions = "1.3, 2"
4041
StaticArrays = "1"
4142
julia = "1.8"
4243
oneAPI_Level_Zero_Loader_jll = "1.9"
43-
oneAPI_Support_jll = "~0.3.1"
44+
oneAPI_Support_jll = "~0.3.2"
4445

4546
[extras]
4647
libigc_jll = "94295238-5935-5bd7-bb0f-b00942e9bdd5"

deps/generate_interfaces.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ function generate_headers(library::String, filename::String, output::String)
232232
occursin("int64_t", header) && (suffix = "_64")
233233
end
234234
header = replace(header, "$(name_routine)(" => "onemkl$(version)$(name_routine)$(suffix)(")
235-
if name_routine ("init_matrix_handle", "init_matmat_descr", "release_matmat_descr", "set_matmat_data")
235+
if name_routine ("init_matrix_handle", "init_matmat_descr", "release_matmat_descr", "set_matmat_data", "set_csr_data")
236236
header = replace(header, "void onemkl" => "int onemkl")
237237
end
238238
if library == "sparse"
@@ -369,7 +369,7 @@ function generate_cpp(library::String, filename::String, output::String)
369369
!occursin("scratchpad_size", name) && write(oneapi_cpp, " auto status = oneapi::mkl::$library::$variant$name<$type>($parameters);\n")
370370
occursin("scratchpad_size", name) && write(oneapi_cpp, " int64_t scratchpad_size = oneapi::mkl::$library::$variant$name<$type>($parameters);\n")
371371
else
372-
if name ("init_matrix_handle", "init_matmat_descr", "release_matmat_descr", "set_matmat_data")
372+
if name ("init_matrix_handle", "init_matmat_descr", "release_matmat_descr", "set_matmat_data", "set_csr_data")
373373
write(oneapi_cpp, " auto status = oneapi::mkl::$library::$variant$name($parameters);\n")
374374
else
375375
write(oneapi_cpp, " oneapi::mkl::$library::$variant$name($parameters);\n")
@@ -378,8 +378,8 @@ function generate_cpp(library::String, filename::String, output::String)
378378
if occursin("scratchpad_size", name)
379379
write(oneapi_cpp, " return scratchpad_size;\n")
380380
else
381-
(name ("init_matrix_handle", "init_matmat_descr", "release_matmat_descr", "set_matmat_data")) && write(oneapi_cpp, " __FORCE_MKL_FLUSH__(status);\n")
382-
(name ("init_matrix_handle", "init_matmat_descr", "release_matmat_descr", "set_matmat_data")) && write(oneapi_cpp, " return 0;\n")
381+
(name ("init_matrix_handle", "init_matmat_descr", "release_matmat_descr", "set_matmat_data", "set_csr_data")) && write(oneapi_cpp, " __FORCE_MKL_FLUSH__(status);\n")
382+
(name ("init_matrix_handle", "init_matmat_descr", "release_matmat_descr", "set_matmat_data", "set_csr_data")) && write(oneapi_cpp, " return 0;\n")
383383
end
384384
write(oneapi_cpp, "}")
385385
write(oneapi_cpp, "\n\n")

deps/src/onemkl.cpp

Lines changed: 16 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3465,52 +3465,36 @@ extern "C" void onemklXsparse_init_matrix_handle(matrix_handle_t *handle) {
34653465
oneapi::mkl::sparse::init_matrix_handle((oneapi::mkl::sparse::matrix_handle_t*) handle);
34663466
}
34673467

3468-
extern "C" int onemklSsparse_set_csr_data(syclQueue_t device_queue, matrix_handle_t handle, int32_t num_rows, int32_t num_cols, onemklIndex index, int32_t *row_ptr, int32_t *col_ind, float *val) {
3469-
auto status = oneapi::mkl::sparse::set_csr_data(device_queue->val, (oneapi::mkl::sparse::matrix_handle_t) handle, num_rows, num_cols, convert(index), row_ptr, col_ind, val);
3470-
__FORCE_MKL_FLUSH__(status);
3471-
return 0;
3468+
extern "C" void onemklSsparse_set_csr_data(syclQueue_t device_queue, matrix_handle_t handle, int32_t num_rows, int32_t num_cols, onemklIndex index, int32_t *row_ptr, int32_t *col_ind, float *val) {
3469+
oneapi::mkl::sparse::set_csr_data(device_queue->val, (oneapi::mkl::sparse::matrix_handle_t) handle, num_rows, num_cols, convert(index), row_ptr, col_ind, val);
34723470
}
34733471

3474-
extern "C" int onemklSsparse_set_csr_data_64(syclQueue_t device_queue, matrix_handle_t handle, int64_t num_rows, int64_t num_cols, onemklIndex index, int64_t *row_ptr, int64_t *col_ind, float *val) {
3475-
auto status = oneapi::mkl::sparse::set_csr_data(device_queue->val, (oneapi::mkl::sparse::matrix_handle_t) handle, num_rows, num_cols, convert(index), row_ptr, col_ind, val);
3476-
__FORCE_MKL_FLUSH__(status);
3477-
return 0;
3472+
extern "C" void onemklSsparse_set_csr_data_64(syclQueue_t device_queue, matrix_handle_t handle, int64_t num_rows, int64_t num_cols, onemklIndex index, int64_t *row_ptr, int64_t *col_ind, float *val) {
3473+
oneapi::mkl::sparse::set_csr_data(device_queue->val, (oneapi::mkl::sparse::matrix_handle_t) handle, num_rows, num_cols, convert(index), row_ptr, col_ind, val);
34783474
}
34793475

3480-
extern "C" int onemklDsparse_set_csr_data(syclQueue_t device_queue, matrix_handle_t handle, int32_t num_rows, int32_t num_cols, onemklIndex index, int32_t *row_ptr, int32_t *col_ind, double *val) {
3481-
auto status = oneapi::mkl::sparse::set_csr_data(device_queue->val, (oneapi::mkl::sparse::matrix_handle_t) handle, num_rows, num_cols, convert(index), row_ptr, col_ind, val);
3482-
__FORCE_MKL_FLUSH__(status);
3483-
return 0;
3476+
extern "C" void onemklDsparse_set_csr_data(syclQueue_t device_queue, matrix_handle_t handle, int32_t num_rows, int32_t num_cols, onemklIndex index, int32_t *row_ptr, int32_t *col_ind, double *val) {
3477+
oneapi::mkl::sparse::set_csr_data(device_queue->val, (oneapi::mkl::sparse::matrix_handle_t) handle, num_rows, num_cols, convert(index), row_ptr, col_ind, val);
34843478
}
34853479

3486-
extern "C" int onemklDsparse_set_csr_data_64(syclQueue_t device_queue, matrix_handle_t handle, int64_t num_rows, int64_t num_cols, onemklIndex index, int64_t *row_ptr, int64_t *col_ind, double *val) {
3487-
auto status = oneapi::mkl::sparse::set_csr_data(device_queue->val, (oneapi::mkl::sparse::matrix_handle_t) handle, num_rows, num_cols, convert(index), row_ptr, col_ind, val);
3488-
__FORCE_MKL_FLUSH__(status);
3489-
return 0;
3480+
extern "C" void onemklDsparse_set_csr_data_64(syclQueue_t device_queue, matrix_handle_t handle, int64_t num_rows, int64_t num_cols, onemklIndex index, int64_t *row_ptr, int64_t *col_ind, double *val) {
3481+
oneapi::mkl::sparse::set_csr_data(device_queue->val, (oneapi::mkl::sparse::matrix_handle_t) handle, num_rows, num_cols, convert(index), row_ptr, col_ind, val);
34903482
}
34913483

3492-
extern "C" int onemklCsparse_set_csr_data(syclQueue_t device_queue, matrix_handle_t handle, int32_t num_rows, int32_t num_cols, onemklIndex index, int32_t *row_ptr, int32_t *col_ind, float _Complex *val) {
3493-
auto status = oneapi::mkl::sparse::set_csr_data(device_queue->val, (oneapi::mkl::sparse::matrix_handle_t) handle, num_rows, num_cols, convert(index), row_ptr, col_ind, reinterpret_cast<std::complex<float>*>(val));
3494-
__FORCE_MKL_FLUSH__(status);
3495-
return 0;
3484+
extern "C" void onemklCsparse_set_csr_data(syclQueue_t device_queue, matrix_handle_t handle, int32_t num_rows, int32_t num_cols, onemklIndex index, int32_t *row_ptr, int32_t *col_ind, float _Complex *val) {
3485+
oneapi::mkl::sparse::set_csr_data(device_queue->val, (oneapi::mkl::sparse::matrix_handle_t) handle, num_rows, num_cols, convert(index), row_ptr, col_ind, reinterpret_cast<std::complex<float>*>(val));
34963486
}
34973487

3498-
extern "C" int onemklCsparse_set_csr_data_64(syclQueue_t device_queue, matrix_handle_t handle, int64_t num_rows, int64_t num_cols, onemklIndex index, int64_t *row_ptr, int64_t *col_ind, float _Complex *val) {
3499-
auto status = oneapi::mkl::sparse::set_csr_data(device_queue->val, (oneapi::mkl::sparse::matrix_handle_t) handle, num_rows, num_cols, convert(index), row_ptr, col_ind, reinterpret_cast<std::complex<float>*>(val));
3500-
__FORCE_MKL_FLUSH__(status);
3501-
return 0;
3488+
extern "C" void onemklCsparse_set_csr_data_64(syclQueue_t device_queue, matrix_handle_t handle, int64_t num_rows, int64_t num_cols, onemklIndex index, int64_t *row_ptr, int64_t *col_ind, float _Complex *val) {
3489+
oneapi::mkl::sparse::set_csr_data(device_queue->val, (oneapi::mkl::sparse::matrix_handle_t) handle, num_rows, num_cols, convert(index), row_ptr, col_ind, reinterpret_cast<std::complex<float>*>(val));
35023490
}
35033491

3504-
extern "C" int onemklZsparse_set_csr_data(syclQueue_t device_queue, matrix_handle_t handle, int32_t num_rows, int32_t num_cols, onemklIndex index, int32_t *row_ptr, int32_t *col_ind, double _Complex *val) {
3505-
auto status = oneapi::mkl::sparse::set_csr_data(device_queue->val, (oneapi::mkl::sparse::matrix_handle_t) handle, num_rows, num_cols, convert(index), row_ptr, col_ind, reinterpret_cast<std::complex<double>*>(val));
3506-
__FORCE_MKL_FLUSH__(status);
3507-
return 0;
3492+
extern "C" void onemklZsparse_set_csr_data(syclQueue_t device_queue, matrix_handle_t handle, int32_t num_rows, int32_t num_cols, onemklIndex index, int32_t *row_ptr, int32_t *col_ind, double _Complex *val) {
3493+
oneapi::mkl::sparse::set_csr_data(device_queue->val, (oneapi::mkl::sparse::matrix_handle_t) handle, num_rows, num_cols, convert(index), row_ptr, col_ind, reinterpret_cast<std::complex<double>*>(val));
35083494
}
35093495

3510-
extern "C" int onemklZsparse_set_csr_data_64(syclQueue_t device_queue, matrix_handle_t handle, int64_t num_rows, int64_t num_cols, onemklIndex index, int64_t *row_ptr, int64_t *col_ind, double _Complex *val) {
3511-
auto status = oneapi::mkl::sparse::set_csr_data(device_queue->val, (oneapi::mkl::sparse::matrix_handle_t) handle, num_rows, num_cols, convert(index), row_ptr, col_ind, reinterpret_cast<std::complex<double>*>(val));
3512-
__FORCE_MKL_FLUSH__(status);
3513-
return 0;
3496+
extern "C" void onemklZsparse_set_csr_data_64(syclQueue_t device_queue, matrix_handle_t handle, int64_t num_rows, int64_t num_cols, onemklIndex index, int64_t *row_ptr, int64_t *col_ind, double _Complex *val) {
3497+
oneapi::mkl::sparse::set_csr_data(device_queue->val, (oneapi::mkl::sparse::matrix_handle_t) handle, num_rows, num_cols, convert(index), row_ptr, col_ind, reinterpret_cast<std::complex<double>*>(val));
35143498
}
35153499

35163500
extern "C" int onemklSsparse_gemv(syclQueue_t device_queue, onemklTranspose transpose_flag, float alpha, matrix_handle_t handle, float *x, float beta, float *y) {

deps/src/onemkl.h

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2009,37 +2009,37 @@ int64_t onemklZgels_batch_scratchpad_size(syclQueue_t device_queue, onemklTransp
20092009
// SPARSE
20102010
void onemklXsparse_init_matrix_handle(matrix_handle_t *handle);
20112011

2012-
int onemklSsparse_set_csr_data(syclQueue_t device_queue, matrix_handle_t handle, int32_t
2013-
num_rows, int32_t num_cols, onemklIndex index, int32_t *row_ptr,
2014-
int32_t *col_ind, float *val);
2012+
void onemklSsparse_set_csr_data(syclQueue_t device_queue, matrix_handle_t handle, int32_t
2013+
num_rows, int32_t num_cols, onemklIndex index, int32_t *row_ptr,
2014+
int32_t *col_ind, float *val);
20152015

2016-
int onemklSsparse_set_csr_data_64(syclQueue_t device_queue, matrix_handle_t handle, int64_t
2017-
num_rows, int64_t num_cols, onemklIndex index, int64_t
2018-
*row_ptr, int64_t *col_ind, float *val);
2016+
void onemklSsparse_set_csr_data_64(syclQueue_t device_queue, matrix_handle_t handle, int64_t
2017+
num_rows, int64_t num_cols, onemklIndex index, int64_t
2018+
*row_ptr, int64_t *col_ind, float *val);
20192019

2020-
int onemklDsparse_set_csr_data(syclQueue_t device_queue, matrix_handle_t handle, int32_t
2021-
num_rows, int32_t num_cols, onemklIndex index, int32_t *row_ptr,
2022-
int32_t *col_ind, double *val);
2020+
void onemklDsparse_set_csr_data(syclQueue_t device_queue, matrix_handle_t handle, int32_t
2021+
num_rows, int32_t num_cols, onemklIndex index, int32_t *row_ptr,
2022+
int32_t *col_ind, double *val);
20232023

2024-
int onemklDsparse_set_csr_data_64(syclQueue_t device_queue, matrix_handle_t handle, int64_t
2025-
num_rows, int64_t num_cols, onemklIndex index, int64_t
2026-
*row_ptr, int64_t *col_ind, double *val);
2024+
void onemklDsparse_set_csr_data_64(syclQueue_t device_queue, matrix_handle_t handle, int64_t
2025+
num_rows, int64_t num_cols, onemklIndex index, int64_t
2026+
*row_ptr, int64_t *col_ind, double *val);
20272027

2028-
int onemklCsparse_set_csr_data(syclQueue_t device_queue, matrix_handle_t handle, int32_t
2029-
num_rows, int32_t num_cols, onemklIndex index, int32_t *row_ptr,
2030-
int32_t *col_ind, float _Complex *val);
2028+
void onemklCsparse_set_csr_data(syclQueue_t device_queue, matrix_handle_t handle, int32_t
2029+
num_rows, int32_t num_cols, onemklIndex index, int32_t *row_ptr,
2030+
int32_t *col_ind, float _Complex *val);
20312031

2032-
int onemklCsparse_set_csr_data_64(syclQueue_t device_queue, matrix_handle_t handle, int64_t
2033-
num_rows, int64_t num_cols, onemklIndex index, int64_t
2034-
*row_ptr, int64_t *col_ind, float _Complex *val);
2032+
void onemklCsparse_set_csr_data_64(syclQueue_t device_queue, matrix_handle_t handle, int64_t
2033+
num_rows, int64_t num_cols, onemklIndex index, int64_t
2034+
*row_ptr, int64_t *col_ind, float _Complex *val);
20352035

2036-
int onemklZsparse_set_csr_data(syclQueue_t device_queue, matrix_handle_t handle, int32_t
2037-
num_rows, int32_t num_cols, onemklIndex index, int32_t *row_ptr,
2038-
int32_t *col_ind, double _Complex *val);
2036+
void onemklZsparse_set_csr_data(syclQueue_t device_queue, matrix_handle_t handle, int32_t
2037+
num_rows, int32_t num_cols, onemklIndex index, int32_t *row_ptr,
2038+
int32_t *col_ind, double _Complex *val);
20392039

2040-
int onemklZsparse_set_csr_data_64(syclQueue_t device_queue, matrix_handle_t handle, int64_t
2041-
num_rows, int64_t num_cols, onemklIndex index, int64_t
2042-
*row_ptr, int64_t *col_ind, double _Complex *val);
2040+
void onemklZsparse_set_csr_data_64(syclQueue_t device_queue, matrix_handle_t handle, int64_t
2041+
num_rows, int64_t num_cols, onemklIndex index, int64_t
2042+
*row_ptr, int64_t *col_ind, double _Complex *val);
20432043

20442044
int onemklSsparse_gemv(syclQueue_t device_queue, onemklTranspose transpose_flag, float alpha,
20452045
matrix_handle_t handle, float *x, float beta, float *y);

lib/mkl/oneMKL.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,17 @@ using ..SYCL: syclQueue_t
1111

1212
using GPUArrays
1313

14+
using LinearAlgebra
15+
using SparseArrays
16+
1417
# Exclude Float16 for now, since many oneMKL functions - copy, scal, do not take Float16
1518
const onemklFloat = Union{Float64,Float32,ComplexF64,ComplexF32}
1619
const onemklComplex = Union{ComplexF32,ComplexF64}
1720
const onemklHalf = Union{Float16,ComplexF16}
18-
include("wrappers.jl")
21+
22+
include("utils.jl")
23+
include("wrappers_blas.jl")
24+
include("wrappers_sparse.jl")
1925
include("linalg.jl")
2026

2127
function band(A::StridedArray, kl, ku)

lib/mkl/utils.jl

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
#
2+
# Auxiliary
3+
#
4+
5+
function Base.convert(::Type{onemklSide}, side::Char)
6+
if side == 'L'
7+
return ONEMKL_SIDE_LEFT
8+
elseif side == 'R'
9+
return ONEMKL_SIDE_RIGHT
10+
else
11+
throw(ArgumentError("Unknown transpose $side"))
12+
end
13+
end
14+
15+
function Base.convert(::Type{onemklTranspose}, trans::Char)
16+
if trans == 'N'
17+
return ONEMKL_TRANSPOSE_NONTRANS
18+
elseif trans == 'T'
19+
return ONEMKL_TRANSPOSE_TRANS
20+
elseif trans == 'C'
21+
return ONEMLK_TRANSPOSE_CONJTRANS
22+
else
23+
throw(ArgumentError("Unknown transpose $trans"))
24+
end
25+
end
26+
27+
function Base.convert(::Type{onemklUplo}, uplo::Char)
28+
if uplo == 'U'
29+
return ONEMKL_UPLO_UPPER
30+
elseif uplo == 'L'
31+
return ONEMKL_UPLO_LOWER
32+
else
33+
throw(ArgumentError("Unknown transpose $uplo"))
34+
end
35+
end
36+
37+
function Base.convert(::Type{onemklDiag}, diag::Char)
38+
if diag == 'N'
39+
return ONEMKL_DIAG_NONUNIT
40+
elseif diag == 'U'
41+
return ONEMKL_DIAG_UNIT
42+
else
43+
throw(ArgumentError("Unknown transpose $diag"))
44+
end
45+
end
46+
47+
function Base.convert(::Type{onemklIndex}, index::Char)
48+
if index == 'O'
49+
return ONEMKL_INDEX_ONE
50+
elseif index == 'Z'
51+
return ONEMKL_INDEX_ZERO
52+
else
53+
throw(ArgumentError("Unknown index $index"))
54+
end
55+
end
56+
57+
function Base.convert(::Type{onemklLayout}, index::Char)
58+
if index == 'R'
59+
return ONEMKL_LAYOUT_ROW
60+
elseif index == 'C'
61+
return ONEMKL_LAYOUT_COL
62+
else
63+
throw(ArgumentError("Unknown layout $layout"))
64+
end
65+
end
66+
67+
# create a batch of pointers in device memory from a batch of device arrays
68+
@inline function unsafe_batch(batch::Vector{<:oneArray{T}}) where {T}
69+
ptrs = pointer.(batch)
70+
return oneArray(ptrs)
71+
end

lib/mkl/wrappers.jl renamed to lib/mkl/wrappers_blas.jl

Lines changed: 1 addition & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,65 +1,3 @@
1-
#
2-
# Auxiliary
3-
#
4-
5-
function Base.convert(::Type{onemklSide}, side::Char)
6-
if side == 'L'
7-
return ONEMKL_SIDE_LEFT
8-
elseif side == 'R'
9-
return ONEMKL_SIDE_RIGHT
10-
else
11-
throw(ArgumentError("Unknown transpose $side"))
12-
end
13-
end
14-
15-
function Base.convert(::Type{onemklTranspose}, trans::Char)
16-
if trans == 'N'
17-
return ONEMKL_TRANSPOSE_NONTRANS
18-
elseif trans == 'T'
19-
return ONEMKL_TRANSPOSE_TRANS
20-
elseif trans == 'C'
21-
return ONEMLK_TRANSPOSE_CONJTRANS
22-
else
23-
throw(ArgumentError("Unknown transpose $trans"))
24-
end
25-
end
26-
27-
function Base.convert(::Type{onemklUplo}, uplo::Char)
28-
if uplo == 'U'
29-
return ONEMKL_UPLO_UPPER
30-
elseif uplo == 'L'
31-
return ONEMKL_UPLO_LOWER
32-
else
33-
throw(ArgumentError("Unknown transpose $uplo"))
34-
end
35-
end
36-
37-
function Base.convert(::Type{onemklDiag}, diag::Char)
38-
if diag == 'N'
39-
return ONEMKL_DIAG_NONUNIT
40-
elseif diag == 'U'
41-
return ONEMKL_DIAG_UNIT
42-
else
43-
throw(ArgumentError("Unknown transpose $diag"))
44-
end
45-
end
46-
47-
function Base.convert(::Type{onemklIndex}, index::Char)
48-
if index == 'O'
49-
return ONEMKL_INDEX_ONE
50-
elseif index == 'Z'
51-
return ONEMKL_INDEX_ZERO
52-
else
53-
throw(ArgumentError("Unknown index $index"))
54-
end
55-
end
56-
57-
# create a batch of pointers in device memory from a batch of device arrays
58-
@inline function unsafe_batch(batch::Vector{<:oneArray{T}}) where {T}
59-
ptrs = pointer.(batch)
60-
return oneArray(ptrs)
61-
end
62-
631
## (GE) general matrix-matrix multiplication batched
642
for (fname, elty) in
653
((:onemklDgemmBatched,:Float64),
@@ -1263,4 +1201,4 @@ end
12631201
function gemm_strided_batched(transA::Char, transB::Char, A::AbstractArray{T, 3},
12641202
B::AbstractArray{T,3}) where T
12651203
gemm_strided_batched(transA, transB, one(T), A, B)
1266-
end
1204+
end

0 commit comments

Comments
 (0)