Skip to content

Commit c9cd31d

Browse files
committed
use the helper function to create descriptor
1 parent 7dc4ce8 commit c9cd31d

File tree

3 files changed

+59
-41
lines changed

3 files changed

+59
-41
lines changed
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
#pragma once
2+
3+
#include <type_traits>
4+
5+
#include <cusparse.h>
6+
7+
#include <spblas/concepts.hpp>
8+
#include <spblas/views/inspectors.hpp>
9+
10+
#include "exception.hpp"
11+
#include "types.hpp"
12+
13+
namespace spblas {
14+
namespace __cusparse {
15+
16+
// create matrix descriptor from spblas csr view
17+
template <matrix mat>
18+
requires __detail::is_csr_view_v<mat>
19+
cusparseSpMatDescr_t create_matrix_descr(mat&& a) {
20+
using matrix_type = std::remove_cvref_t<mat>;
21+
cusparseSpMatDescr_t descr;
22+
throw_if_error(cusparseCreateCsr(
23+
&descr, __backend::shape(a)[0], __backend::shape(a)[1], a.values().size(),
24+
a.rowptr().data(), a.colind().data(), a.values().data(),
25+
to_cusparse_indextype<typename matrix_type::offset_type>(),
26+
to_cusparse_indextype<typename matrix_type::index_type>(),
27+
CUSPARSE_INDEX_BASE_ZERO,
28+
to_cuda_datatype<typename matrix_type::scalar_type>()));
29+
return descr;
30+
}
31+
32+
// create dense vector from mdspan
33+
template <vector vec>
34+
requires __ranges::contiguous_range<vec>
35+
cusparseDnVecDescr_t create_vector_descr(vec&& v) {
36+
using vector_type = std::remove_cvref_t<vec>;
37+
cusparseDnVecDescr_t descr;
38+
throw_if_error(cusparseCreateDnVec(
39+
&descr, v.size(), v.data(),
40+
to_cuda_datatype<typename vector_type::value_type>()));
41+
return descr;
42+
}
43+
44+
} // namespace __cusparse
45+
} // namespace spblas

include/spblas/vendor/cusparse/multiply.hpp

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <spblas/detail/view_inspectors.hpp>
1212

1313
#include "cuda_allocator.hpp"
14+
#include "descriptor.hpp"
1415
#include "exception.hpp"
1516
#include "types.hpp"
1617

@@ -59,23 +60,9 @@ class spmv_state_t {
5960
tensor_scalar_t<A> alpha = alpha_optional.value_or(1);
6061
auto handle = this->handle_.get();
6162

62-
cusparseSpMatDescr_t mat;
63-
__cusparse::throw_if_error(cusparseCreateCsr(
64-
&mat, __backend::shape(a_base)[0], __backend::shape(a_base)[1],
65-
a_base.values().size(), a_base.rowptr().data(), a_base.colind().data(),
66-
a_base.values().data(),
67-
to_cusparse_indextype<typename matrix_type::offset_type>(),
68-
to_cusparse_indextype<typename matrix_type::index_type>(),
69-
CUSPARSE_INDEX_BASE_ZERO, to_cuda_datatype<value_type>()));
70-
71-
cusparseDnVecDescr_t vecb;
72-
cusparseDnVecDescr_t vecc;
73-
__cusparse::throw_if_error(cusparseCreateDnVec(
74-
&vecb, b_base.size(), b_base.data(),
75-
to_cuda_datatype<typename input_type::value_type>()));
76-
__cusparse::throw_if_error(cusparseCreateDnVec(
77-
&vecc, c.size(), c.data(),
78-
to_cuda_datatype<typename output_type::value_type>()));
63+
auto mat = __cusparse::create_matrix_descr(a_base);
64+
auto vecb = __cusparse::create_vector_descr(b_base);
65+
auto vecc = __cusparse::create_vector_descr(c);
7966

8067
value_type alpha_val = alpha;
8168
value_type beta = 0.0;

include/spblas/vendor/cusparse/multiply_spgemm.hpp

Lines changed: 10 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <spblas/detail/view_inspectors.hpp>
1212

1313
#include "cuda_allocator.hpp"
14+
#include "descriptor.hpp"
1415
#include "exception.hpp"
1516
#include "types.hpp"
1617

@@ -78,27 +79,12 @@ class spgemm_state_t {
7879
value_type alpha = alpha_optional.value_or(1);
7980
value_type beta = 1;
8081
auto handle = this->handle_.get();
81-
// Create sparse matrix A in CSR format
82-
__cusparse::throw_if_error(cusparseCreateCsr(
83-
&mat_a_, __backend::shape(a_base)[0], __backend::shape(a_base)[1],
84-
a_base.values().size(), a_base.rowptr().data(), a_base.colind().data(),
85-
a_base.values().data(),
86-
to_cusparse_indextype<typename matrix_type::offset_type>(),
87-
to_cusparse_indextype<typename matrix_type::index_type>(),
88-
CUSPARSE_INDEX_BASE_ZERO, to_cuda_datatype<value_type>()));
89-
__cusparse::throw_if_error(cusparseCreateCsr(
90-
&mat_b_, __backend::shape(b_base)[0], __backend::shape(b_base)[1],
91-
b_base.values().size(), b_base.rowptr().data(), b_base.colind().data(),
92-
b_base.values().data(),
93-
to_cusparse_indextype<typename matrix_type::offset_type>(),
94-
to_cusparse_indextype<typename matrix_type::index_type>(),
95-
CUSPARSE_INDEX_BASE_ZERO, to_cuda_datatype<value_type>()));
96-
__cusparse::throw_if_error(cusparseCreateCsr(
97-
&mat_c_, __backend::shape(a)[0], __backend::shape(b)[1], 0,
98-
c.rowptr().data(), NULL, NULL,
99-
to_cusparse_indextype<typename matrix_type::offset_type>(),
100-
to_cusparse_indextype<typename matrix_type::index_type>(),
101-
CUSPARSE_INDEX_BASE_ZERO, to_cuda_datatype<value_type>()));
82+
__cusparse::throw_if_error(cusparseDestroySpMat(mat_a_));
83+
__cusparse::throw_if_error(cusparseDestroySpMat(mat_b_));
84+
__cusparse::throw_if_error(cusparseDestroySpMat(mat_c_));
85+
mat_a_ = __cusparse::create_matrix_descr(a_base);
86+
mat_b_ = __cusparse::create_matrix_descr(b_base);
87+
mat_c_ = __cusparse::create_matrix_descr(c);
10288

10389
// ask bufferSize1 bytes for external memory
10490
size_t buffer_size_1 = 0;
@@ -183,9 +169,9 @@ class spgemm_state_t {
183169
char* workspace_2_;
184170
index<index_t> result_shape_;
185171
index_t result_nnz_;
186-
cusparseSpMatDescr_t mat_a_;
187-
cusparseSpMatDescr_t mat_b_;
188-
cusparseSpMatDescr_t mat_c_;
172+
cusparseSpMatDescr_t mat_a_ = nullptr;
173+
cusparseSpMatDescr_t mat_b_ = nullptr;
174+
cusparseSpMatDescr_t mat_c_ = nullptr;
189175
cusparseSpGEMMDescr_t descr_;
190176
};
191177

0 commit comments

Comments
 (0)