Skip to content

Commit 0bca405

Browse files
committed
We must use the buffer version of set_csr_data
1 parent 6c6018a commit 0bca405

File tree

2 files changed

+10
-18
lines changed

2 files changed

+10
-18
lines changed

deps/generate_interfaces.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ function generate_cpp(library::String, filename::String, output::String)
366366
!occursin("scratchpad_size", name) && write(oneapi_cpp, " auto status = oneapi::mkl::$library::$variant$name<$type>($parameters, {});\n")
367367
occursin("scratchpad_size", name) && write(oneapi_cpp, " int64_t scratchpad_size = oneapi::mkl::$library::$variant$name<$type>($parameters);\n")
368368
else
369-
if name ("init_matrix_handle", "init_matmat_descr", "release_matmat_descr", "set_matmat_data")
369+
if name ("init_matrix_handle", "init_matmat_descr", "release_matmat_descr", "set_matmat_data", "set_csr_data")
370370
write(oneapi_cpp, " auto status = oneapi::mkl::$library::$variant$name($parameters, {});\n")
371371
else
372372
write(oneapi_cpp, " oneapi::mkl::$library::$variant$name($parameters);\n")
@@ -375,7 +375,7 @@ function generate_cpp(library::String, filename::String, output::String)
375375
if occursin("scratchpad_size", name)
376376
write(oneapi_cpp, " return scratchpad_size;\n")
377377
else
378-
(name ("init_matrix_handle", "init_matmat_descr", "release_matmat_descr", "set_matmat_data")) && write(oneapi_cpp, " __FORCE_MKL_FLUSH__(status);\n")
378+
(name ("init_matrix_handle", "init_matmat_descr", "release_matmat_descr", "set_matmat_data", "set_csr_data")) && write(oneapi_cpp, " __FORCE_MKL_FLUSH__(status);\n")
379379
write(oneapi_cpp, " return 0;\n")
380380
end
381381
write(oneapi_cpp, "}")

deps/src/onemkl.cpp

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3467,50 +3467,42 @@ extern "C" int onemklXsparse_init_matrix_handle(matrix_handle_t *handle) {
34673467
}
34683468

34693469
extern "C" int onemklSsparse_set_csr_data(syclQueue_t device_queue, matrix_handle_t handle, int32_t num_rows, int32_t num_cols, onemklIndex index, int32_t *row_ptr, int32_t *col_ind, float *val) {
3470-
auto status = oneapi::mkl::sparse::set_csr_data(device_queue->val, (oneapi::mkl::sparse::matrix_handle_t) handle, num_rows, num_cols, convert(index), row_ptr, col_ind, val, {});
3471-
__FORCE_MKL_FLUSH__(status);
3470+
oneapi::mkl::sparse::set_csr_data(device_queue->val, (oneapi::mkl::sparse::matrix_handle_t) handle, num_rows, num_cols, convert(index), row_ptr, col_ind, val);
34723471
return 0;
34733472
}
34743473

34753474
extern "C" int onemklSsparse_set_csr_data_64(syclQueue_t device_queue, matrix_handle_t handle, int64_t num_rows, int64_t num_cols, onemklIndex index, int64_t *row_ptr, int64_t *col_ind, float *val) {
3476-
auto status = oneapi::mkl::sparse::set_csr_data(device_queue->val, (oneapi::mkl::sparse::matrix_handle_t) handle, num_rows, num_cols, convert(index), row_ptr, col_ind, val, {});
3477-
__FORCE_MKL_FLUSH__(status);
3475+
oneapi::mkl::sparse::set_csr_data(device_queue->val, (oneapi::mkl::sparse::matrix_handle_t) handle, num_rows, num_cols, convert(index), row_ptr, col_ind, val);
34783476
return 0;
34793477
}
34803478

34813479
extern "C" int onemklDsparse_set_csr_data(syclQueue_t device_queue, matrix_handle_t handle, int32_t num_rows, int32_t num_cols, onemklIndex index, int32_t *row_ptr, int32_t *col_ind, double *val) {
3482-
auto status = oneapi::mkl::sparse::set_csr_data(device_queue->val, (oneapi::mkl::sparse::matrix_handle_t) handle, num_rows, num_cols, convert(index), row_ptr, col_ind, val, {});
3483-
__FORCE_MKL_FLUSH__(status);
3480+
oneapi::mkl::sparse::set_csr_data(device_queue->val, (oneapi::mkl::sparse::matrix_handle_t) handle, num_rows, num_cols, convert(index), row_ptr, col_ind, val);
34843481
return 0;
34853482
}
34863483

