Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions deps/generate_interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,7 @@ function generate_cpp(library::String, filename::Vector{String}, output::String;
else
if !(name ∈ void_output)
write(oneapi_cpp, " auto status = oneapi::mkl::$library::$variant$name($parameters, {});\n")
write(oneapi_cpp, " device_queue->val.wait_and_throw();\n")
else
write(oneapi_cpp, " oneapi::mkl::$library::$variant$name($parameters);\n")
end
Expand Down
1 change: 1 addition & 0 deletions deps/onemkl_epilogue.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
extern "C" int onemklXsparse_matmat(syclQueue_t device_queue, matrix_handle_t A, matrix_handle_t B, matrix_handle_t C, onemklMatmatRequest req, matmat_descr_t descr, int64_t *sizeTempBuffer, void *tempBuffer) {
auto status = oneapi::mkl::sparse::matmat(device_queue->val, (oneapi::mkl::sparse::matrix_handle_t) A, (oneapi::mkl::sparse::matrix_handle_t) B, (oneapi::mkl::sparse::matrix_handle_t) C, convert(req), (oneapi::mkl::sparse::matmat_descr_t) descr, sizeTempBuffer, tempBuffer, {});
device_queue->val.wait_and_throw();
return 0;
}

Expand Down
18 changes: 18 additions & 0 deletions deps/onemkl_prologue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -393,13 +393,15 @@ extern "C" int onemklHgemm_batch(syclQueue_t device_queue, onemklTranspose trans
int64_t *ldb, uint16_t *beta, short **c,
int64_t *ldc, int64_t group_count, int64_t *group_size) {
gemmBatchInfo gemmInfo(device_queue, group_count, transa, transb);
device_queue->val.wait_and_throw();
auto status = oneapi::mkl::blas::column_major::gemm_batch(device_queue->val,
&gemmInfo.m_transa[0], &gemmInfo.m_transb[0],
m, n, k, reinterpret_cast<sycl::half *>(alpha),
reinterpret_cast<const sycl::half **>(&a[0]), lda,
reinterpret_cast<const sycl::half **>(&b[0]), ldb,
reinterpret_cast<sycl::half *>(beta), reinterpret_cast<sycl::half **>(&c[0]),
ldc, group_count, group_size, {});
device_queue->val.wait_and_throw();
return 0;
}

Expand All @@ -410,13 +412,15 @@ extern "C" int onemklSgemm_batch(syclQueue_t device_queue, onemklTranspose trans
int64_t *ldb, float *beta, float **c,
int64_t *ldc, int64_t group_count, int64_t *group_size) {
gemmBatchInfo gemmInfo(device_queue, group_count, transa, transb);
device_queue->val.wait_and_throw();
auto status = oneapi::mkl::blas::column_major::gemm_batch(device_queue->val,
&gemmInfo.m_transa[0], &gemmInfo.m_transb[0],
m, n, k, alpha,
(const float **)&a[0], lda,
(const float **)&b[0], ldb,
beta, &c[0], ldc,
group_count, group_size, {});
device_queue->val.wait_and_throw();
return 0;
}

Expand All @@ -427,13 +431,15 @@ extern "C" int onemklDgemm_batch(syclQueue_t device_queue, onemklTranspose trans
int64_t *ldb, double *beta, double **c,
int64_t *ldc, int64_t group_count, int64_t *group_size) {
gemmBatchInfo gemmInfo(device_queue, group_count, transa, transb);
device_queue->val.wait_and_throw();
auto status = oneapi::mkl::blas::column_major::gemm_batch(device_queue->val,
&gemmInfo.m_transa[0], &gemmInfo.m_transb[0],
m, n, k, alpha,
(const double **)&a[0], lda,
(const double **)&b[0], ldb,
beta, &c[0], ldc,
group_count, group_size, {});
device_queue->val.wait_and_throw();
return 0;
}

Expand All @@ -445,6 +451,7 @@ extern "C" int onemklCgemm_batch(syclQueue_t device_queue, onemklTranspose trans
int64_t *ldb, float _Complex *beta, float _Complex **c,
int64_t *ldc, int64_t group_count, int64_t *group_size) {
gemmBatchInfo gemmInfo(device_queue, group_count, transa, transb);
device_queue->val.wait_and_throw();
auto status = oneapi::mkl::blas::column_major::gemm_batch(device_queue->val,
&gemmInfo.m_transa[0], &gemmInfo.m_transb[0],
m, n, k, reinterpret_cast<std::complex<float> *>(alpha),
Expand All @@ -455,6 +462,7 @@ extern "C" int onemklCgemm_batch(syclQueue_t device_queue, onemklTranspose trans
reinterpret_cast<std::complex<float> *>(beta),
reinterpret_cast<std::complex<float> **>(&c[0]), ldc,
group_count, group_size, {});
device_queue->val.wait_and_throw();
return 0;
}

Expand All @@ -467,6 +475,7 @@ extern "C" int onemklZgemm_batch(syclQueue_t device_queue, onemklTranspose trans
double _Complex **c,
int64_t *ldc, int64_t group_count, int64_t *group_size) {
gemmBatchInfo gemmInfo(device_queue, group_count, transa, transb);
device_queue->val.wait_and_throw();
auto status = oneapi::mkl::blas::column_major::gemm_batch(device_queue->val,
&gemmInfo.m_transa[0], &gemmInfo.m_transb[0],
m, n, k, reinterpret_cast<std::complex<double> *>(alpha),
Expand All @@ -477,6 +486,7 @@ extern "C" int onemklZgemm_batch(syclQueue_t device_queue, onemklTranspose trans
reinterpret_cast<std::complex<double> *>(beta),
reinterpret_cast<std::complex<double> **>(&c[0]), ldc,
group_count, group_size, {});
device_queue->val.wait_and_throw();
return 0;
}

Expand All @@ -487,12 +497,14 @@ extern "C" int onemklStrsm_batch(syclQueue_t device_queue, onemklSide left_right
int64_t group_count, int64_t *group_size) {
trsmBatchInfo trsmInfo(device_queue, left_right, upper_lower, transa,
unit_diag, group_count);
device_queue->val.wait_and_throw();

auto status = oneapi::mkl::blas::column_major::trsm_batch(device_queue->val,
&trsmInfo.m_leftright[0], &trsmInfo.m_upperlower[0],
&trsmInfo.m_transa[0], &trsmInfo.m_unitdiag[0],
m, n, alpha, (const float **)&a[0], lda,
&b[0], ldb, group_count, group_size, {});
device_queue->val.wait_and_throw();
return 0;
}

Expand All @@ -504,12 +516,14 @@ extern "C" int onemklDtrsm_batch(syclQueue_t device_queue, onemklSide left_right
int64_t *group_size) {
trsmBatchInfo trsmInfo(device_queue, left_right, upper_lower, transa,
unit_diag, group_count);
device_queue->val.wait_and_throw();

auto status = oneapi::mkl::blas::column_major::trsm_batch(device_queue->val,
&trsmInfo.m_leftright[0], &trsmInfo.m_upperlower[0],
&trsmInfo.m_transa[0], &trsmInfo.m_unitdiag[0],
m, n, alpha, (const double **)&a[0], lda, &b[0],
ldb, group_count, group_size, {});
device_queue->val.wait_and_throw();
return 0;
}

Expand All @@ -521,6 +535,7 @@ extern "C" int onemklCtrsm_batch(syclQueue_t device_queue, onemklSide left_right
int64_t group_count, int64_t *group_size) {
trsmBatchInfo trsmInfo(device_queue, left_right, upper_lower, transa,
unit_diag, group_count);
device_queue->val.wait_and_throw();

auto status = oneapi::mkl::blas::column_major::trsm_batch(device_queue->val,
&trsmInfo.m_leftright[0], &trsmInfo.m_upperlower[0],
Expand All @@ -529,6 +544,7 @@ extern "C" int onemklCtrsm_batch(syclQueue_t device_queue, onemklSide left_right
reinterpret_cast<const std::complex<float> **>(&a[0]),
lda, reinterpret_cast<std::complex<float> **>(&b[0]),
ldb, group_count, group_size, {});
device_queue->val.wait_and_throw();
return 0;
}

Expand All @@ -540,6 +556,7 @@ extern "C" int onemklZtrsm_batch(syclQueue_t device_queue, onemklSide left_right
int64_t group_count, int64_t *group_size) {
trsmBatchInfo trsmInfo(device_queue, left_right,
upper_lower, transa, unit_diag, group_count);
device_queue->val.wait_and_throw();

auto status = oneapi::mkl::blas::column_major::trsm_batch(device_queue->val,
&trsmInfo.m_leftright[0], &trsmInfo.m_upperlower[0],
Expand All @@ -548,5 +565,6 @@ extern "C" int onemklZtrsm_batch(syclQueue_t device_queue, onemklSide left_right
reinterpret_cast<const std::complex<double> **>(&a[0]),
lda, reinterpret_cast<std::complex<double> **>(&b[0]),
ldb, group_count, group_size, {});
device_queue->val.wait_and_throw();
return 0;
}
Loading