Skip to content

Commit aedfc12

Browse files
authored
Add library wrappers for oneMKL Sparse (#393)
1 parent c576901 commit aedfc12

File tree

8 files changed

+954
-41
lines changed

8 files changed

+954
-41
lines changed

deps/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON)
66

77
project(oneAPISupport)
88

9-
add_library(oneapi_support SHARED src/sycl.h src/sycl.hpp src/sycl.cpp src/onemkl.cpp)
9+
add_library(oneapi_support SHARED src/sycl.h src/sycl.hpp src/sycl.cpp src/onemkl.h src/onemkl.cpp)
1010

1111
target_link_libraries(oneapi_support
1212
mkl_sycl

deps/generate_interfaces.jl

Lines changed: 46 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,12 @@ dict_version = Dict{Int, Char}(1 => 'S', 2 => 'D', 3 => 'C', 4 => 'Z')
99
version_types = Dict{Char, String}('S' => "float",
1010
'D' => "double",
1111
'C' => "std::complex<float>",
12-
'Z' => "std::complex<double>",
13-
'I' => "int32_t",
14-
'L' => "int64_t")
12+
'Z' => "std::complex<double>")
1513

1614
version_types_header = Dict{Char, String}('S' => "float",
1715
'D' => "double",
1816
'C' => "float _Complex",
19-
'Z' => "double _Complex",
20-
'I' => "int32_t",
21-
'L' => "int64_t")
17+
'Z' => "double _Complex")
2218

2319
function generate_headers(library::String, filename::String, output::String)
2420
routines = Dict{String,Int}()
@@ -49,8 +45,10 @@ function generate_headers(library::String, filename::String, output::String)
4945
occursin("heevx", header) && continue # LAPACK routine
5046
occursin("hegvx", header) && continue # LAPACK routine
5147
occursin("(matrix_handle_t handle", header) && continue # SPARSE routine
48+
occursin("update_diagonal_values", header) && continue # SPARSE routine
49+
occursin("get_matmat_data", header) && continue # SPARSE routine
50+
occursin("matmat(", header) && continue # SPARSE routine
5251
occursin("gemvdot", header) && continue # SPARSE routine
53-
occursin("matmat", header) && continue # SPARSE routine
5452

5553
# Check if the routine is a template
5654
template = occursin("template", header)
@@ -104,8 +102,10 @@ function generate_headers(library::String, filename::String, output::String)
104102
header = replace(header, "sycl::buffer<double, 1> &" => "double *")
105103
header = replace(header, "sycl::buffer<std::complex<float>, 1> &" => "float _Complex *")
106104
header = replace(header, "sycl::buffer<std::complex<double>, 1> &" => "double _Complex *")
105+
header = replace(header, "sycl::buffer<std::uint8_t, 1> *" => "uint8_t *")
107106
header = replace(header, "sycl::buffer<int32_t, 1> &" => "int32_t *")
108107
header = replace(header, "sycl::buffer<int64_t, 1> &" => "int64_t *")
108+
header = replace(header, "sycl::buffer<int64_t, 1> *" => "int64_t *")
109109

110110
header = replace(header, "template <>\n" => "")
111111
header = replace(header, "<std::complex<float>>" => "")
@@ -129,6 +129,7 @@ function generate_headers(library::String, filename::String, output::String)
129129
header = replace(header, "oneapi::mkl::layout" => "onemklLayout")
130130
header = replace(header, "oneapi::mkl::index" => "onemklIndex")
131131
header = replace(header, "oneapi::mkl::property" => "onemklProperty")
132+
header = replace(header, "sparse::matmat_descr_t" => "matmat_descr_t")
132133

