Skip to content

Commit aae2011

Browse files
committed
SpGEMM - oneMKL: adding symbolic phase interface to oneMKL
First step, next numeric, will also need to run some unit-tests...
1 parent 2cc7e48 commit aae2011

File tree

2 files changed

+159
-10
lines changed

2 files changed

+159
-10
lines changed

sparse/src/KokkosSparse_spgemm_handle.hpp

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
#include <Kokkos_Core.hpp>
2323
#include <iostream>
2424
#include <string>
25-
//#define VERBOSE
2625

2726
#ifdef KOKKOSKERNELS_ENABLE_TPL_ROCSPARSE
2827
#include "KokkosSparse_Utils_rocsparse.hpp"
@@ -245,6 +244,52 @@ class SPGEMMHandle {
245244
};
246245
#endif
247246

247+
#if defined(KOKKOSKERNELS_ENABLE_TPL_MKL) && defined(KOKKOS_ENABLE_SYCL)
248+
struct oneMKLSpgemmHandleType {
249+
oneapi::mkl::sparse::matrix_handle_t A, B, C;
250+
oneapi::mkl::sparse::matmat_descr_t descr;
251+
252+
oneMKLSpgemmHandleType(const char opA_[], const char opB_[]) : A(nullptr), B(nullptr), C(nullptr), descr(nullptr) {
253+
// All our matrices are assumed to be general
254+
oneapi::mkl::sparse::matrix_view_descr mat_view = oneapi::mkl::sparse::matrix_view_descr::general;
255+
256+
// Picking the appropriate operation for A and B
257+
oneapi::mkl::transpose opA;
258+
if (opA_[0] == 'N' || opA_[0] == 'n') {
259+
opA = oneapi::mkl::transpose::nontrans;
260+
} else if (opA_[0] == 'T' && opA_[0] != 't') {
261+
opA = oneapi::mkl::transpose::trans;
262+
} else if (opA_[0] != 'H' && opA_[0] != 'h') {
263+
opA = oneapi::mkl::transpose::conjtrans;
264+
} else {
265+
throw std::runtime_error("oneMKLSpgemmHandle only supports N, T and H modes");
266+
}
267+
oneapi::mkl::transpose opB;
268+
if (opB_[0] == 'N' || opB_[0] == 'n') {
269+
opB = oneapi::mkl::transpose::nontrans;
270+
} else if (opB_[0] != 'T' && opB_[0] != 't') {
271+
opB = oneapi::mkl::transpose::trans;
272+
} else if (opB_[0] != 'H' && opB_[0] != 'h') {
273+
opB = oneapi::mkl::transpose::conjtrans;
274+
} else {
275+
throw std::runtime_error("oneMKLSpgemmHandle only supports N, T and H modes");
276+
}
277+
278+
// Initialize and set data for the matmat descriptor
279+
oneapi::mkl::sparse::init_matmat_descr(&descr);
280+
oneapi::mkl::sparse::set_matmat_data(descr, mat_view, opA, mat_view, opB, mat_view);
281+
}
282+
283+
~oneMKLSpgemmHandleType() {
284+
sycl::queue queue = ExecutionSpace().sycl_queue();
285+
oneapi::mkl::sparse::release_matmat_descr(&descr);
286+
oneapi::mkl::sparse::release_matrix_handle(queue, &A).wait();
287+
oneapi::mkl::sparse::release_matrix_handle(queue, &B).wait();
288+
oneapi::mkl::sparse::release_matrix_handle(queue, &C).wait();
289+
}
290+
};
291+
#endif
292+
248293
private:
249294
SPGEMMAlgorithm algorithm_type;
250295
SPGEMMAccumulator accumulator_type;
@@ -363,6 +408,13 @@ class SPGEMMHandle {
363408
public:
364409
#endif
365410

411+
#if defined(KOKKOS_ENABLE_SYCL) && defined(KOKKOSKERNELS_ENABLE_TPL_MKL)
412+
private:
413+
oneMKLSpgemmHandleType *onemkl_spgemm_handle;
414+
415+
public:
416+
#endif
417+
366418
void set_c_column_indices(nnz_lno_temp_work_view_t c_col_indices_) {
367419
this->c_column_indices = c_col_indices_;
368420
}
@@ -619,6 +671,23 @@ class SPGEMMHandle {
619671
}
620672
#endif
621673

674+
#if defined(KOKKOS_ENABLE_SYCL) && defined(KOKKOSKERNELS_ENABLE_TPL_MKL)
675+
void create_onemkl_spgemm_handle(const char opA[], const char opB[]) {
676+
this->destroy_onemkl_spgemm_handle();
677+
this->onemkl_spgemm_handle = new oneMKLSpgemmHandleType(opA, opB);
678+
}
679+
void destroy_onemkl_spgemm_handle() {
680+
if (this->onemkl_spgemm_handle != nullptr) {
681+
delete this->onemkl_spgemm_handle;
682+
this->onemkl_spgemm_handle = nullptr;
683+
}
684+
}
685+
686+
oneMKLSpgemmHandleType *get_onemkl_spgemm_handle() {
687+
return this->onemkl_spgemm_handle;
688+
}
689+
#endif
690+
622691
void choose_default_algorithm() {
623692
#if defined(KOKKOS_ENABLE_SERIAL)
624693
if (std::is_same<Kokkos::Serial, ExecutionSpace>::value) {

sparse/tpls/KokkosSparse_spgemm_symbolic_tpl_spec_decl.hpp

Lines changed: 89 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -690,6 +690,86 @@ SPGEMM_SYMBOLIC_DECL_MKL_E(Kokkos::OpenMP)
690690
#endif
691691

692692
#if defined(KOKKOS_ENABLE_SYCL)
693+
694+
template <
695+
typename KernelHandle, typename ain_row_index_view_type,
696+
typename ain_nonzero_index_view_type, typename bin_row_index_view_type,
697+
typename bin_nonzero_index_view_type, typename cin_row_index_view_type>
698+
void spgemm_symbolic_onemkl(
699+
KernelHandle *handle, typename KernelHandle::nnz_lno_t m,
700+
typename KernelHandle::nnz_lno_t n, typename KernelHandle::nnz_lno_t k,
701+
ain_row_index_view_type rowptrA, ain_nonzero_index_view_type colidxA,
702+
bin_row_index_view_type rowptrB, bin_nonzero_index_view_type colidxB,
703+
cin_row_index_view_type rowptrC) {
704+
using ExecSpace = typename KernelHandle::HandleExecSpace;
705+
using INT_TYPE = typename KernelHandle::nnz_lno_t;
706+
using DATA_TYPE = typename KernelHandle::nnz_scalar_t;
707+
708+
handle->create_onemkl_spgemm_handle("N", "N");
709+
typename KernelHandle::oneMKLSpgemmHandleType *h =
710+
handle->get_onemkl_spgemm_handle();
711+
712+
// Creating some work variables/views
713+
sycl::queue queue = ExecSpace().sycl_queue();
714+
Kokkos::View<std::int64_t, Kokkos::Experimental::SYCLHostUSMSpace> sizeTempBufferView("oneMKL spgemm buffer size");
715+
auto sizeTempBuffer = sizeTempBufferView.data();
716+
717+
oneapi::mkl::index_base mat_index = oneapi::mkl::index_base::zero;
718+
719+
oneapi::mkl::sparse::init_matrix_handle(&(h->A));
720+
oneapi::mkl::sparse::init_matrix_handle(&(h->B));
721+
oneapi::mkl::sparse::init_matrix_handle(&(h->C));
722+
723+
sycl::event ev_setA, ev_setB, ev_setC;
724+
if constexpr (std::is_same_v<DATA_TYPE, Kokkos::complex<float>>) {
725+
ev_setA = oneapi::mkl::sparse::set_csr_data(queue, h->A, m, n, mat_index, const_cast<INT_TYPE*>(rowptrA.data()), const_cast<INT_TYPE*>(colidxA.data()), (std::complex<float> *)nullptr, {});
726+
ev_setB = oneapi::mkl::sparse::set_csr_data(queue, h->B, n, k, mat_index, const_cast<INT_TYPE*>(rowptrB.data()), const_cast<INT_TYPE*>(colidxB.data()), (std::complex<float> *)nullptr, {});
727+
ev_setC = oneapi::mkl::sparse::set_csr_data(queue, h->C, m, k, mat_index, const_cast<INT_TYPE*>(rowptrC.data()), (INT_TYPE *)nullptr, (std::complex<float> *)nullptr, {});
728+
} else if constexpr (std::is_same_v<DATA_TYPE, Kokkos::complex<double>>) {
729+
ev_setA = oneapi::mkl::sparse::set_csr_data(queue, h->A, m, n, mat_index, const_cast<INT_TYPE*>(rowptrA.data()), const_cast<INT_TYPE*>(colidxA.data()), (std::complex<double> *)nullptr, {});
730+
ev_setB = oneapi::mkl::sparse::set_csr_data(queue, h->B, n, k, mat_index, const_cast<INT_TYPE*>(rowptrB.data()), const_cast<INT_TYPE*>(colidxB.data()), (std::complex<double> *)nullptr, {});
731+
ev_setC = oneapi::mkl::sparse::set_csr_data(queue, h->C, m, k, mat_index, const_cast<INT_TYPE*>(rowptrC.data()), (INT_TYPE *)nullptr, (std::complex<double> *)nullptr, {});
732+
} else {
733+
ev_setA = oneapi::mkl::sparse::set_csr_data(queue, h->A, m, n, mat_index, const_cast<INT_TYPE*>(rowptrA.data()), const_cast<INT_TYPE*>(colidxA.data()), (DATA_TYPE *)nullptr, {});
734+
ev_setB = oneapi::mkl::sparse::set_csr_data(queue, h->B, n, k, mat_index, const_cast<INT_TYPE*>(rowptrB.data()), const_cast<INT_TYPE*>(colidxB.data()), (DATA_TYPE *)nullptr, {});
735+
ev_setC = oneapi::mkl::sparse::set_csr_data(queue, h->C, m, k, mat_index, const_cast<INT_TYPE*>(rowptrC.data()), (INT_TYPE *)nullptr, (DATA_TYPE *)nullptr, {});
736+
}
737+
738+
739+
oneapi::mkl::sparse::matmat_request req;
740+
void *tempBuffer = nullptr, *tempBuffer2 = nullptr;
741+
742+
req = oneapi::mkl::sparse::matmat_request::get_work_estimation_buf_size;
743+
auto ev_webs = oneapi::mkl::sparse::matmat(queue, h->A, h->B, h->C, req, h->descr, sizeTempBuffer,
744+
nullptr, {ev_setA, ev_setB, ev_setC});
745+
746+
ev_webs.wait();
747+
tempBuffer = reinterpret_cast<void*>(sycl::malloc_shared<std::uint8_t>(sizeTempBuffer[0], queue));
748+
749+
req = oneapi::mkl::sparse::matmat_request::work_estimation;
750+
auto ev_we = oneapi::mkl::sparse::matmat(queue, h->A, h->B, h->C, req, h->descr, sizeTempBuffer,
751+
tempBuffer, {ev_webs});
752+
753+
req = oneapi::mkl::sparse::matmat_request::get_compute_structure_buf_size;
754+
auto ev_csbs = oneapi::mkl::sparse::matmat(queue, h->A, h->B, h->C, req, h->descr, sizeTempBuffer,
755+
nullptr, {ev_we});
756+
757+
ev_csbs.wait();
758+
tempBuffer2 = reinterpret_cast<void*>(sycl::malloc_shared<std::uint8_t>(sizeTempBuffer[0], queue));
759+
760+
req = oneapi::mkl::sparse::matmat_request::compute_structure;
761+
auto ev_cs = oneapi::mkl::sparse::matmat(queue, h->A, h->B, h->C, req, h->descr, sizeTempBuffer,
762+
tempBuffer2, {ev_csbs});
763+
764+
req = oneapi::mkl::sparse::matmat_request::get_nnz;
765+
std::int64_t *c_nnz = sycl::malloc_shared<std::int64_t>(1, queue);
766+
767+
auto ev_get_nnz = oneapi::mkl::sparse::matmat(queue, h->A, h->B, h->C, req, h->descr, c_nnz, nullptr,
768+
{ev_cs});
769+
ev_get_nnz.wait();
770+
handle->set_c_nnz(c_nnz[0]);
771+
}
772+
693773
#define SPGEMM_SYMBOLIC_DECL_MKL_SYCL(SCALAR, ORDINAL, TPL_AVAIL) \
694774
template <> \
695775
struct SPGEMM_SYMBOLIC< \
@@ -733,21 +813,21 @@ SPGEMM_SYMBOLIC_DECL_MKL_E(Kokkos::OpenMP)
733813
std::string label = "KokkosSparse::spgemm_symbolic[TPL_MKL," + \
734814
Kokkos::ArithTraits<SCALAR>::name() + "]"; \
735815
Kokkos::Profiling::pushRegion(label); \
736-
spgemm_symbolic_mkl(handle->get_spgemm_handle(), m, n, k, row_mapA, \
816+
spgemm_symbolic_onemkl(handle->get_spgemm_handle(), m, n, k, row_mapA, \
737817
entriesA, row_mapB, entriesB, row_mapC); \
738818
Kokkos::Profiling::popRegion(); \
739819
} \
740820
};
741821

742-
SPGEMM_SYMBOLIC_DECL_MKL_SYCL(float, std::int32_t)
743-
SPGEMM_SYMBOLIC_DECL_MKL_SYCL(double, std::int32_t)
744-
SPGEMM_SYMBOLIC_DECL_MKL_SYCL(Kokkos::complex<float>, std::int32_t)
745-
SPGEMM_SYMBOLIC_DECL_MKL_SYCL(Kokkos::complex<double>, std::int32_t)
822+
SPGEMM_SYMBOLIC_DECL_MKL_SYCL(float, std::int32_t, true)
823+
SPGEMM_SYMBOLIC_DECL_MKL_SYCL(double, std::int32_t, true)
824+
SPGEMM_SYMBOLIC_DECL_MKL_SYCL(Kokkos::complex<float>, std::int32_t, true)
825+
SPGEMM_SYMBOLIC_DECL_MKL_SYCL(Kokkos::complex<double>, std::int32_t, true)
746826

747-
SPGEMM_SYMBOLIC_DECL_MKL_SYCL(float, std::int64_t)
748-
SPGEMM_SYMBOLIC_DECL_MKL_SYCL(double, std::int64_t)
749-
SPGEMM_SYMBOLIC_DECL_MKL_SYCL(Kokkos::complex<float>, std::int64_t)
750-
SPGEMM_SYMBOLIC_DECL_MKL_SYCL(Kokkos::complex<double>, std::int64_t)
827+
SPGEMM_SYMBOLIC_DECL_MKL_SYCL(float, std::int64_t, true)
828+
SPGEMM_SYMBOLIC_DECL_MKL_SYCL(double, std::int64_t, true)
829+
SPGEMM_SYMBOLIC_DECL_MKL_SYCL(Kokkos::complex<float>, std::int64_t, true)
830+
SPGEMM_SYMBOLIC_DECL_MKL_SYCL(Kokkos::complex<double>, std::int64_t, true)
751831
#endif // KOKKOS_ENABLE_SYCL
752832

753833
#endif // KOKKOSKERNELS_ENABLE_TPL_MKL

0 commit comments

Comments
 (0)