Skip to content

Commit 1fa7b3a

Browse files
authored
Build with sycl_thrust for device examples if building oneMKL. (#53)
* Build with `sycl_thrust` for device examples if building oneMKL. * Support compiling both CPU and GPU tests and examples. * Update cuSPARSE backend architecture to more closely resemble other backends. * Update rocSPARSE backend architecture to more closely resemble other backends.
1 parent 8e4ad01 commit 1fa7b3a

34 files changed

+1086
-408
lines changed

CMakeLists.txt

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,25 @@ endif()
2020
# Download dependencies
2121
include(FetchContent)
2222

23+
set(SPBLAS_CPU_BACKEND OFF)
24+
set(SPBLAS_GPU_BACKEND OFF)
25+
2326
if (ENABLE_ONEMKL_SYCL)
27+
set(SPBLAS_CPU_BACKEND ON)
28+
set(SPBLAS_GPU_BACKEND ON)
2429
find_package(MKL REQUIRED)
2530
target_link_libraries(spblas INTERFACE MKL::MKL_SYCL) # SYCL APIs
2631
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DSPBLAS_ENABLE_ONEMKL_SYCL")
32+
33+
FetchContent_Declare(
34+
sycl_thrust
35+
GIT_REPOSITORY https://github.com/SparseBLAS/sycl-thrust.git
36+
GIT_TAG main)
37+
FetchContent_MakeAvailable(sycl_thrust)
2738
endif()
2839

2940
if (ENABLE_ARMPL)
41+
set(SPBLAS_CPU_BACKEND ON)
3042
if (NOT DEFINED ENV{ARMPL_DIR})
3143
message(FATAL_ERROR "Environment variable ARMPL_DIR must be set when the ArmPL is enabled.")
3244
endif()
@@ -36,6 +48,7 @@ if (ENABLE_ARMPL)
3648
endif()
3749

3850
if (ENABLE_AOCLSPARSE)
51+
set(SPBLAS_CPU_BACKEND ON)
3952
if (NOT DEFINED ENV{AOCLSPARSE_DIR})
4053
message(FATAL_ERROR "Environment variable AOCLSPARSE_DIR must be set when the AOCLSPARSE is enabled.")
4154
endif()
@@ -81,6 +94,15 @@ if (ENABLE_CUSPARSE)
8194
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DSPBLAS_ENABLE_CUSPARSE")
8295
endif()
8396

97+
# If no vendor backend is enabled, enable CPU backend for reference implementation
98+
if (NOT ENABLE_ONEMKL_SYCL AND
99+
NOT ENABLE_ARMPL AND
100+
NOT ENABLE_AOCLSPARSE AND
101+
NOT ENABLE_ROCSPARSE AND
102+
NOT ENABLE_CUSPARSE)
103+
set(SPBLAS_CPU_BACKEND ON)
104+
endif()
105+
84106
# turn on/off debug logging
85107
if (LOG_LEVEL)
86108
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DLOG_LEVEL=${LOG_LEVEL}") # SPBLAS_DEBUG | SPBLAS_WARNING | SPBLAS_TRACE | SPBLAS_INFO

examples/CMakeLists.txt

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,23 @@ function(add_example example_name)
33
target_link_libraries(${example_name} spblas fmt)
44
endfunction()
55

6-
if (NOT SPBLAS_GPU_BACKEND)
6+
# CPU examples
7+
if (SPBLAS_CPU_BACKEND)
78
add_example(simple_spmv)
89
add_example(simple_spmm)
910
add_example(simple_spgemm)
1011
add_example(simple_sptrsv)
11-
add_example(matrix_opt_example)
1212
add_example(spmm_csc)
13-
else()
14-
add_subdirectory(device)
13+
add_example(matrix_opt_example)
1514
endif()
1615

17-
if (ENABLE_ROCSPARSE)
18-
add_subdirectory(rocsparse)
19-
endif()
20-
if (ENABLE_CUSPARSE)
21-
add_subdirectory(cusparse)
16+
# GPU examples
17+
if (SPBLAS_GPU_BACKEND)
18+
add_subdirectory(device)
19+
if (ENABLE_CUSPARSE)
20+
add_subdirectory(cusparse)
21+
endif()
22+
if (ENABLE_ROCSPARSE)
23+
add_subdirectory(rocsparse)
24+
endif()
2225
endif()

examples/cusparse/cusparse_simple_spmv.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ int main(int argc, char** argv) {
7676
std::span<value_t> y_span(d_y, m);
7777

7878
// y = A * x
79-
spblas::spmv_state_t state;
80-
spblas::multiply(state, a, x_span, y_span);
79+
spblas::operation_info_t info;
80+
spblas::multiply(info, a, x_span, y_span);
8181

8282
CUDA_CHECK(
8383
cudaMemcpy(y.data(), d_y, y.size() * sizeof(value_t), cudaMemcpyDefault));

examples/device/CMakeLists.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,15 @@ function(add_device_example example_name)
22
add_executable(${example_name} ${example_name}.cpp)
33
if (ENABLE_ROCSPARSE)
44
set_source_files_properties(${example_name}.cpp PROPERTIES LANGUAGE HIP)
5+
target_link_libraries(${example_name} roc::rocthrust)
56
elseif (ENABLE_CUSPARSE)
67
target_link_libraries(${example_name} Thrust)
8+
elseif (ENABLE_ONEMKL_SYCL)
9+
target_link_libraries(${example_name} sycl_thrust)
710
else()
811
message(FATAL_ERROR "Device backend not found.")
912
endif()
1013
target_link_libraries(${example_name} spblas fmt)
1114
endfunction()
1215

13-
add_device_example(simple_spmv)
16+
add_device_example(device_spmv)
Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,7 @@ int main(int argc, char** argv) {
5656
std::span<value_t> y_span(d_y.data().get(), m);
5757

5858
// y = A * x
59-
spblas::spmv_state_t state;
60-
spblas::multiply(state, a, x_span, y_span);
59+
spblas::multiply(a, x_span, y_span);
6160

6261
thrust::copy(d_y.begin(), d_y.end(), y.begin());
6362

examples/rocsparse/rocsparse_simple_spmv.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ int main(int argc, char** argv) {
7676
std::span<value_t> y_span(d_y, m);
7777

7878
// y = A * x
79-
spblas::spmv_state_t state;
80-
spblas::multiply(state, a, x_span, y_span);
79+
spblas::operation_info_t info;
80+
spblas::multiply(info, a, x_span, y_span);
8181

8282
HIP_CHECK(
8383
hipMemcpy(y.data(), d_y, y.size() * sizeof(value_t), hipMemcpyDefault));

include/spblas/detail/operation_info_t.hpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,14 @@
1515
#include <spblas/vendor/aoclsparse/operation_state_t.hpp>
1616
#endif
1717

18+
#ifdef SPBLAS_ENABLE_CUSPARSE
19+
#include <spblas/vendor/cusparse/operation_state_t.hpp>
20+
#endif
21+
22+
#ifdef SPBLAS_ENABLE_ROCSPARSE
23+
#include <spblas/vendor/rocsparse/operation_state_t.hpp>
24+
#endif
25+
1826
namespace spblas {
1927

2028
class operation_info_t {
@@ -53,6 +61,13 @@ class operation_info_t {
5361
state_(std::move(state)) {}
5462
#endif
5563

64+
#ifdef SPBLAS_ENABLE_CUSPARSE
65+
operation_info_t(index<> result_shape, offset_t result_nnz,
66+
__cusparse::operation_state_t&& state)
67+
: result_shape_(result_shape), result_nnz_(result_nnz),
68+
state_(std::move(state)) {}
69+
#endif
70+
5671
void update_impl_(index<> result_shape, offset_t result_nnz) {
5772
result_shape_ = result_shape;
5873
result_nnz_ = result_nnz;
@@ -76,6 +91,16 @@ class operation_info_t {
7691
public:
7792
__aoclsparse::operation_state_t state_;
7893
#endif
94+
95+
#ifdef SPBLAS_ENABLE_CUSPARSE
96+
public:
97+
__cusparse::operation_state_t state_;
98+
#endif
99+
100+
#ifdef SPBLAS_ENABLE_ROCSPARSE
101+
public:
102+
__rocsparse::operation_state_t state_;
103+
#endif
79104
};
80105

81106
} // namespace spblas
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#pragma once
2+
3+
#include <cusparse.h>
4+
#include <memory>
5+
6+
namespace spblas {
7+
namespace __cusparse {
8+
9+
class abstract_operation_state_t {
10+
public:
11+
// Common state that all operations need
12+
cusparseHandle_t handle() const {
13+
return handle_;
14+
}
15+
16+
// Make std::default_delete a friend so unique_ptr can delete us
17+
friend struct std::default_delete<abstract_operation_state_t>;
18+
19+
protected:
20+
abstract_operation_state_t() {
21+
cusparseCreate(&handle_);
22+
}
23+
24+
virtual ~abstract_operation_state_t() {
25+
if (handle_) {
26+
cusparseDestroy(handle_);
27+
}
28+
}
29+
30+
cusparseHandle_t handle_;
31+
};
32+
33+
} // namespace __cusparse
34+
} // namespace spblas
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
#pragma once
2+
3+
#include <cusparse.h>
4+
5+
#include <spblas/detail/types.hpp>
6+
#include <spblas/detail/view_inspectors.hpp>
7+
#include <spblas/vendor/cusparse/exception.hpp>
8+
#include <spblas/vendor/cusparse/types.hpp>
9+
10+
namespace spblas {
11+
12+
namespace __cusparse {
13+
14+
template <matrix M>
15+
requires __detail::is_csr_view_v<M>
16+
cusparseSpMatDescr_t create_cusparse_handle(M&& m) {
17+
cusparseSpMatDescr_t mat_descr;
18+
__cusparse::throw_if_error(cusparseCreateCsr(
19+
&mat_descr, __backend::shape(m)[0], __backend::shape(m)[1],
20+
m.values().size(), m.rowptr().data(), m.colind().data(),
21+
m.values().data(), detail::cusparse_index_type_v<tensor_offset_t<M>>,
22+
detail::cusparse_index_type_v<tensor_index_t<M>>,
23+
CUSPARSE_INDEX_BASE_ZERO, detail::cuda_data_type_v<tensor_scalar_t<M>>));
24+
25+
return mat_descr;
26+
}
27+
28+
template <vector V>
29+
requires __ranges::contiguous_range<V>
30+
cusparseDnVecDescr_t create_cusparse_handle(V&& v) {
31+
cusparseDnVecDescr_t vec_descr;
32+
__cusparse::throw_if_error(
33+
cusparseCreateDnVec(&vec_descr, __backend::shape(v), __ranges::data(v),
34+
detail::cuda_data_type_v<tensor_scalar_t<V>>));
35+
36+
return vec_descr;
37+
}
38+
39+
} // namespace __cusparse
40+
41+
} // namespace spblas
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#pragma once
2+
3+
#include <cusparse.h>
4+
#include <spblas/detail/view_inspectors.hpp>
5+
6+
namespace spblas {
7+
namespace __cusparse {
8+
9+
//
10+
// Takes in a CSR or CSR_transpose (aka CSC) or CSC or CSC_transpose
11+
// and returns the cusparseOperation_t value associated with it being
12+
// represented in the CSR format
13+
//
14+
// CSR = CSR + NON_TRANSPOSE
15+
// CSR_transpose = CSR + TRANSPOSE
16+
// CSC = CSR + TRANSPOSE
17+
// CSC_transpose = CSR + NON_TRANSPOSE
18+
//
19+
template <matrix M>
20+
cusparseOperation_t get_transpose(M&& m) {
21+
static_assert(__detail::has_csr_base<M> || __detail::has_csc_base<M>);
22+
if constexpr (__detail::has_base<M>) {
23+
return get_transpose(m.base());
24+
} else if constexpr (__detail::is_csr_view_v<M>) {
25+
return CUSPARSE_OPERATION_NON_TRANSPOSE;
26+
} else if constexpr (__detail::is_csc_view_v<M>) {
27+
return CUSPARSE_OPERATION_TRANSPOSE;
28+
}
29+
}
30+
31+
} // namespace __cusparse
32+
} // namespace spblas

0 commit comments

Comments
 (0)