133134
# Sanitize the header
134135
header = replace(header, " \\" => "")
@@ -203,10 +204,6 @@ function generate_headers(library::String, filename::String, output::String)
203204
end
204205
end
205206
version = 'X'
206-
if library == "sparse"
207-
version = occursin("int32_t", header) ? 'I' : version
208-
version = occursin("int64_t", header) ? 'L' : version
209-
end
210207
version = occursin("double", header) ? 'D' : version
211208
version = occursin("float", header) ? 'S' : version
212209
version = occursin("float _Complex", header) ? 'C' : version
@@ -218,6 +215,7 @@ function generate_headers(library::String, filename::String, output::String)
218215
versions = ('S', 'D', 'C', 'Z')
219216
mapreduce(x -> startswith(name_routine, x), |, ["or", "sy"]) && !startswith(name_routine, "sytrf") && (versions = ('S', 'D'))
220217
mapreduce(x -> startswith(name_routine, x), |, ["un", "he"]) && (versions = ('C', 'Z'))
218+
(name_routine == "gesvd_scratchpad_size") && (routines[name_routine] > 1) && continue
221219
routines[name_routine] = routines[name_routine] - 1 + length(versions)
222220
for blas_version in versions
223221
copy_header = header
@@ -229,12 +227,14 @@ function generate_headers(library::String, filename::String, output::String)
229227
end
230228
else
231229
if isempty(list_versions)
230+
suffix = ""
232231
if name_routine == "set_csr_data"
233-
occursin("int32_t", header) && (version = "I" * version)
234-
occursin("int64_t", header) && (version = "L" * version)
232+
occursin("int64_t", header) && (suffix = "_64")
233+
end
234+
header = replace(header, "$(name_routine)(" => "onemkl$(version)$(name_routine)$(suffix)(")
235+
if name_routine ("init_matrix_handle", "init_matmat_descr", "release_matmat_descr", "set_matmat_data")
236+
header = replace(header, "void onemkl" => "int onemkl")
235237
end
236-
header = replace(header, name_routine => "onemkl$(version)$(name_routine)")
237-
header = replace(header, "void onemkl" => "int onemkl")
238238
if library == "sparse"
239239
if occursin("std::complex", header)
240240
(version == 'C') && (header = replace(header, "std::complex " => "float _Complex "))
@@ -247,6 +247,9 @@ function generate_headers(library::String, filename::String, output::String)
247247
header = replace(header, "layout " => "onemklLayout ")
248248
header = replace(header, "index_base " => "onemklIndex ")
249249
header = replace(header, "property " => "onemklProperty ")
250+
header = replace(header, "sparse::matrix_view_descr " => "onemklMatrixView ")
251+
header = replace(header, "matrix_view_descr " => "onemklMatrixView ")
252+
header = replace(header, "sparse::matmat_request " => "onemklMatmatRequest ")
250253
header = replace(header, name_routine => "sparse_" * name_routine)
251254
end
252255
push!(signatures, (header, name_routine, version, template))
@@ -280,7 +283,7 @@ function generate_headers(library::String, filename::String, output::String)
280283
# Check the number of methods
281284
blacklist = String[]
282285
for name_routine in keys(routines)
283-
if (routines[name_routine] > 5) && (library != "sparse")
286+
if (routines[name_routine] > 5) && (name_routine != "set_csr_data")
284287
@warn "The routine $(name_routine) has more than 4 methods and will not be interfaced."
285288
push!(blacklist, name_routine)
286289
end
@@ -331,7 +334,10 @@ function generate_cpp(library::String, filename::String, output::String)
331334
parameters = replace(parameters, "syclQueue_t device_queue" => "device_queue->val")
332335
parameters = replace(parameters, "int32_t " => "")
333336
parameters = replace(parameters, "int64_t " => "")
334-
parameters = replace(parameters, "matrix_handle_t " => "")
337+
parameters = replace(parameters, "matrix_handle_t *" => "(oneapi::mkl::sparse::matrix_handle_t*) ")
338+
parameters = replace(parameters, "matrix_handle_t " => "(oneapi::mkl::sparse::matrix_handle_t) ")
339+
parameters = replace(parameters, "matmat_descr_t *" => "(oneapi::mkl::sparse::matmat_descr_t*) ")
340+
parameters = replace(parameters, "matmat_descr_t " => "(oneapi::mkl::sparse::matmat_descr_t) ")
335341
parameters = replace(parameters, "float _Complex *" => "reinterpret_cast<std::complex<float> *>")
336342
parameters = replace(parameters, "double _Complex *" => "reinterpret_cast<std::complex<double> *>")
337343
parameters = replace(parameters, "float _Complex " => "static_cast<std::complex<float> >")
@@ -342,13 +348,15 @@ function generate_cpp(library::String, filename::String, output::String)
342348
parameters = replace(parameters, ", double " => ", ")
343349
parameters = replace(parameters, ", *" => ", ")
344350

