1111#include < spblas/detail/view_inspectors.hpp>
1212
1313#include " cuda_allocator.hpp"
14- #include " descriptor .hpp"
14+ #include " detail/cusparse_tensors .hpp"
1515#include " exception.hpp"
1616#include " types.hpp"
1717
@@ -83,17 +83,17 @@ class spgemm_state_t {
8383 __cusparse::throw_if_error (cusparseDestroySpMat (mat_a_));
8484 __cusparse::throw_if_error (cusparseDestroySpMat (mat_b_));
8585 __cusparse::throw_if_error (cusparseDestroySpMat (mat_c_));
86- mat_a_ = __cusparse::create_matrix_descr (a_base);
87- mat_b_ = __cusparse::create_matrix_descr (b_base);
88- mat_c_ = __cusparse::create_matrix_descr (c);
86+ mat_a_ = __cusparse::create_cusparse_handle (a_base);
87+ mat_b_ = __cusparse::create_cusparse_handle (b_base);
88+ mat_c_ = __cusparse::create_cusparse_handle (c);
8989
9090 // ask bufferSize1 bytes for external memory
9191 size_t buffer_size_1 = 0 ;
9292 __cusparse::throw_if_error (cusparseSpGEMM_workEstimation (
9393 handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
9494 CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a_, mat_b_, &beta, mat_c_,
95- to_cuda_datatype <value_type>() , CUSPARSE_SPGEMM_DEFAULT, this -> descr_ ,
96- &buffer_size_1, NULL ));
95+ detail::cuda_data_type_v <value_type>, CUSPARSE_SPGEMM_DEFAULT,
96+ this -> descr_ , &buffer_size_1, NULL ));
9797 if (buffer_size_1 > this ->buffer_size_1_ ) {
9898 this ->alloc_ .deallocate (this ->workspace_1_ , buffer_size_1_);
9999 this ->buffer_size_1_ = buffer_size_1;
@@ -104,16 +104,16 @@ class spgemm_state_t {
104104 __cusparse::throw_if_error (cusparseSpGEMM_workEstimation (
105105 handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
106106 CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a_, mat_b_, &beta, mat_c_,
107- to_cuda_datatype <value_type>() , CUSPARSE_SPGEMM_DEFAULT, this -> descr_ ,
108- &buffer_size_1, this ->workspace_1_ ));
107+ detail::cuda_data_type_v <value_type>, CUSPARSE_SPGEMM_DEFAULT,
108+ this -> descr_ , &buffer_size_1, this ->workspace_1_ ));
109109
110110 // ask buffer_size_2 bytes for external memory
111111 size_t buffer_size_2 = 0 ;
112112 __cusparse::throw_if_error (cusparseSpGEMM_compute (
113113 handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
114114 CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a_, mat_b_, &beta, mat_c_,
115- to_cuda_datatype <value_type>() , CUSPARSE_SPGEMM_DEFAULT, this -> descr_ ,
116- &buffer_size_2, NULL ));
115+ detail::cuda_data_type_v <value_type>, CUSPARSE_SPGEMM_DEFAULT,
116+ this -> descr_ , &buffer_size_2, NULL ));
117117 if (buffer_size_2 > this ->buffer_size_2_ ) {
118118 this ->alloc_ .deallocate (this ->workspace_2_ , buffer_size_2_);
119119 this ->buffer_size_2_ = buffer_size_2;
@@ -124,8 +124,8 @@ class spgemm_state_t {
124124 cusparseSpGEMM_compute (
125125 handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
126126 CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a_, mat_b_, &beta, mat_c_,
127- to_cuda_datatype <value_type>() , CUSPARSE_SPGEMM_DEFAULT, this -> descr_ ,
128- &buffer_size_2, this ->workspace_2_ );
127+ detail::cuda_data_type_v <value_type>, CUSPARSE_SPGEMM_DEFAULT,
128+ this -> descr_ , &buffer_size_2, this ->workspace_2_ );
129129 // get matrix C non-zero entries c_nnz
130130 int64_t c_num_rows, c_num_cols, c_nnz;
131131 cusparseSpMatGetSize (mat_c_, &c_num_rows, &c_num_cols, &c_nnz);
@@ -155,7 +155,8 @@ class spgemm_state_t {
155155 __cusparse::throw_if_error (cusparseSpGEMM_copy (
156156 handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
157157 CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a_, mat_b_, &beta, mat_c_,
158- to_cuda_datatype<value_type>(), CUSPARSE_SPGEMM_DEFAULT, this ->descr_ ));
158+ detail::cuda_data_type_v<value_type>, CUSPARSE_SPGEMM_DEFAULT,
159+ this ->descr_ ));
159160 }
160161
161162 template <matrix A, matrix B, matrix C>
@@ -177,9 +178,9 @@ class spgemm_state_t {
177178 __cusparse::throw_if_error (cusparseDestroySpMat (mat_a_));
178179 __cusparse::throw_if_error (cusparseDestroySpMat (mat_b_));
179180 __cusparse::throw_if_error (cusparseDestroySpMat (mat_c_));
180- mat_a_ = __cusparse::create_matrix_descr (a_base);
181- mat_b_ = __cusparse::create_matrix_descr (b_base);
182- mat_c_ = __cusparse::create_matrix_descr (c);
181+ mat_a_ = __cusparse::create_cusparse_handle (a_base);
182+ mat_b_ = __cusparse::create_cusparse_handle (b_base);
183+ mat_c_ = __cusparse::create_cusparse_handle (c);
183184
184185 // ask bufferSize1 bytes for external memory
185186 size_t buffer_size_1 = 0 ;
@@ -302,10 +303,11 @@ class spgemm_state_t {
302303 b_base.colind ().data (), b_base.values ().data ()));
303304 __cusparse::throw_if_error (cusparseCsrSetPointers (
304305 this ->mat_c_ , c.rowptr ().data (), c.colind ().data (), c.values ().data ()));
305- cusparseSpGEMMreuse_compute (
306- handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
307- CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a_, mat_b_, &beta, mat_c_,
308- to_cuda_datatype<value_type>(), CUSPARSE_SPGEMM_DEFAULT, this ->descr_ );
306+ cusparseSpGEMMreuse_compute (handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
307+ CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha,
308+ mat_a_, mat_b_, &beta, mat_c_,
309+ detail::cuda_data_type_v<value_type>,
310+ CUSPARSE_SPGEMM_DEFAULT, this ->descr_ );
309311 }
310312
311313private:
0 commit comments