Skip to content

Commit 7845296

Browse files
committed
reuse the current changes on main
1 parent 144aa14 commit 7845296

File tree

2 files changed

+22
-65
lines changed

2 files changed

+22
-65
lines changed

include/spblas/vendor/cusparse/descriptor.hpp

Lines changed: 0 additions & 45 deletions
This file was deleted.

include/spblas/vendor/cusparse/multiply_spgemm.hpp

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
#include <spblas/detail/view_inspectors.hpp>
1212

1313
#include "cuda_allocator.hpp"
14-
#include "descriptor.hpp"
14+
#include "detail/cusparse_tensors.hpp"
1515
#include "exception.hpp"
1616
#include "types.hpp"
1717

@@ -83,17 +83,17 @@ class spgemm_state_t {
8383
__cusparse::throw_if_error(cusparseDestroySpMat(mat_a_));
8484
__cusparse::throw_if_error(cusparseDestroySpMat(mat_b_));
8585
__cusparse::throw_if_error(cusparseDestroySpMat(mat_c_));
86-
mat_a_ = __cusparse::create_matrix_descr(a_base);
87-
mat_b_ = __cusparse::create_matrix_descr(b_base);
88-
mat_c_ = __cusparse::create_matrix_descr(c);
86+
mat_a_ = __cusparse::create_cusparse_handle(a_base);
87+
mat_b_ = __cusparse::create_cusparse_handle(b_base);
88+
mat_c_ = __cusparse::create_cusparse_handle(c);
8989

9090
// ask bufferSize1 bytes for external memory
9191
size_t buffer_size_1 = 0;
9292
__cusparse::throw_if_error(cusparseSpGEMM_workEstimation(
9393
handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
9494
CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a_, mat_b_, &beta, mat_c_,
95-
to_cuda_datatype<value_type>(), CUSPARSE_SPGEMM_DEFAULT, this->descr_,
96-
&buffer_size_1, NULL));
95+
detail::cuda_data_type_v<value_type>, CUSPARSE_SPGEMM_DEFAULT,
96+
this->descr_, &buffer_size_1, NULL));
9797
if (buffer_size_1 > this->buffer_size_1_) {
9898
this->alloc_.deallocate(this->workspace_1_, buffer_size_1_);
9999
this->buffer_size_1_ = buffer_size_1;
@@ -104,16 +104,16 @@ class spgemm_state_t {
104104
__cusparse::throw_if_error(cusparseSpGEMM_workEstimation(
105105
handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
106106
CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a_, mat_b_, &beta, mat_c_,
107-
to_cuda_datatype<value_type>(), CUSPARSE_SPGEMM_DEFAULT, this->descr_,
108-
&buffer_size_1, this->workspace_1_));
107+
detail::cuda_data_type_v<value_type>, CUSPARSE_SPGEMM_DEFAULT,
108+
this->descr_, &buffer_size_1, this->workspace_1_));
109109

110110
// ask buffer_size_2 bytes for external memory
111111
size_t buffer_size_2 = 0;
112112
__cusparse::throw_if_error(cusparseSpGEMM_compute(
113113
handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
114114
CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a_, mat_b_, &beta, mat_c_,
115-
to_cuda_datatype<value_type>(), CUSPARSE_SPGEMM_DEFAULT, this->descr_,
116-
&buffer_size_2, NULL));
115+
detail::cuda_data_type_v<value_type>, CUSPARSE_SPGEMM_DEFAULT,
116+
this->descr_, &buffer_size_2, NULL));
117117
if (buffer_size_2 > this->buffer_size_2_) {
118118
this->alloc_.deallocate(this->workspace_2_, buffer_size_2_);
119119
this->buffer_size_2_ = buffer_size_2;
@@ -124,8 +124,8 @@ class spgemm_state_t {
124124
cusparseSpGEMM_compute(
125125
handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
126126
CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a_, mat_b_, &beta, mat_c_,
127-
to_cuda_datatype<value_type>(), CUSPARSE_SPGEMM_DEFAULT, this->descr_,
128-
&buffer_size_2, this->workspace_2_);
127+
detail::cuda_data_type_v<value_type>, CUSPARSE_SPGEMM_DEFAULT,
128+
this->descr_, &buffer_size_2, this->workspace_2_);
129129
// get matrix C non-zero entries c_nnz
130130
int64_t c_num_rows, c_num_cols, c_nnz;
131131
cusparseSpMatGetSize(mat_c_, &c_num_rows, &c_num_cols, &c_nnz);
@@ -155,7 +155,8 @@ class spgemm_state_t {
155155
__cusparse::throw_if_error(cusparseSpGEMM_copy(
156156
handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
157157
CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a_, mat_b_, &beta, mat_c_,
158-
to_cuda_datatype<value_type>(), CUSPARSE_SPGEMM_DEFAULT, this->descr_));
158+
detail::cuda_data_type_v<value_type>, CUSPARSE_SPGEMM_DEFAULT,
159+
this->descr_));
159160
}
160161

161162
template <matrix A, matrix B, matrix C>
@@ -177,9 +178,9 @@ class spgemm_state_t {
177178
__cusparse::throw_if_error(cusparseDestroySpMat(mat_a_));
178179
__cusparse::throw_if_error(cusparseDestroySpMat(mat_b_));
179180
__cusparse::throw_if_error(cusparseDestroySpMat(mat_c_));
180-
mat_a_ = __cusparse::create_matrix_descr(a_base);
181-
mat_b_ = __cusparse::create_matrix_descr(b_base);
182-
mat_c_ = __cusparse::create_matrix_descr(c);
181+
mat_a_ = __cusparse::create_cusparse_handle(a_base);
182+
mat_b_ = __cusparse::create_cusparse_handle(b_base);
183+
mat_c_ = __cusparse::create_cusparse_handle(c);
183184

184185
// ask bufferSize1 bytes for external memory
185186
size_t buffer_size_1 = 0;
@@ -302,10 +303,11 @@ class spgemm_state_t {
302303
b_base.colind().data(), b_base.values().data()));
303304
__cusparse::throw_if_error(cusparseCsrSetPointers(
304305
this->mat_c_, c.rowptr().data(), c.colind().data(), c.values().data()));
305-
cusparseSpGEMMreuse_compute(
306-
handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
307-
CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a_, mat_b_, &beta, mat_c_,
308-
to_cuda_datatype<value_type>(), CUSPARSE_SPGEMM_DEFAULT, this->descr_);
306+
cusparseSpGEMMreuse_compute(handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
307+
CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha,
308+
mat_a_, mat_b_, &beta, mat_c_,
309+
detail::cuda_data_type_v<value_type>,
310+
CUSPARSE_SPGEMM_DEFAULT, this->descr_);
309311
}
310312

311313
private:

0 commit comments

Comments
 (0)