@@ -690,6 +690,86 @@ SPGEMM_SYMBOLIC_DECL_MKL_E(Kokkos::OpenMP)
690690#endif
691691
692692#if defined(KOKKOS_ENABLE_SYCL)
693+
694+ template <
695+ typename KernelHandle, typename ain_row_index_view_type,
696+ typename ain_nonzero_index_view_type, typename bin_row_index_view_type,
697+ typename bin_nonzero_index_view_type, typename cin_row_index_view_type>
698+ void spgemm_symbolic_onemkl (
699+ KernelHandle *handle, typename KernelHandle::nnz_lno_t m,
700+ typename KernelHandle::nnz_lno_t n, typename KernelHandle::nnz_lno_t k,
701+ ain_row_index_view_type rowptrA, ain_nonzero_index_view_type colidxA,
702+ bin_row_index_view_type rowptrB, bin_nonzero_index_view_type colidxB,
703+ cin_row_index_view_type rowptrC) {
704+ using ExecSpace = typename KernelHandle::HandleExecSpace;
705+ using INT_TYPE = typename KernelHandle::nnz_lno_t ;
706+ using DATA_TYPE = typename KernelHandle::nnz_scalar_t ;
707+
708+ handle->create_onemkl_spgemm_handle (" N" , " N" );
709+ typename KernelHandle::oneMKLSpgemmHandleType *h =
710+ handle->get_onemkl_spgemm_handle ();
711+
712+ // Creating some work variables/views
713+ sycl::queue queue = ExecSpace ().sycl_queue ();
714+ Kokkos::View<std::int64_t , Kokkos::Experimental::SYCLHostUSMSpace> sizeTempBufferView (" oneMKL spgemm buffer size" );
715+ auto sizeTempBuffer = sizeTempBufferView.data ();
716+
717+ oneapi::mkl::index_base mat_index = oneapi::mkl::index_base::zero;
718+
719+ oneapi::mkl::sparse::init_matrix_handle (&(h->A ));
720+ oneapi::mkl::sparse::init_matrix_handle (&(h->B ));
721+ oneapi::mkl::sparse::init_matrix_handle (&(h->C ));
722+
723+ sycl::event ev_setA, ev_setB, ev_setC;
724+ if constexpr (std::is_same_v<DATA_TYPE, Kokkos::complex <float >>) {
725+ ev_setA = oneapi::mkl::sparse::set_csr_data (queue, h->A , m, n, mat_index, const_cast <INT_TYPE*>(rowptrA.data ()), const_cast <INT_TYPE*>(colidxA.data ()), (std::complex <float > *)nullptr , {});
726+ ev_setB = oneapi::mkl::sparse::set_csr_data (queue, h->B , n, k, mat_index, const_cast <INT_TYPE*>(rowptrB.data ()), const_cast <INT_TYPE*>(colidxB.data ()), (std::complex <float > *)nullptr , {});
727+ ev_setC = oneapi::mkl::sparse::set_csr_data (queue, h->C , m, k, mat_index, const_cast <INT_TYPE*>(rowptrC.data ()), (INT_TYPE *)nullptr , (std::complex <float > *)nullptr , {});
728+ } else if constexpr (std::is_same_v<DATA_TYPE, Kokkos::complex <double >>) {
729+ ev_setA = oneapi::mkl::sparse::set_csr_data (queue, h->A , m, n, mat_index, const_cast <INT_TYPE*>(rowptrA.data ()), const_cast <INT_TYPE*>(colidxA.data ()), (std::complex <double > *)nullptr , {});
730+ ev_setB = oneapi::mkl::sparse::set_csr_data (queue, h->B , n, k, mat_index, const_cast <INT_TYPE*>(rowptrB.data ()), const_cast <INT_TYPE*>(colidxB.data ()), (std::complex <double > *)nullptr , {});
731+ ev_setC = oneapi::mkl::sparse::set_csr_data (queue, h->C , m, k, mat_index, const_cast <INT_TYPE*>(rowptrC.data ()), (INT_TYPE *)nullptr , (std::complex <double > *)nullptr , {});
732+ } else {
733+ ev_setA = oneapi::mkl::sparse::set_csr_data (queue, h->A , m, n, mat_index, const_cast <INT_TYPE*>(rowptrA.data ()), const_cast <INT_TYPE*>(colidxA.data ()), (DATA_TYPE *)nullptr , {});
734+ ev_setB = oneapi::mkl::sparse::set_csr_data (queue, h->B , n, k, mat_index, const_cast <INT_TYPE*>(rowptrB.data ()), const_cast <INT_TYPE*>(colidxB.data ()), (DATA_TYPE *)nullptr , {});
735+ ev_setC = oneapi::mkl::sparse::set_csr_data (queue, h->C , m, k, mat_index, const_cast <INT_TYPE*>(rowptrC.data ()), (INT_TYPE *)nullptr , (DATA_TYPE *)nullptr , {});
736+ }
737+
738+
739+ oneapi::mkl::sparse::matmat_request req;
740+ void *tempBuffer = nullptr , *tempBuffer2 = nullptr ;
741+
742+ req = oneapi::mkl::sparse::matmat_request::get_work_estimation_buf_size;
743+ auto ev_webs = oneapi::mkl::sparse::matmat (queue, h->A , h->B , h->C , req, h->descr , sizeTempBuffer,
744+ nullptr , {ev_setA, ev_setB, ev_setC});
745+
746+ ev_webs.wait ();
747+ tempBuffer = reinterpret_cast <void *>(sycl::malloc_shared<std::uint8_t >(sizeTempBuffer[0 ], queue));
748+
749+ req = oneapi::mkl::sparse::matmat_request::work_estimation;
750+ auto ev_we = oneapi::mkl::sparse::matmat (queue, h->A , h->B , h->C , req, h->descr , sizeTempBuffer,
751+ tempBuffer, {ev_webs});
752+
753+ req = oneapi::mkl::sparse::matmat_request::get_compute_structure_buf_size;
754+ auto ev_csbs = oneapi::mkl::sparse::matmat (queue, h->A , h->B , h->C , req, h->descr , sizeTempBuffer,
755+ nullptr , {ev_we});
756+
757+ ev_csbs.wait ();
758+ tempBuffer2 = reinterpret_cast <void *>(sycl::malloc_shared<std::uint8_t >(sizeTempBuffer[0 ], queue));
759+
760+ req = oneapi::mkl::sparse::matmat_request::compute_structure;
761+ auto ev_cs = oneapi::mkl::sparse::matmat (queue, h->A , h->B , h->C , req, h->descr , sizeTempBuffer,
762+ tempBuffer2, {ev_csbs});
763+
764+ req = oneapi::mkl::sparse::matmat_request::get_nnz;
765+ std::int64_t *c_nnz = sycl::malloc_shared<std::int64_t >(1 , queue);
766+
767+ auto ev_get_nnz = oneapi::mkl::sparse::matmat (queue, h->A , h->B , h->C , req, h->descr , c_nnz, nullptr ,
768+ {ev_cs});
769+ ev_get_nnz.wait ();
770+ handle->set_c_nnz (c_nnz[0 ]);
771+ }
772+
693773#define SPGEMM_SYMBOLIC_DECL_MKL_SYCL (SCALAR, ORDINAL, TPL_AVAIL ) \
694774 template <> \
695775 struct SPGEMM_SYMBOLIC < \
@@ -733,21 +813,21 @@ SPGEMM_SYMBOLIC_DECL_MKL_E(Kokkos::OpenMP)
733813 std::string label = " KokkosSparse::spgemm_symbolic[TPL_MKL," + \
734814 Kokkos::ArithTraits<SCALAR>::name () + " ]" ; \
735815 Kokkos::Profiling::pushRegion (label); \
736- spgemm_symbolic_mkl (handle->get_spgemm_handle (), m, n, k, row_mapA, \
816+ spgemm_symbolic_onemkl (handle->get_spgemm_handle (), m, n, k, row_mapA, \
737817 entriesA, row_mapB, entriesB, row_mapC); \
738818 Kokkos::Profiling::popRegion (); \
739819 } \
740820 };
741821
742- SPGEMM_SYMBOLIC_DECL_MKL_SYCL (float , std::int32_t )
743- SPGEMM_SYMBOLIC_DECL_MKL_SYCL(double , std::int32_t )
744- SPGEMM_SYMBOLIC_DECL_MKL_SYCL(Kokkos::complex <float >, std::int32_t )
745- SPGEMM_SYMBOLIC_DECL_MKL_SYCL(Kokkos::complex <double >, std::int32_t )
822+ SPGEMM_SYMBOLIC_DECL_MKL_SYCL (float , std::int32_t , true )
823+ SPGEMM_SYMBOLIC_DECL_MKL_SYCL(double , std::int32_t , true )
824+ SPGEMM_SYMBOLIC_DECL_MKL_SYCL(Kokkos::complex <float >, std::int32_t , true )
825+ SPGEMM_SYMBOLIC_DECL_MKL_SYCL(Kokkos::complex <double >, std::int32_t , true )
746826
747- SPGEMM_SYMBOLIC_DECL_MKL_SYCL(float , std::int64_t )
748- SPGEMM_SYMBOLIC_DECL_MKL_SYCL(double , std::int64_t )
749- SPGEMM_SYMBOLIC_DECL_MKL_SYCL(Kokkos::complex <float >, std::int64_t )
750- SPGEMM_SYMBOLIC_DECL_MKL_SYCL(Kokkos::complex <double >, std::int64_t )
827+ SPGEMM_SYMBOLIC_DECL_MKL_SYCL(float , std::int64_t , true )
828+ SPGEMM_SYMBOLIC_DECL_MKL_SYCL(double , std::int64_t , true )
829+ SPGEMM_SYMBOLIC_DECL_MKL_SYCL(Kokkos::complex <float >, std::int64_t , true )
830+ SPGEMM_SYMBOLIC_DECL_MKL_SYCL(Kokkos::complex <double >, std::int64_t , true )
751831#endif // KOKKOS_ENABLE_SYCL
752832
753833#endif // KOKKOSKERNELS_ENABLE_TPL_MKL
0 commit comments