88#include < spblas/detail/operation_info_t.hpp>
99#include < spblas/detail/ranges.hpp>
1010#include < spblas/detail/view_inspectors.hpp>
11+ #include < spblas/vendor/onemkl_sycl/detail/create_matrix_handle.hpp>
1112
1213//
1314// Defines the following APIs for SpGEMM:
2324namespace spblas {
2425
2526template <matrix A, matrix B, matrix C>
26- requires __detail::has_csr_base<A> && __detail::has_csr_base<B> &&
27- __detail::is_csr_view_v<C>
27+ requires (__detail::has_csr_base<A> || __detail::has_csc_base<A>) &&
28+ (__detail::has_csr_base<B> || __detail::has_csc_base<B>) &&
29+ __detail::is_csr_view_v<C>
2830operation_info_t multiply_compute (A&& a, B&& b, C&& c) {
2931 log_trace (" " );
3032 auto a_base = __detail::get_ultimate_base (a);
@@ -34,51 +36,40 @@ operation_info_t multiply_compute(A&& a, B&& b, C&& c) {
3436 using oneapi::mkl::sparse::matmat_request;
3537 using oneapi::mkl::sparse::matrix_view_descr;
3638
37- oneapi::mkl::sparse::matmat_descr_t descr = nullptr ;
38-
3939 sycl::queue q (sycl::cpu_selector_v);
4040
41- oneapi::mkl::sparse::init_matmat_descr (&descr);
42-
43- oneapi::mkl::sparse::set_matmat_data (
44- descr, matrix_view_descr::general, transpose::nontrans, // view/op for A
45- matrix_view_descr::general, transpose::nontrans, // view/op for B
46- matrix_view_descr::general); // view for C
47-
48- oneapi::mkl::sparse::matrix_handle_t a_handle, b_handle, c_handle;
49- a_handle = b_handle = c_handle = nullptr ;
50-
51- oneapi::mkl::sparse::init_matrix_handle (&a_handle);
52- oneapi::mkl::sparse::init_matrix_handle (&b_handle);
53- oneapi::mkl::sparse::init_matrix_handle (&c_handle);
54-
55- oneapi::mkl::sparse::set_csr_data (
56- q, a_handle, __backend::shape (a_base)[0 ], __backend::shape (a_base)[1 ],
57- oneapi::mkl::index_base::zero, a_base.rowptr ().data (),
58- a_base.colind ().data (), a_base.values ().data ())
59- .wait ();
60-
61- oneapi::mkl::sparse::set_csr_data (
62- q, b_handle, __backend::shape (b_base)[0 ], __backend::shape (b_base)[1 ],
63- oneapi::mkl::index_base::zero, b_base.rowptr ().data (),
64- b_base.colind ().data (), b_base.values ().data ())
65- .wait ();
66-
6741 using T = tensor_scalar_t <C>;
6842 using I = tensor_index_t <C>;
6943
44+ oneapi::mkl::sparse::matrix_handle_t a_handle =
45+ __mkl::create_matrix_handle (q, a_base);
46+ oneapi::mkl::sparse::matrix_handle_t b_handle =
47+ __mkl::create_matrix_handle (q, b_base);
48+
7049 I* c_rowptr;
7150 if (c.rowptr ().size () >= __backend::shape (c)[0 ] + 1 ) {
7251 c_rowptr = c.rowptr ().data ();
7352 } else {
7453 c_rowptr = sycl::malloc_device<I>(__backend::shape (c)[0 ] + 1 , q);
7554 }
7655
56+ oneapi::mkl::sparse::matrix_handle_t c_handle = nullptr ;
57+ oneapi::mkl::sparse::init_matrix_handle (&c_handle);
58+
7759 oneapi::mkl::sparse::set_csr_data (
7860 q, c_handle, __backend::shape (c)[0 ], __backend::shape (c)[1 ],
7961 oneapi::mkl::index_base::zero, c_rowptr, (I*) nullptr , (T*) nullptr )
8062 .wait ();
8163
64+ oneapi::mkl::sparse::matmat_descr_t descr = nullptr ;
65+ oneapi::mkl::sparse::init_matmat_descr (&descr);
66+
67+ oneapi::mkl::sparse::set_matmat_data (
68+ descr, matrix_view_descr::general,
69+ __mkl::get_transpose (a), // view/op for A
70+ matrix_view_descr::general, __mkl::get_transpose (b), // view/op for B
71+ matrix_view_descr::general); // view for C
72+
8273 auto ev1 = oneapi::mkl::sparse::matmat (q, a_handle, b_handle, c_handle,
8374 matmat_request::work_estimation, descr,
8475 nullptr , nullptr , {});
@@ -113,8 +104,9 @@ operation_info_t multiply_compute(A&& a, B&& b, C&& c) {
113104}
114105
115106template <matrix A, matrix B, matrix C>
116- requires __detail::has_csr_base<A> && __detail::has_csr_base<B> &&
117- __detail::is_csr_view_v<C>
107+ requires (__detail::has_csr_base<A> || __detail::has_csc_base<A>) &&
108+ (__detail::has_csr_base<B> || __detail::has_csc_base<B>) &&
109+ __detail::is_csr_view_v<C>
118110void multiply_fill (operation_info_t & info, A&& a, B&& b, C&& c) {
119111
120112 log_trace (" " );
0 commit comments