|
11 | 11 | #include <spblas/detail/view_inspectors.hpp> |
12 | 12 |
|
13 | 13 | #include "cuda_allocator.hpp" |
| 14 | +#include "descriptor.hpp" |
14 | 15 | #include "exception.hpp" |
15 | 16 | #include "types.hpp" |
16 | 17 |
|
@@ -78,27 +79,12 @@ class spgemm_state_t { |
78 | 79 | value_type alpha = alpha_optional.value_or(1); |
79 | 80 | value_type beta = 1; |
80 | 81 | auto handle = this->handle_.get(); |
81 | | - // Create sparse matrix A in CSR format |
82 | | - __cusparse::throw_if_error(cusparseCreateCsr( |
83 | | - &mat_a_, __backend::shape(a_base)[0], __backend::shape(a_base)[1], |
84 | | - a_base.values().size(), a_base.rowptr().data(), a_base.colind().data(), |
85 | | - a_base.values().data(), |
86 | | - to_cusparse_indextype<typename matrix_type::offset_type>(), |
87 | | - to_cusparse_indextype<typename matrix_type::index_type>(), |
88 | | - CUSPARSE_INDEX_BASE_ZERO, to_cuda_datatype<value_type>())); |
89 | | - __cusparse::throw_if_error(cusparseCreateCsr( |
90 | | - &mat_b_, __backend::shape(b_base)[0], __backend::shape(b_base)[1], |
91 | | - b_base.values().size(), b_base.rowptr().data(), b_base.colind().data(), |
92 | | - b_base.values().data(), |
93 | | - to_cusparse_indextype<typename matrix_type::offset_type>(), |
94 | | - to_cusparse_indextype<typename matrix_type::index_type>(), |
95 | | - CUSPARSE_INDEX_BASE_ZERO, to_cuda_datatype<value_type>())); |
96 | | - __cusparse::throw_if_error(cusparseCreateCsr( |
97 | | - &mat_c_, __backend::shape(a)[0], __backend::shape(b)[1], 0, |
98 | | - c.rowptr().data(), NULL, NULL, |
99 | | - to_cusparse_indextype<typename matrix_type::offset_type>(), |
100 | | - to_cusparse_indextype<typename matrix_type::index_type>(), |
101 | | - CUSPARSE_INDEX_BASE_ZERO, to_cuda_datatype<value_type>())); |
| 82 | + __cusparse::throw_if_error(cusparseDestroySpMat(mat_a_)); |
| 83 | + __cusparse::throw_if_error(cusparseDestroySpMat(mat_b_)); |
| 84 | + __cusparse::throw_if_error(cusparseDestroySpMat(mat_c_)); |
| 85 | + mat_a_ = __cusparse::create_matrix_descr(a_base); |
| 86 | + mat_b_ = __cusparse::create_matrix_descr(b_base); |
| 87 | + mat_c_ = __cusparse::create_matrix_descr(c); |
102 | 88 |
|
103 | 89 | // ask bufferSize1 bytes for external memory |
104 | 90 | size_t buffer_size_1 = 0; |
@@ -183,9 +169,9 @@ class spgemm_state_t { |
183 | 169 | char* workspace_2_; |
184 | 170 | index<index_t> result_shape_; |
185 | 171 | index_t result_nnz_; |
186 | | - cusparseSpMatDescr_t mat_a_; |
187 | | - cusparseSpMatDescr_t mat_b_; |
188 | | - cusparseSpMatDescr_t mat_c_; |
| 172 | + cusparseSpMatDescr_t mat_a_ = nullptr; |
| 173 | + cusparseSpMatDescr_t mat_b_ = nullptr; |
| 174 | + cusparseSpMatDescr_t mat_c_ = nullptr; |
189 | 175 | cusparseSpGEMMDescr_t descr_; |
190 | 176 | }; |
191 | 177 |
|
|
0 commit comments