34873484
extern "C" int onemklDsparse_set_csr_data_64(syclQueue_t device_queue, matrix_handle_t handle, int64_t num_rows, int64_t num_cols, onemklIndex index, int64_t *row_ptr, int64_t *col_ind, double *val) {
3488-
auto status = oneapi::mkl::sparse::set_csr_data(device_queue->val, (oneapi::mkl::sparse::matrix_handle_t) handle, num_rows, num_cols, convert(index), row_ptr, col_ind, val, {});
3489-
__FORCE_MKL_FLUSH__(status);
3485+
oneapi::mkl::sparse::set_csr_data(device_queue->val, (oneapi::mkl::sparse::matrix_handle_t) handle, num_rows, num_cols, convert(index), row_ptr, col_ind, val);
34903486
return 0;
34913487
}
34923488

34933489
extern "C" int onemklCsparse_set_csr_data(syclQueue_t device_queue, matrix_handle_t handle, int32_t num_rows, int32_t num_cols, onemklIndex index, int32_t *row_ptr, int32_t *col_ind, float _Complex *val) {
3494-
auto status = oneapi::mkl::sparse::set_csr_data(device_queue->val, (oneapi::mkl::sparse::matrix_handle_t) handle, num_rows, num_cols, convert(index), row_ptr, col_ind, reinterpret_cast<std::complex<float>*>(val), {});
3495-
__FORCE_MKL_FLUSH__(status);
3490+
oneapi::mkl::sparse::set_csr_data(device_queue->val, (oneapi::mkl::sparse::matrix_handle_t) handle, num_rows, num_cols, convert(index), row_ptr, col_ind, reinterpret_cast<std::complex<float>*>(val));
34963491
return 0;
34973492
}
34983493

34993494
extern "C" int onemklCsparse_set_csr_data_64(syclQueue_t device_queue, matrix_handle_t handle, int64_t num_rows, int64_t num_cols, onemklIndex index, int64_t *row_ptr, int64_t *col_ind, float _Complex *val) {
3500-
auto status = oneapi::mkl::sparse::set_csr_data(device_queue->val, (oneapi::mkl::sparse::matrix_handle_t) handle, num_rows, num_cols, convert(index), row_ptr, col_ind, reinterpret_cast<std::complex<float>*>(val), {});
3501-
__FORCE_MKL_FLUSH__(status);
3495+
oneapi::mkl::sparse::set_csr_data(device_queue->val, (oneapi::mkl::sparse::matrix_handle_t) handle, num_rows, num_cols, convert(index), row_ptr, col_ind, reinterpret_cast<std::complex<float>*>(val));
35023496
return 0;
35033497
}
35043498

35053499
extern "C" int onemklZsparse_set_csr_data(syclQueue_t device_queue, matrix_handle_t handle, int32_t num_rows, int32_t num_cols, onemklIndex index, int32_t *row_ptr, int32_t *col_ind, double _Complex *val) {
3506-
auto status = oneapi::mkl::sparse::set_csr_data(device_queue->val, (oneapi::mkl::sparse::matrix_handle_t) handle, num_rows, num_cols, convert(index), row_ptr, col_ind, reinterpret_cast<std::complex<double>*>(val), {});
3507-
__FORCE_MKL_FLUSH__(status);
3500+
oneapi::mkl::sparse::set_csr_data(device_queue->val, (oneapi::mkl::sparse::matrix_handle_t) handle, num_rows, num_cols, convert(index), row_ptr, col_ind, reinterpret_cast<std::complex<double>*>(val));
35083501
return 0;
35093502
}
35103503

35113504
extern "C" int onemklZsparse_set_csr_data_64(syclQueue_t device_queue, matrix_handle_t handle, int64_t num_rows, int64_t num_cols, onemklIndex index, int64_t *row_ptr, int64_t *col_ind, double _Complex *val) {
3512-
auto status = oneapi::mkl::sparse::set_csr_data(device_queue->val, (oneapi::mkl::sparse::matrix_handle_t) handle, num_rows, num_cols, convert(index), row_ptr, col_ind, reinterpret_cast<std::complex<double>*>(val), {});
3513-
__FORCE_MKL_FLUSH__(status);
3505+
oneapi::mkl::sparse::set_csr_data(device_queue->val, (oneapi::mkl::sparse::matrix_handle_t) handle, num_rows, num_cols, convert(index), row_ptr, col_ind, reinterpret_cast<std::complex<double>*>(val));
35143506
return 0;
35153507
}
35163508

0 commit comments

Comments
 (0)