345-
for type in ("onemklTranspose", "onemklSide", "onemklUplo", "onemklDiag", "onemklGenerate",
346-
"onemklJob", "onemklJobsvd", "onemklCompz", "onemklRangev", "onemklIndex", "onemklProperty")
347-
parameters = replace(parameters, Regex("$type ([a-z_]+),") => SubstitutionString("convert(\\1),"))
348-
parameters = replace(parameters, Regex(", $type ([a-z_]+)") => SubstitutionString(", convert(\\1)"))
351+
for type in ("onemklTranspose", "onemklSide", "onemklUplo", "onemklDiag",
352+
"onemklGenerate", "onemklLayout", "onemklJob", "onemklJobsvd",
353+
"onemklCompz", "onemklRangev", "onemklIndex", "onemklProperty",
354+
"onemklMatrixView", "onemklMatmatRequest")
355+
parameters = replace(parameters, Regex("$type ([A-Za-z_]+),") => SubstitutionString("convert(\\1),"))
356+
parameters = replace(parameters, Regex(", $type ([A-Za-z_]+)") => SubstitutionString(", convert(\\1)"))
349357
end
350-
parameters = replace(parameters, r" >([a-z]+)" => s" >(\1)")
351-
parameters = replace(parameters, r" \*>([a-z]+)" => s"*>(\1)")
358+
parameters = replace(parameters, r" >([A-Za-z_]+)" => s" >(\1)")
359+
parameters = replace(parameters, r" \*>([A-Za-z_]+)" => s"*>(\1)")
352360

353361
variant = ""
354362
if library == "blas"
@@ -361,13 +369,17 @@ function generate_cpp(library::String, filename::String, output::String)
361369
!occursin("scratchpad_size", name) && write(oneapi_cpp, " auto status = oneapi::mkl::$library::$variant$name<$type>($parameters);\n")
362370
occursin("scratchpad_size", name) && write(oneapi_cpp, " int64_t scratchpad_size = oneapi::mkl::$library::$variant$name<$type>($parameters);\n")
363371
else
364-
write(oneapi_cpp, " auto status = oneapi::mkl::$library::$variant$name($parameters);\n")
372+
if name ("init_matrix_handle", "init_matmat_descr", "release_matmat_descr", "set_matmat_data")
373+
write(oneapi_cpp, " auto status = oneapi::mkl::$library::$variant$name($parameters);\n")
374+
else
375+
write(oneapi_cpp, " oneapi::mkl::$library::$variant$name($parameters);\n")
376+
end
365377
end
366378
if occursin("scratchpad_size", name)
367379
write(oneapi_cpp, " return scratchpad_size;\n")
368380
else
369-
write(oneapi_cpp, " __FORCE_MKL_FLUSH__(status);\n")
370-
write(oneapi_cpp, " return 0;\n")
381+
(name ("init_matrix_handle", "init_matmat_descr", "release_matmat_descr", "set_matmat_data")) && write(oneapi_cpp, " __FORCE_MKL_FLUSH__(status);\n")
382+
(name ("init_matrix_handle", "init_matmat_descr", "release_matmat_descr", "set_matmat_data")) && write(oneapi_cpp, " return 0;\n")
371383
end
372384
write(oneapi_cpp, "}")
373385
write(oneapi_cpp, "\n\n")
@@ -377,7 +389,7 @@ end
377389

378390
generate_headers("lapack", lapack, "onemkl_lapack.h")
379391
generate_headers("blas", blas, "onemkl_blas.h")
380-
# generate_headers("sparse", sparse, "onemkl_sparse.h")
392+
generate_headers("sparse", sparse, "onemkl_sparse.h")
381393

382394
io = open("src/onemkl.h", "w")
383395
headers_prologue = read("onemkl_prologue.h", String)
@@ -388,16 +400,16 @@ write(io, headers_blas)
388400
headers_lapack = read("onemkl_lapack.h", String)
389401
write(io, "// LAPACK\n")
390402
write(io, headers_lapack)
391-
# headers_sparse = read("onemkl_sparse.h", String)
392-
# write(io, "// SPARSE\n")
393-
# write(io, headers_sparse)
403+
headers_sparse = read("onemkl_sparse.h", String)
404+
write(io, "// SPARSE\n")
405+
write(io, headers_sparse)
394406
headers_epilogue = read("onemkl_epilogue.h", String)
395407
write(io, headers_epilogue)
396408
close(io)
397409

398410
generate_cpp("lapack", lapack, "onemkl_lapack.cpp")
399411
generate_cpp("blas", blas, "onemkl_blas.cpp")
400-
# generate_cpp("sparse", sparse, "onemkl_sparse.cpp")
412+
generate_cpp("sparse", sparse, "onemkl_sparse.cpp")
401413

