Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions deps/generate_interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -447,13 +447,17 @@ function generate_cpp(library::String, filename::Vector{String}, output::String;
write(oneapi_cpp, "extern \"C\" $header {\n")
if template
type = version_types[version]
!occursin("scratchpad_size", name) && write(oneapi_cpp, " auto status = oneapi::mkl::$library::$variant$name<$type>($parameters, {});\n")
occursin("scratchpad_size", name) && write(oneapi_cpp, " int64_t scratchpad_size = oneapi::mkl::$library::$variant$name<$type>($parameters);\n")
!occursin("scratchpad_size", name) && write(oneapi_cpp, " auto status = oneapi::mkl::$library::$variant$name<$type>($parameters, {});\n device_queue->val.wait_and_throw();\n")
occursin("scratchpad_size", name) && write(oneapi_cpp, " int64_t scratchpad_size = oneapi::mkl::$library::$variant$name<$type>($parameters);\n device_queue->val.wait_and_throw();\n")
# !occursin("scratchpad_size", name) && write(oneapi_cpp, " auto status = oneapi::mkl::$library::$variant$name<$type>($parameters, {});\n")
# occursin("scratchpad_size", name) && write(oneapi_cpp, " int64_t scratchpad_size = oneapi::mkl::$library::$variant$name<$type>($parameters);\n")
else
if !(name ∈ void_output)
write(oneapi_cpp, " auto status = oneapi::mkl::$library::$variant$name($parameters, {});\n")
occursin("device_queue", parameters) && write(oneapi_cpp, " device_queue->val.wait_and_throw();\n")
else
write(oneapi_cpp, " oneapi::mkl::$library::$variant$name($parameters);\n")
occursin("device_queue", parameters) && write(oneapi_cpp, " device_queue->val.wait_and_throw();\n")
end
end
if occursin("scratchpad_size", name)
Expand Down
1 change: 1 addition & 0 deletions deps/onemkl_epilogue.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
extern "C" int onemklXsparse_matmat(syclQueue_t device_queue, matrix_handle_t A, matrix_handle_t B, matrix_handle_t C, onemklMatmatRequest req, matmat_descr_t descr, int64_t *sizeTempBuffer, void *tempBuffer) {
auto status = oneapi::mkl::sparse::matmat(device_queue->val, (oneapi::mkl::sparse::matrix_handle_t) A, (oneapi::mkl::sparse::matrix_handle_t) B, (oneapi::mkl::sparse::matrix_handle_t) C, convert(req), (oneapi::mkl::sparse::matmat_descr_t) descr, sizeTempBuffer, tempBuffer, {});
device_queue->val.wait_and_throw();
return 0;
}

Expand Down
18 changes: 18 additions & 0 deletions deps/onemkl_prologue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -393,13 +393,15 @@ extern "C" int onemklHgemm_batch(syclQueue_t device_queue, onemklTranspose trans
int64_t *ldb, uint16_t *beta, short **c,
int64_t *ldc, int64_t group_count, int64_t *group_size) {
gemmBatchInfo gemmInfo(device_queue, group_count, transa, transb);
device_queue->val.wait_and_throw();
auto status = oneapi::mkl::blas::column_major::gemm_batch(device_queue->val,
&gemmInfo.m_transa[0], &gemmInfo.m_transb[0],
m, n, k, reinterpret_cast<sycl::half *>(alpha),
reinterpret_cast<const sycl::half **>(&a[0]), lda,
reinterpret_cast<const sycl::half **>(&b[0]), ldb,
reinterpret_cast<sycl::half *>(beta), reinterpret_cast<sycl::half **>(&c[0]),
ldc, group_count, group_size, {});
device_queue->val.wait_and_throw();
return 0;
}

Expand All @@ -410,13 +412,15 @@ extern "C" int onemklSgemm_batch(syclQueue_t device_queue, onemklTranspose trans
int64_t *ldb, float *beta, float **c,
int64_t *ldc, int64_t group_count, int64_t *group_size) {
gemmBatchInfo gemmInfo(device_queue, group_count, transa, transb);
device_queue->val.wait_and_throw();
auto status = oneapi::mkl::blas::column_major::gemm_batch(device_queue->val,
&gemmInfo.m_transa[0], &gemmInfo.m_transb[0],
m, n, k, alpha,
(const float **)&a[0], lda,
(const float **)&b[0], ldb,
beta, &c[0], ldc,
group_count, group_size, {});
device_queue->val.wait_and_throw();
return 0;
}

Expand All @@ -427,13 +431,15 @@ extern "C" int onemklDgemm_batch(syclQueue_t device_queue, onemklTranspose trans
int64_t *ldb, double *beta, double **c,
int64_t *ldc, int64_t group_count, int64_t *group_size) {
gemmBatchInfo gemmInfo(device_queue, group_count, transa, transb);
device_queue->val.wait_and_throw();
auto status = oneapi::mkl::blas::column_major::gemm_batch(device_queue->val,
&gemmInfo.m_transa[0], &gemmInfo.m_transb[0],
m, n, k, alpha,
(const double **)&a[0], lda,
(const double **)&b[0], ldb,
beta, &c[0], ldc,
group_count, group_size, {});
device_queue->val.wait_and_throw();
return 0;
}

Expand All @@ -445,6 +451,7 @@ extern "C" int onemklCgemm_batch(syclQueue_t device_queue, onemklTranspose trans
int64_t *ldb, float _Complex *beta, float _Complex **c,
int64_t *ldc, int64_t group_count, int64_t *group_size) {
gemmBatchInfo gemmInfo(device_queue, group_count, transa, transb);
device_queue->val.wait_and_throw();
auto status = oneapi::mkl::blas::column_major::gemm_batch(device_queue->val,
&gemmInfo.m_transa[0], &gemmInfo.m_transb[0],
m, n, k, reinterpret_cast<std::complex<float> *>(alpha),
Expand All @@ -455,6 +462,7 @@ extern "C" int onemklCgemm_batch(syclQueue_t device_queue, onemklTranspose trans
reinterpret_cast<std::complex<float> *>(beta),
reinterpret_cast<std::complex<float> **>(&c[0]), ldc,
group_count, group_size, {});
device_queue->val.wait_and_throw();
return 0;
}

Expand All @@ -467,6 +475,7 @@ extern "C" int onemklZgemm_batch(syclQueue_t device_queue, onemklTranspose trans
double _Complex **c,
int64_t *ldc, int64_t group_count, int64_t *group_size) {
gemmBatchInfo gemmInfo(device_queue, group_count, transa, transb);
device_queue->val.wait_and_throw();
auto status = oneapi::mkl::blas::column_major::gemm_batch(device_queue->val,
&gemmInfo.m_transa[0], &gemmInfo.m_transb[0],
m, n, k, reinterpret_cast<std::complex<double> *>(alpha),
Expand All @@ -477,6 +486,7 @@ extern "C" int onemklZgemm_batch(syclQueue_t device_queue, onemklTranspose trans
reinterpret_cast<std::complex<double> *>(beta),
reinterpret_cast<std::complex<double> **>(&c[0]), ldc,
group_count, group_size, {});
device_queue->val.wait_and_throw();
return 0;
}

Expand All @@ -487,12 +497,14 @@ extern "C" int onemklStrsm_batch(syclQueue_t device_queue, onemklSide left_right
int64_t group_count, int64_t *group_size) {
trsmBatchInfo trsmInfo(device_queue, left_right, upper_lower, transa,
unit_diag, group_count);
device_queue->val.wait_and_throw();

auto status = oneapi::mkl::blas::column_major::trsm_batch(device_queue->val,
&trsmInfo.m_leftright[0], &trsmInfo.m_upperlower[0],
&trsmInfo.m_transa[0], &trsmInfo.m_unitdiag[0],
m, n, alpha, (const float **)&a[0], lda,
&b[0], ldb, group_count, group_size, {});
device_queue->val.wait_and_throw();
return 0;
}

Expand All @@ -504,12 +516,14 @@ extern "C" int onemklDtrsm_batch(syclQueue_t device_queue, onemklSide left_right
int64_t *group_size) {
trsmBatchInfo trsmInfo(device_queue, left_right, upper_lower, transa,
unit_diag, group_count);
device_queue->val.wait_and_throw();

auto status = oneapi::mkl::blas::column_major::trsm_batch(device_queue->val,
&trsmInfo.m_leftright[0], &trsmInfo.m_upperlower[0],
&trsmInfo.m_transa[0], &trsmInfo.m_unitdiag[0],
m, n, alpha, (const double **)&a[0], lda, &b[0],
ldb, group_count, group_size, {});
device_queue->val.wait_and_throw();
return 0;
}

Expand All @@ -521,6 +535,7 @@ extern "C" int onemklCtrsm_batch(syclQueue_t device_queue, onemklSide left_right
int64_t group_count, int64_t *group_size) {
trsmBatchInfo trsmInfo(device_queue, left_right, upper_lower, transa,
unit_diag, group_count);
device_queue->val.wait_and_throw();

auto status = oneapi::mkl::blas::column_major::trsm_batch(device_queue->val,
&trsmInfo.m_leftright[0], &trsmInfo.m_upperlower[0],
Expand All @@ -529,6 +544,7 @@ extern "C" int onemklCtrsm_batch(syclQueue_t device_queue, onemklSide left_right
reinterpret_cast<const std::complex<float> **>(&a[0]),
lda, reinterpret_cast<std::complex<float> **>(&b[0]),
ldb, group_count, group_size, {});
device_queue->val.wait_and_throw();
return 0;
}

Expand All @@ -540,6 +556,7 @@ extern "C" int onemklZtrsm_batch(syclQueue_t device_queue, onemklSide left_right
int64_t group_count, int64_t *group_size) {
trsmBatchInfo trsmInfo(device_queue, left_right,
upper_lower, transa, unit_diag, group_count);
device_queue->val.wait_and_throw();

auto status = oneapi::mkl::blas::column_major::trsm_batch(device_queue->val,
&trsmInfo.m_leftright[0], &trsmInfo.m_upperlower[0],
Expand All @@ -548,5 +565,6 @@ extern "C" int onemklZtrsm_batch(syclQueue_t device_queue, onemklSide left_right
reinterpret_cast<const std::complex<double> **>(&a[0]),
lda, reinterpret_cast<std::complex<double> **>(&b[0]),
ldb, group_count, group_size, {});
device_queue->val.wait_and_throw();
return 0;
}
Loading