Skip to content

Commit ec7eb47

Browse files
committed
implement cusparse spsv
1 parent 879a7d3 commit ec7eb47

File tree

3 files changed

+125
-1
lines changed

3 files changed

+125
-1
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
#pragma once
22

33
#include "multiply.hpp"
4+
#include "trisolve.hpp"
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
#pragma once
2+
3+
#include <cstdint>
4+
#include <functional>
5+
#include <memory>
6+
#include <type_traits>
7+
8+
#include <cuda_runtime.h>
9+
#include <cusparse.h>
10+
11+
#include <spblas/detail/ranges.hpp>
12+
#include <spblas/detail/triangular_types.hpp>
13+
#include <spblas/detail/view_inspectors.hpp>
14+
15+
#include "cuda_allocator.hpp"
16+
#include "detail/cusparse_tensors.hpp"
17+
#include "exception.hpp"
18+
#include "types.hpp"
19+
20+
namespace spblas {
21+
class triangular_solve_state_t {
22+
public:
23+
triangular_solve_state_t()
24+
: triangular_solve_state_t(cusparse::cuda_allocator<char>{}) {}
25+
26+
triangular_solve_state_t(cusparse::cuda_allocator<char> alloc)
27+
: alloc_(alloc), buffer_size_(0), workspace_(nullptr) {
28+
cusparseHandle_t handle;
29+
__cusparse::throw_if_error(cusparseCreate(&handle));
30+
if (auto stream = alloc.stream()) {
31+
cusparseSetStream(handle, stream);
32+
}
33+
handle_ = handle_manager(handle, [](cusparseHandle_t handle) {
34+
__cusparse::throw_if_error(cusparseDestroy(handle));
35+
});
36+
}
37+
38+
triangular_solve_state_t(cusparse::cuda_allocator<char> alloc,
39+
cusparseHandle_t handle)
40+
: alloc_(alloc), buffer_size_(0), workspace_(nullptr) {
41+
handle_ = handle_manager(handle, [](cusparseHandle_t handle) {
42+
// it is provided by user, we do not delete it at all.
43+
});
44+
}
45+
46+
~triangular_solve_state_t() {
47+
alloc_.deallocate(workspace_);
48+
}
49+
50+
template <matrix A, class Triangle, class DiagonalStorage, vector B, vector C>
51+
requires __detail::has_csr_base<A> &&
52+
__detail::has_contiguous_range_base<B> &&
53+
__ranges::contiguous_range<C>
54+
void triangular_solve(A&& a, Triangle uplo, DiagonalStorage diag, B&& b,
55+
C&& c) {
56+
auto a_base = __detail::get_ultimate_base(a);
57+
auto b_base = __detail::get_ultimate_base(b);
58+
using matrix_type = decltype(a_base);
59+
using value_type = typename matrix_type::scalar_type;
60+
// the following needs to be non-const because cusparseSpMatSetAttribute
61+
// only accept void*
62+
auto diag_type = std::is_same_v<DiagonalStorage, explicit_diagonal_t>
63+
? CUSPARSE_DIAG_TYPE_NON_UNIT
64+
: CUSPARSE_DIAG_TYPE_UNIT;
65+
auto fill_mode = std::is_same_v<Triangle, upper_triangle_t>
66+
? CUSPARSE_FILL_MODE_UPPER
67+
: CUSPARSE_FILL_MODE_LOWER;
68+
69+
auto a_descr = __cusparse::create_cusparse_handle(a_base);
70+
auto b_descr = __cusparse::create_cusparse_handle(b_base);
71+
auto c_descr = __cusparse::create_cusparse_handle(c);
72+
73+
cusparseSpMatSetAttribute(a_descr, CUSPARSE_SPMAT_FILL_MODE, &fill_mode,
74+
sizeof(fill_mode));
75+
cusparseSpMatSetAttribute(a_descr, CUSPARSE_SPMAT_DIAG_TYPE, &diag_type,
76+
sizeof(diag_type));
77+
value_type alpha = 1.0;
78+
size_t buffer_size = 0;
79+
auto handle = this->handle_.get();
80+
cusparseSpSVDescr_t descr;
81+
cusparseSpSV_createDescr(&descr);
82+
__cusparse::throw_if_error(cusparseSpSV_bufferSize(
83+
handle, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, a_descr, b_descr,
84+
c_descr, detail::cuda_data_type_v<value_type>,
85+
CUSPARSE_SPSV_ALG_DEFAULT, descr, &buffer_size));
86+
if (buffer_size > this->buffer_size_) {
87+
this->alloc_.deallocate(workspace_, this->buffer_size_);
88+
this->buffer_size_ = buffer_size;
89+
workspace_ = this->alloc_.allocate(buffer_size);
90+
}
91+
__cusparse::throw_if_error(cusparseSpSV_analysis(
92+
handle, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, a_descr, b_descr,
93+
c_descr, detail::cuda_data_type_v<value_type>,
94+
CUSPARSE_SPSV_ALG_DEFAULT, descr, this->workspace_));
95+
__cusparse::throw_if_error(cusparseSpSV_solve(
96+
handle, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, a_descr, b_descr,
97+
c_descr, detail::cuda_data_type_v<value_type>,
98+
CUSPARSE_SPSV_ALG_DEFAULT, descr));
99+
__cusparse::throw_if_error(cusparseDestroySpMat(a_descr));
100+
__cusparse::throw_if_error(cusparseDestroyDnVec(b_descr));
101+
__cusparse::throw_if_error(cusparseDestroyDnVec(c_descr));
102+
}
103+
104+
private:
105+
using handle_manager =
106+
std::unique_ptr<std::pointer_traits<cusparseHandle_t>::element_type,
107+
std::function<void(cusparseHandle_t)>>;
108+
handle_manager handle_;
109+
cusparse::cuda_allocator<char> alloc_;
110+
std::uint64_t buffer_size_;
111+
char* workspace_;
112+
};
113+
114+
template <matrix A, class Triangle, class DiagonalStorage, vector B, vector C>
115+
requires __detail::has_csr_base<A> &&
116+
__detail::has_contiguous_range_base<B> &&
117+
__ranges::contiguous_range<C>
118+
void triangular_solve(triangular_solve_state_t& trisolve_handle, A&& a,
119+
Triangle uplo, DiagonalStorage diag, B&& b, C&& c) {
120+
trisolve_handle.triangular_solve(a, uplo, diag, b, c);
121+
}
122+
123+
} // namespace spblas

test/gtest/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ if (SPBLAS_GPU_BACKEND)
2020
set(GPUTEST_SOURCES device/spmv_test.cpp device/spgemm_test.cpp device/spgemm_reuse_test.cpp device/rocsparse/spgemm_4args_test.cpp device/triangular_solve_test.cpp)
2121
set_source_files_properties(${GPUTEST_SOURCES} PROPERTIES LANGUAGE HIP)
2222
else ()
23-
set(GPUTEST_SOURCES device/spmv_test.cpp)
23+
set(GPUTEST_SOURCES device/spmv_test.cpp device/triangular_solve_test.cpp)
2424
endif ()
2525
list(APPEND TEST_SOURCES ${GPUTEST_SOURCES})
2626
endif()

0 commit comments

Comments
 (0)