402414
io = open("src/onemkl.cpp", "w")
403415
cpp_prologue = read("onemkl_prologue.cpp", String)
@@ -408,9 +420,9 @@ write(io, cpp_blas)
408420
cpp_lapack = read("onemkl_lapack.cpp", String)
409421
write(io, "// LAPACK\n")
410422
write(io, cpp_lapack)
411-
# cpp_sparse = read("onemkl_sparse.cpp", String)
412-
# write(io, "// SPARSE\n")
413-
# write(io, cpp_sparse)
423+
cpp_sparse = read("onemkl_sparse.cpp", String)
424+
write(io, "// SPARSE\n")
425+
write(io, cpp_sparse)
414426
cpp_epilogue = read("onemkl_epilogue.cpp", String)
415427
write(io, cpp_epilogue)
416428
close(io)

deps/onemkl_prologue.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,36 @@ oneapi::mkl::sparse::property convert(onemklProperty val) {
179179
}
180180
}
181181

182+
oneapi::mkl::sparse::matrix_view_descr convert(onemklMatrixView val) {
183+
switch (val) {
184+
case ONEMKL_MATRIX_VIEW_GENERAL:
185+
return oneapi::mkl::sparse::matrix_view_descr::general;
186+
}
187+
}
188+
189+
oneapi::mkl::sparse::matmat_request convert(onemklMatmatRequest val) {
190+
switch (val) {
191+
case ONEMKL_MATMAT_REQUEST_GET_WORK_ESTIMATION_BUF_SIZE:
192+
return oneapi::mkl::sparse::matmat_request::get_work_estimation_buf_size;
193+
case ONEMKL_MATMAT_REQUEST_WORK_ESTIMATION:
194+
return oneapi::mkl::sparse::matmat_request::work_estimation;
195+
case ONEMKL_MATMAT_REQUEST_GET_COMPUTE_STRUCTURE_BUF_SIZE:
196+
return oneapi::mkl::sparse::matmat_request::get_compute_structure_buf_size;
197+
case ONEMKL_MATMAT_REQUEST_COMPUTE_STRUCTURE:
198+
return oneapi::mkl::sparse::matmat_request::compute_structure;
199+
case ONEMKL_MATMAT_REQUEST_FINALIZE_STRUCTURE:
200+
return oneapi::mkl::sparse::matmat_request::finalize_structure;
201+
case ONEMKL_MATMAT_REQUEST_GET_COMPUTE_BUF_SIZE:
202+
return oneapi::mkl::sparse::matmat_request::get_compute_buf_size;
203+
case ONEMKL_MATMAT_REQUEST_COMPUTE:
204+
return oneapi::mkl::sparse::matmat_request::compute;
205+
case ONEMKL_MATMAT_REQUEST_GET_NNZ:
206+
return oneapi::mkl::sparse::matmat_request::get_nnz;
207+
case ONEMKL_MATMAT_REQUEST_FINALIZE:
208+
return oneapi::mkl::sparse::matmat_request::finalize;
209+
}
210+
}
211+
182212
// gemm
183213
// https://spec.oneapi.io/versions/1.0-rev-1/elements/oneMKL/source/domains/blas/gemm.html
184214
class gemmBatchInfo {

deps/onemkl_prologue.h

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,27 @@ typedef enum {
104104
ONEMKL_PROPERTY_SORTED,
105105
} onemklProperty;
106106

107-
// I need help :(
108-
typedef struct MatrixHandle_st *MatrixHandle_t;
107+
typedef enum {
108+
ONEMKL_MATRIX_VIEW_GENERAL,
109+
} onemklMatrixView;
110+
111+
typedef enum {
112+
ONEMKL_MATMAT_REQUEST_GET_WORK_ESTIMATION_BUF_SIZE,
113+
ONEMKL_MATMAT_REQUEST_WORK_ESTIMATION,
114+
ONEMKL_MATMAT_REQUEST_GET_COMPUTE_STRUCTURE_BUF_SIZE,
115+
ONEMKL_MATMAT_REQUEST_COMPUTE_STRUCTURE,
116+
ONEMKL_MATMAT_REQUEST_FINALIZE_STRUCTURE,
117+
ONEMKL_MATMAT_REQUEST_GET_COMPUTE_BUF_SIZE,
118+
ONEMKL_MATMAT_REQUEST_COMPUTE,
119+
ONEMKL_MATMAT_REQUEST_GET_NNZ,
120+
ONEMKL_MATMAT_REQUEST_FINALIZE,
121+
} onemklMatmatRequest;
122+
123+
struct matrix_handle;
124+
typedef struct matrix_handle *matrix_handle_t;
125+
126+
struct matmat_descr;
127+
typedef struct matmat_descr *matmat_descr_t;
109128

110129
int onemklHgemmBatched(syclQueue_t device_queue, onemklTranspose transa,
111130
onemklTranspose transb, int64_t *m,

0 commit comments

Comments
 (0)