diff --git a/include/spblas/vendor/rocsparse/exception.hpp b/include/spblas/vendor/rocsparse/exception.hpp index 6a2ed5c..cb836b5 100644 --- a/include/spblas/vendor/rocsparse/exception.hpp +++ b/include/spblas/vendor/rocsparse/exception.hpp @@ -10,7 +10,7 @@ namespace spblas { namespace __rocsparse { // Throw an exception if the hipError_t is not hipSuccess. -void throw_if_error(hipError_t error_code, std::string prefix = "") { +inline void throw_if_error(hipError_t error_code, std::string prefix = "") { if (error_code == hipSuccess) { return; } @@ -21,7 +21,7 @@ void throw_if_error(hipError_t error_code, std::string prefix = "") { } // Throw an exception if the rocsparse_status is not rocsparse_status_success. -void throw_if_error(rocsparse_status error_code) { +inline void throw_if_error(rocsparse_status error_code) { if (error_code == rocsparse_status_success) { return; } else if (error_code == rocsparse_status_invalid_handle) { diff --git a/include/spblas/vendor/rocsparse/multiply_spgemm.hpp b/include/spblas/vendor/rocsparse/multiply_spgemm.hpp new file mode 100644 index 0000000..0b58f37 --- /dev/null +++ b/include/spblas/vendor/rocsparse/multiply_spgemm.hpp @@ -0,0 +1,319 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include + +#include +#include + +#include "exception.hpp" +#include "hip_allocator.hpp" +#include "types.hpp" + +namespace spblas { +namespace __rocsparse { + +template +T create_null_matrix() { + return {nullptr, nullptr, nullptr, index{0, 0}, 0}; +} + +} // namespace __rocsparse + +class spgemm_state_t { +public: + spgemm_state_t() : spgemm_state_t(rocsparse::hip_allocator{}) {} + + spgemm_state_t(rocsparse::hip_allocator alloc) + : alloc_(alloc), buffer_size_(0), workspace_(nullptr), result_nnz_(0), + result_shape_(0, 0) { + rocsparse_handle handle; + __rocsparse::throw_if_error(rocsparse_create_handle(&handle)); + if (auto stream = alloc.stream()) { + rocsparse_set_stream(handle, stream); + } + handle_ = handle_manager(handle, [](rocsparse_handle handle) { + __rocsparse::throw_if_error(rocsparse_destroy_handle(handle)); + }); + } + + spgemm_state_t(rocsparse::hip_allocator alloc, rocsparse_handle handle) + : alloc_(alloc), buffer_size_(0), workspace_(nullptr), result_nnz_(0), + result_shape_(0, 0) { + handle_ = handle_manager(handle, [](rocsparse_handle handle) { + // it is provided by user, we do not delete it at all. + }); + } + + ~spgemm_state_t() { + alloc_.deallocate(this->workspace_, this->buffer_size_); + __rocsparse::throw_if_error(rocsparse_destroy_spmat_descr(this->mat_a_)); + __rocsparse::throw_if_error(rocsparse_destroy_spmat_descr(this->mat_b_)); + __rocsparse::throw_if_error(rocsparse_destroy_spmat_descr(this->mat_c_)); + __rocsparse::throw_if_error(rocsparse_destroy_spmat_descr(this->mat_d_)); + } + + auto result_shape() { + return this->result_shape_; + } + + auto result_nnz() { + return this->result_nnz_; + } + + template + requires __detail::has_csr_base && __detail::has_csr_base && + __detail::is_csr_view_v && __detail::has_csr_base + void multiply_compute(A&& a, B&& b, C&& c, D&& d) { + auto a_base = __detail::get_ultimate_base(a); + auto b_base = __detail::get_ultimate_base(b); + auto d_base = __detail::get_ultimate_base(d); + using matrix_type = decltype(a_base); + using input_type = decltype(b_base); + using output_type = std::remove_reference_t; + using value_type = typename matrix_type::scalar_type; + + size_t buffer_size = 0; + + auto alpha_optional = __detail::get_scaling_factor(a, b); + value_type alpha = alpha_optional.value_or(1); + auto beta_optional = __detail::get_scaling_factor(d); + value_type beta = beta_optional.value_or(1); + auto handle = this->handle_.get(); + // Create sparse matrix in CSR format + this->mat_a_ = __rocsparse::create_rocsparse_handle(a_base); + this->mat_b_ = __rocsparse::create_rocsparse_handle(b_base); + this->mat_c_ = __rocsparse::create_rocsparse_handle(c); + this->mat_d_ = __rocsparse::create_rocsparse_handle(d_base); + // ask buffer_size bytes for external memory + __rocsparse::throw_if_error(rocsparse_spgemm( + handle, rocsparse_operation_none, rocsparse_operation_none, &alpha, + this->mat_a_, this->mat_b_, &beta, this->mat_d_, this->mat_c_, + detail::rocsparse_data_type_v, rocsparse_spgemm_alg_default, + rocsparse_spgemm_stage_buffer_size, &buffer_size, nullptr)); + // allocate the new buffer if it requires more than what the buffer + // currently has. + if (buffer_size > this->buffer_size_) { + this->alloc_.deallocate(workspace_, this->buffer_size_); + this->buffer_size_ = buffer_size; + workspace_ = this->alloc_.allocate(buffer_size); + } + __rocsparse::throw_if_error(rocsparse_spgemm( + handle, rocsparse_operation_none, rocsparse_operation_none, &alpha, + this->mat_a_, this->mat_b_, &beta, this->mat_d_, this->mat_c_, + detail::rocsparse_data_type_v, rocsparse_spgemm_alg_default, + rocsparse_spgemm_stage_nnz, &this->buffer_size_, this->workspace_)); + // get matrix C non-zero entries and size + int64_t c_num_rows; + int64_t c_num_cols; + __rocsparse::throw_if_error(rocsparse_spmat_get_size( + this->mat_c_, &c_num_rows, &c_num_cols, &this->result_nnz_)); + // form a shape + this->result_shape_ = index(c_num_rows, c_num_cols); + } + + template + requires __detail::has_csr_base && __detail::has_csr_base && + __detail::is_csr_view_v && __detail::has_csr_base + void multiply_fill(A&& a, B&& b, C&& c, D&& d) { + auto a_base = __detail::get_ultimate_base(a); + auto b_base = __detail::get_ultimate_base(b); + using matrix_type = decltype(a_base); + using input_type = decltype(b_base); + using output_type = std::remove_reference_t; + using value_type = typename matrix_type::scalar_type; + + auto alpha_optional = __detail::get_scaling_factor(a, b); + tensor_scalar_t alpha = alpha_optional.value_or(1); + value_type alpha_val = alpha; + auto beta_optional = __detail::get_scaling_factor(d); + value_type beta = beta_optional.value_or(1); + + __rocsparse::throw_if_error(rocsparse_csr_set_pointers( + this->mat_c_, c.rowptr().data(), c.colind().data(), c.values().data())); + + __rocsparse::throw_if_error(rocsparse_spgemm( + handle_.get(), rocsparse_operation_none, rocsparse_operation_none, + &alpha, this->mat_a_, this->mat_b_, &beta, this->mat_d_, this->mat_c_, + detail::rocsparse_data_type_v, rocsparse_spgemm_alg_default, + rocsparse_spgemm_stage_compute, &this->buffer_size_, workspace_)); + } + + template + requires __detail::has_csr_base && __detail::has_csr_base && + __detail::is_csr_view_v && __detail::has_csr_base + void multiply_symbolic_fill(A&& a, B&& b, C&& c, D&& d) { + auto a_base = __detail::get_ultimate_base(a); + auto b_base = __detail::get_ultimate_base(b); + auto d_base = __detail::get_ultimate_base(d); + using matrix_type = decltype(a_base); + using input_type = decltype(b_base); + using output_type = std::remove_reference_t; + using value_type = typename matrix_type::scalar_type; + + auto alpha_optional = __detail::get_scaling_factor(a, b); + value_type alpha = alpha_optional.value_or(1); + auto beta_optional = __detail::get_scaling_factor(d); + value_type beta = beta_optional.value_or(1); + + __rocsparse::throw_if_error(rocsparse_csr_set_pointers( + this->mat_c_, c.rowptr().data(), c.colind().data(), c.values().data())); + + __rocsparse::throw_if_error(rocsparse_spgemm( + this->handle_.get(), rocsparse_operation_none, rocsparse_operation_none, + &alpha, this->mat_a_, this->mat_b_, &beta, this->mat_d_, this->mat_c_, + detail::rocsparse_data_type_v, rocsparse_spgemm_alg_default, + rocsparse_spgemm_stage_symbolic, &this->buffer_size_, + this->workspace_)); + } + + template + requires __detail::has_csr_base && __detail::has_csr_base && + __detail::is_csr_view_v && __detail::has_csr_base + void multiply_numeric(A&& a, B&& b, C&& c, D&& d) { + auto a_base = __detail::get_ultimate_base(a); + auto b_base = __detail::get_ultimate_base(b); + auto d_base = __detail::get_ultimate_base(d); + using matrix_type = decltype(a_base); + using input_type = decltype(b_base); + using output_type = std::remove_reference_t; + using value_type = typename matrix_type::scalar_type; + + auto alpha_optional = __detail::get_scaling_factor(a, b); + tensor_scalar_t alpha = alpha_optional.value_or(1); + value_type alpha_val = alpha; + auto beta_optional = __detail::get_scaling_factor(d); + value_type beta = beta_optional.value_or(1); + + // Update the pointer from the matrix but they must contains the same + // sparsity as the previous call. + __rocsparse::throw_if_error(rocsparse_csr_set_pointers( + this->mat_a_, a_base.rowptr().data(), a_base.colind().data(), + a_base.values().data())); + __rocsparse::throw_if_error(rocsparse_csr_set_pointers( + this->mat_b_, b_base.rowptr().data(), b_base.colind().data(), + b_base.values().data())); + __rocsparse::throw_if_error(rocsparse_csr_set_pointers( + this->mat_c_, c.rowptr().data(), c.colind().data(), c.values().data())); + if (d_base.values().data()) { + // when it is still a null matrix, we can not use set pointer function + __rocsparse::throw_if_error(rocsparse_csr_set_pointers( + this->mat_d_, d_base.rowptr().data(), d_base.colind().data(), + d_base.values().data())); + } + __rocsparse::throw_if_error(rocsparse_spgemm( + this->handle_.get(), rocsparse_operation_none, rocsparse_operation_none, + &alpha, this->mat_a_, this->mat_b_, &beta, this->mat_d_, this->mat_c_, + detail::rocsparse_data_type_v, rocsparse_spgemm_alg_default, + rocsparse_spgemm_stage_numeric, &this->buffer_size_, this->workspace_)); + } + +private: + using handle_manager = + std::unique_ptr::element_type, + std::function>; + handle_manager handle_; + rocsparse::hip_allocator alloc_; + std::uint64_t buffer_size_; + char* workspace_; + index result_shape_; + std::int64_t result_nnz_; + rocsparse_spmat_descr mat_a_; + rocsparse_spmat_descr mat_b_; + rocsparse_spmat_descr mat_c_; + rocsparse_spmat_descr mat_d_; +}; + +template + requires __detail::has_csr_base && __detail::has_csr_base && + __detail::is_csr_view_v +void multiply_inspect(spgemm_state_t& spgemm_handle, A&& a, B&& b, C&& c) {} + +template + requires __detail::has_csr_base && __detail::has_csr_base && + __detail::is_csr_view_v && __detail::has_csr_base +void multiply_compute(spgemm_state_t& spgemm_handle, A&& a, B&& b, C&& c, + D&& d) { + spgemm_handle.multiply_compute(a, b, c, d); +} + +template + requires __detail::has_csr_base && __detail::has_csr_base && + __detail::is_csr_view_v && __detail::has_csr_base +void multiply_fill(spgemm_state_t& spgemm_handle, A&& a, B&& b, C&& c, D&& d) { + spgemm_handle.multiply_fill(a, b, c, d); +} + +template + requires __detail::has_csr_base && __detail::has_csr_base && + __detail::is_csr_view_v && __detail::has_csr_base +void multiply_symbolic_compute(spgemm_state_t& spgemm_handle, A&& a, B&& b, + C&& c, D&& d) { + spgemm_handle.multiply_compute(a, b, c, d); +} + +template + requires __detail::has_csr_base && __detail::has_csr_base && + __detail::is_csr_view_v && __detail::has_csr_base +void multiply_symbolic_fill(spgemm_state_t& spgemm_handle, A&& a, B&& b, C&& c, + D&& d) { + spgemm_handle.multiply_symbolic_fill(a, b, c, d); +} + +template + requires __detail::has_csr_base && __detail::has_csr_base && + __detail::is_csr_view_v && __detail::has_csr_base +void multiply_numeric(spgemm_state_t& spgemm_handle, A&& a, B&& b, C&& c, + D&& d) { + spgemm_handle.multiply_numeric(a, b, c, d); +} + +// the followings support C = A*B by giving null D matrix. +template + requires __detail::has_csr_base && __detail::has_csr_base && + __detail::is_csr_view_v +void multiply_compute(spgemm_state_t& spgemm_handle, A&& a, B&& b, C&& c) { + auto d = __rocsparse::create_null_matrix>(); + spgemm_handle.multiply_compute(a, b, c, scaled(0.0, d)); +} + +template + requires __detail::has_csr_base && __detail::has_csr_base && + __detail::is_csr_view_v +void multiply_fill(spgemm_state_t& spgemm_handle, A&& a, B&& b, C&& c) { + auto d = __rocsparse::create_null_matrix>(); + spgemm_handle.multiply_fill(a, b, c, scaled(0.0, d)); +} + +template + requires __detail::has_csr_base && __detail::has_csr_base && + __detail::is_csr_view_v +void multiply_symbolic_compute(spgemm_state_t& spgemm_handle, A&& a, B&& b, + C&& c) { + auto d = __rocsparse::create_null_matrix>(); + spgemm_handle.multiply_compute(a, b, c, scaled(0.0, d)); +} + +template + requires __detail::has_csr_base && __detail::has_csr_base && + __detail::is_csr_view_v +void multiply_symbolic_fill(spgemm_state_t& spgemm_handle, A&& a, B&& b, + C&& c) { + auto d = __rocsparse::create_null_matrix>(); + spgemm_handle.multiply_symbolic_fill(a, b, c, scaled(0.0, d)); +} + +template + requires __detail::has_csr_base && __detail::has_csr_base && + __detail::is_csr_view_v +void multiply_numeric(spgemm_state_t& spgemm_handle, A&& a, B&& b, C&& c) { + auto d = __rocsparse::create_null_matrix>(); + spgemm_handle.multiply_numeric(a, b, c, scaled(0.0, d)); +} + +} // namespace spblas diff --git a/include/spblas/vendor/rocsparse/rocsparse.hpp b/include/spblas/vendor/rocsparse/rocsparse.hpp index 7caa698..014b2ba 100644 --- a/include/spblas/vendor/rocsparse/rocsparse.hpp +++ b/include/spblas/vendor/rocsparse/rocsparse.hpp @@ -1,3 +1,4 @@ #pragma once #include "multiply.hpp" +#include "multiply_spgemm.hpp" diff --git a/test/gtest/CMakeLists.txt b/test/gtest/CMakeLists.txt index 05e14d2..cd96de1 100644 --- a/test/gtest/CMakeLists.txt +++ b/test/gtest/CMakeLists.txt @@ -17,9 +17,12 @@ endif() # GPU tests if (SPBLAS_GPU_BACKEND) if (ENABLE_ROCSPARSE) - set_source_files_properties(device/spmv_test.cpp PROPERTIES LANGUAGE HIP) - endif() - list(APPEND TEST_SOURCES device/spmv_test.cpp) + set(GPUTEST_SOURCES device/spmv_test.cpp device/spgemm_test.cpp device/spgemm_reuse_test.cpp device/rocsparse/spgemm_4args_test.cpp) + set_source_files_properties(${GPUTEST_SOURCES} PROPERTIES LANGUAGE HIP) + else () + set(GPUTEST_SOURCES device/spmv_test.cpp) + endif () + list(APPEND TEST_SOURCES ${GPUTEST_SOURCES}) endif() add_executable(spblas-tests ${TEST_SOURCES}) diff --git a/test/gtest/device/rocsparse/spgemm_4args_test.cpp b/test/gtest/device/rocsparse/spgemm_4args_test.cpp new file mode 100644 index 0000000..1a8e75a --- /dev/null +++ b/test/gtest/device/rocsparse/spgemm_4args_test.cpp @@ -0,0 +1,421 @@ + +#include "../../util.hpp" +#include + +#include +#include + +using value_t = float; +using index_t = spblas::index_t; +using offset_t = spblas::offset_t; + +TEST(thrust_CsrView, SpGEMM_4Args) { + for (auto&& [m, k, nnz] : util::dims) { + for (auto&& n : {m, k}) { + auto [a_values, a_rowptr, a_colind, a_shape, a_nnz] = + spblas::generate_csr(m, k, nnz); + thrust::device_vector d_a_values(a_values); + thrust::device_vector d_a_rowptr(a_rowptr); + thrust::device_vector d_a_colind(a_colind); + spblas::csr_view d_a( + d_a_values.data().get(), d_a_rowptr.data().get(), + d_a_colind.data().get(), a_shape, a_nnz); + spblas::csr_view a(a_values, a_rowptr, + a_colind, a_shape, a_nnz); + + auto [b_values, b_rowptr, b_colind, b_shape, b_nnz] = + spblas::generate_csr(k, n, nnz); + thrust::device_vector d_b_values(b_values); + thrust::device_vector d_b_rowptr(b_rowptr); + thrust::device_vector d_b_colind(b_colind); + spblas::csr_view d_b( + d_b_values.data().get(), d_b_rowptr.data().get(), + d_b_colind.data().get(), b_shape, b_nnz); + spblas::csr_view b(b_values, b_rowptr, + b_colind, b_shape, b_nnz); + + auto [d_values, d_rowptr, d_colind, d_shape, d_nnz] = + spblas::generate_csr(m, n, nnz); + thrust::device_vector d_d_values(d_values); + thrust::device_vector d_d_rowptr(d_rowptr); + thrust::device_vector d_d_colind(d_colind); + spblas::csr_view d_d( + d_d_values.data().get(), d_d_rowptr.data().get(), + d_d_colind.data().get(), d_shape, d_nnz); + spblas::csr_view d(d_values, d_rowptr, + d_colind, d_shape, d_nnz); + + thrust::device_vector d_c_rowptr(m + 1); + + spblas::csr_view d_c( + nullptr, d_c_rowptr.data().get(), nullptr, {m, n}, 0); + + spblas::spgemm_state_t state; + spblas::multiply_compute(state, d_a, d_b, d_c, d_d); + auto nnz = state.result_nnz(); + thrust::device_vector d_c_values(nnz); + thrust::device_vector d_c_colind(nnz); + std::span d_c_values_span(d_c_values.data().get(), nnz); + std::span d_c_rowptr_span(d_c_rowptr.data().get(), m + 1); + std::span d_c_colind_span(d_c_colind.data().get(), nnz); + d_c.update(d_c_values_span, d_c_rowptr_span, d_c_colind_span, {m, n}, + nnz); + + spblas::multiply_fill(state, d_a, d_b, d_c, d_d); + + std::vector c_values(nnz); + std::vector c_rowptr(m + 1); + std::vector c_colind(nnz); + thrust::copy(d_c_values.begin(), d_c_values.end(), c_values.begin()); + thrust::copy(d_c_rowptr.begin(), d_c_rowptr.end(), c_rowptr.begin()); + thrust::copy(d_c_colind.begin(), d_c_colind.end(), c_colind.begin()); + spblas::csr_view c(c_values, c_rowptr, + c_colind, {m, n}, nnz); + + spblas::__backend::spa_accumulator c_row_ref( + spblas::__backend::shape(c)[1]); + + spblas::__backend::spa_accumulator c_row_acc( + spblas::__backend::shape(c)[1]); + + for (auto&& [i, a_row] : spblas::__backend::rows(a)) { + c_row_ref.clear(); + for (auto&& [k, a_v] : a_row) { + auto&& b_row = spblas::__backend::lookup_row(b, k); + + for (auto&& [j, b_v] : b_row) { + c_row_ref[j] += a_v * b_v; + } + } + auto&& d_row = spblas::__backend::lookup_row(d, i); + for (auto&& [k, d_v] : d_row) { + c_row_ref[k] += d_v; + } + + auto&& c_row = spblas::__backend::lookup_row(c, i); + + // Accumulate output into `c_row_acc` so that we can allow + // duplicate column indices. + c_row_acc.clear(); + for (auto&& [j, c_v] : c_row) { + c_row_acc[j] += c_v; + } + + for (auto&& [j, c_v] : c_row) { + EXPECT_EQ_(c_row_ref[j], c_row_acc[j]); + } + + EXPECT_EQ(c_row_ref.size(), c_row_acc.size()); + } + } + } +} + +TEST(thrust_CsrView, SpGEMM_4Args_AScaled) { + value_t alpha = 2.0f; + for (auto&& [m, k, nnz] : util::dims) { + for (auto&& n : {m, k}) { + auto [a_values, a_rowptr, a_colind, a_shape, a_nnz] = + spblas::generate_csr(m, k, nnz); + thrust::device_vector d_a_values(a_values); + thrust::device_vector d_a_rowptr(a_rowptr); + thrust::device_vector d_a_colind(a_colind); + spblas::csr_view d_a( + d_a_values.data().get(), d_a_rowptr.data().get(), + d_a_colind.data().get(), a_shape, a_nnz); + spblas::csr_view a(a_values, a_rowptr, + a_colind, a_shape, a_nnz); + + auto [b_values, b_rowptr, b_colind, b_shape, b_nnz] = + spblas::generate_csr(k, n, nnz); + thrust::device_vector d_b_values(b_values); + thrust::device_vector d_b_rowptr(b_rowptr); + thrust::device_vector d_b_colind(b_colind); + spblas::csr_view d_b( + d_b_values.data().get(), d_b_rowptr.data().get(), + d_b_colind.data().get(), b_shape, b_nnz); + spblas::csr_view b(b_values, b_rowptr, + b_colind, b_shape, b_nnz); + + auto [d_values, d_rowptr, d_colind, d_shape, d_nnz] = + spblas::generate_csr(m, n, nnz); + thrust::device_vector d_d_values(d_values); + thrust::device_vector d_d_rowptr(d_rowptr); + thrust::device_vector d_d_colind(d_colind); + spblas::csr_view d_d( + d_d_values.data().get(), d_d_rowptr.data().get(), + d_d_colind.data().get(), d_shape, d_nnz); + spblas::csr_view d(d_values, d_rowptr, + d_colind, d_shape, d_nnz); + + thrust::device_vector d_c_rowptr(m + 1); + + spblas::csr_view d_c( + nullptr, d_c_rowptr.data().get(), nullptr, {m, n}, 0); + + spblas::spgemm_state_t state; + spblas::multiply_compute(state, scaled(alpha, d_a), d_b, d_c, d_d); + auto nnz = state.result_nnz(); + thrust::device_vector d_c_values(nnz); + thrust::device_vector d_c_colind(nnz); + std::span d_c_values_span(d_c_values.data().get(), nnz); + std::span d_c_rowptr_span(d_c_rowptr.data().get(), m + 1); + std::span d_c_colind_span(d_c_colind.data().get(), nnz); + d_c.update(d_c_values_span, d_c_rowptr_span, d_c_colind_span, {m, n}, + nnz); + + spblas::multiply_fill(state, scaled(alpha, d_a), d_b, d_c, d_d); + + std::vector c_values(nnz); + std::vector c_rowptr(m + 1); + std::vector c_colind(nnz); + thrust::copy(d_c_values.begin(), d_c_values.end(), c_values.begin()); + thrust::copy(d_c_rowptr.begin(), d_c_rowptr.end(), c_rowptr.begin()); + thrust::copy(d_c_colind.begin(), d_c_colind.end(), c_colind.begin()); + spblas::csr_view c(c_values, c_rowptr, + c_colind, {m, n}, nnz); + + spblas::__backend::spa_accumulator c_row_ref( + spblas::__backend::shape(c)[1]); + + spblas::__backend::spa_accumulator c_row_acc( + spblas::__backend::shape(c)[1]); + + for (auto&& [i, a_row] : spblas::__backend::rows(a)) { + c_row_ref.clear(); + for (auto&& [k, a_v] : a_row) { + auto&& b_row = spblas::__backend::lookup_row(b, k); + + for (auto&& [j, b_v] : b_row) { + c_row_ref[j] += alpha * a_v * b_v; + } + } + auto&& d_row = spblas::__backend::lookup_row(d, i); + for (auto&& [k, d_v] : d_row) { + c_row_ref[k] += d_v; + } + + auto&& c_row = spblas::__backend::lookup_row(c, i); + + // Accumulate output into `c_row_acc` so that we can allow + // duplicate column indices. + c_row_acc.clear(); + for (auto&& [j, c_v] : c_row) { + c_row_acc[j] += c_v; + } + + for (auto&& [j, c_v] : c_row) { + EXPECT_EQ_(c_row_ref[j], c_row_acc[j]); + } + + EXPECT_EQ(c_row_ref.size(), c_row_acc.size()); + } + } + } +} + +TEST(thrust_CsrView, SpGEMM_4Args_BScaled) { + value_t alpha = 2.0f; + for (auto&& [m, k, nnz] : util::dims) { + for (auto&& n : {m, k}) { + auto [a_values, a_rowptr, a_colind, a_shape, a_nnz] = + spblas::generate_csr(m, k, nnz); + thrust::device_vector d_a_values(a_values); + thrust::device_vector d_a_rowptr(a_rowptr); + thrust::device_vector d_a_colind(a_colind); + spblas::csr_view d_a( + d_a_values.data().get(), d_a_rowptr.data().get(), + d_a_colind.data().get(), a_shape, a_nnz); + spblas::csr_view a(a_values, a_rowptr, + a_colind, a_shape, a_nnz); + + auto [b_values, b_rowptr, b_colind, b_shape, b_nnz] = + spblas::generate_csr(k, n, nnz); + thrust::device_vector d_b_values(b_values); + thrust::device_vector d_b_rowptr(b_rowptr); + thrust::device_vector d_b_colind(b_colind); + spblas::csr_view d_b( + d_b_values.data().get(), d_b_rowptr.data().get(), + d_b_colind.data().get(), b_shape, b_nnz); + spblas::csr_view b(b_values, b_rowptr, + b_colind, b_shape, b_nnz); + + auto [d_values, d_rowptr, d_colind, d_shape, d_nnz] = + spblas::generate_csr(m, n, nnz); + thrust::device_vector d_d_values(d_values); + thrust::device_vector d_d_rowptr(d_rowptr); + thrust::device_vector d_d_colind(d_colind); + spblas::csr_view d_d( + d_d_values.data().get(), d_d_rowptr.data().get(), + d_d_colind.data().get(), d_shape, d_nnz); + spblas::csr_view d(d_values, d_rowptr, + d_colind, d_shape, d_nnz); + + thrust::device_vector d_c_rowptr(m + 1); + + spblas::csr_view d_c( + nullptr, d_c_rowptr.data().get(), nullptr, {m, n}, 0); + + spblas::spgemm_state_t state; + spblas::multiply_compute(state, d_a, scaled(alpha, d_b), d_c, d_d); + auto nnz = state.result_nnz(); + thrust::device_vector d_c_values(nnz); + thrust::device_vector d_c_colind(nnz); + std::span d_c_values_span(d_c_values.data().get(), nnz); + std::span d_c_rowptr_span(d_c_rowptr.data().get(), m + 1); + std::span d_c_colind_span(d_c_colind.data().get(), nnz); + d_c.update(d_c_values_span, d_c_rowptr_span, d_c_colind_span, {m, n}, + nnz); + + spblas::multiply_fill(state, d_a, scaled(alpha, d_b), d_c, d_d); + + std::vector c_values(nnz); + std::vector c_rowptr(m + 1); + std::vector c_colind(nnz); + thrust::copy(d_c_values.begin(), d_c_values.end(), c_values.begin()); + thrust::copy(d_c_rowptr.begin(), d_c_rowptr.end(), c_rowptr.begin()); + thrust::copy(d_c_colind.begin(), d_c_colind.end(), c_colind.begin()); + spblas::csr_view c(c_values, c_rowptr, + c_colind, {m, n}, nnz); + + spblas::__backend::spa_accumulator c_row_ref( + spblas::__backend::shape(c)[1]); + + spblas::__backend::spa_accumulator c_row_acc( + spblas::__backend::shape(c)[1]); + + for (auto&& [i, a_row] : spblas::__backend::rows(a)) { + c_row_ref.clear(); + for (auto&& [k, a_v] : a_row) { + auto&& b_row = spblas::__backend::lookup_row(b, k); + + for (auto&& [j, b_v] : b_row) { + c_row_ref[j] += alpha * a_v * b_v; + } + } + auto&& d_row = spblas::__backend::lookup_row(d, i); + for (auto&& [k, d_v] : d_row) { + c_row_ref[k] += d_v; + } + + auto&& c_row = spblas::__backend::lookup_row(c, i); + + // Accumulate output into `c_row_acc` so that we can allow + // duplicate column indices. + c_row_acc.clear(); + for (auto&& [j, c_v] : c_row) { + c_row_acc[j] += c_v; + } + + for (auto&& [j, c_v] : c_row) { + EXPECT_EQ_(c_row_ref[j], c_row_acc[j]); + } + + EXPECT_EQ(c_row_ref.size(), c_row_acc.size()); + } + } + } +} + +TEST(thrust_CsrView, SpGEMM_4Args_DScaled) { + value_t alpha = 2.0f; + for (auto&& [m, k, nnz] : util::dims) { + for (auto&& n : {m, k}) { + auto [a_values, a_rowptr, a_colind, a_shape, a_nnz] = + spblas::generate_csr(m, k, nnz); + thrust::device_vector d_a_values(a_values); + thrust::device_vector d_a_rowptr(a_rowptr); + thrust::device_vector d_a_colind(a_colind); + spblas::csr_view d_a( + d_a_values.data().get(), d_a_rowptr.data().get(), + d_a_colind.data().get(), a_shape, a_nnz); + spblas::csr_view a(a_values, a_rowptr, + a_colind, a_shape, a_nnz); + + auto [b_values, b_rowptr, b_colind, b_shape, b_nnz] = + spblas::generate_csr(k, n, nnz); + thrust::device_vector d_b_values(b_values); + thrust::device_vector d_b_rowptr(b_rowptr); + thrust::device_vector d_b_colind(b_colind); + spblas::csr_view d_b( + d_b_values.data().get(), d_b_rowptr.data().get(), + d_b_colind.data().get(), b_shape, b_nnz); + spblas::csr_view b(b_values, b_rowptr, + b_colind, b_shape, b_nnz); + + auto [d_values, d_rowptr, d_colind, d_shape, d_nnz] = + spblas::generate_csr(m, n, nnz); + thrust::device_vector d_d_values(d_values); + thrust::device_vector d_d_rowptr(d_rowptr); + thrust::device_vector d_d_colind(d_colind); + spblas::csr_view d_d( + d_d_values.data().get(), d_d_rowptr.data().get(), + d_d_colind.data().get(), d_shape, d_nnz); + spblas::csr_view d(d_values, d_rowptr, + d_colind, d_shape, d_nnz); + + thrust::device_vector d_c_rowptr(m + 1); + + spblas::csr_view d_c( + nullptr, d_c_rowptr.data().get(), nullptr, {m, n}, 0); + + spblas::spgemm_state_t state; + spblas::multiply_compute(state, d_a, d_b, d_c, scaled(alpha, d_d)); + auto nnz = state.result_nnz(); + thrust::device_vector d_c_values(nnz); + thrust::device_vector d_c_colind(nnz); + std::span d_c_values_span(d_c_values.data().get(), nnz); + std::span d_c_rowptr_span(d_c_rowptr.data().get(), m + 1); + std::span d_c_colind_span(d_c_colind.data().get(), nnz); + d_c.update(d_c_values_span, d_c_rowptr_span, d_c_colind_span, {m, n}, + nnz); + + spblas::multiply_fill(state, d_a, d_b, d_c, scaled(alpha, d_d)); + + std::vector c_values(nnz); + std::vector c_rowptr(m + 1); + std::vector c_colind(nnz); + thrust::copy(d_c_values.begin(), d_c_values.end(), c_values.begin()); + thrust::copy(d_c_rowptr.begin(), d_c_rowptr.end(), c_rowptr.begin()); + thrust::copy(d_c_colind.begin(), d_c_colind.end(), c_colind.begin()); + spblas::csr_view c(c_values, c_rowptr, + c_colind, {m, n}, nnz); + + spblas::__backend::spa_accumulator c_row_ref( + spblas::__backend::shape(c)[1]); + + spblas::__backend::spa_accumulator c_row_acc( + spblas::__backend::shape(c)[1]); + + for (auto&& [i, a_row] : spblas::__backend::rows(a)) { + c_row_ref.clear(); + for (auto&& [k, a_v] : a_row) { + auto&& b_row = spblas::__backend::lookup_row(b, k); + + for (auto&& [j, b_v] : b_row) { + c_row_ref[j] += a_v * b_v; + } + } + auto&& d_row = spblas::__backend::lookup_row(d, i); + for (auto&& [k, d_v] : d_row) { + c_row_ref[k] += alpha * d_v; + } + + auto&& c_row = spblas::__backend::lookup_row(c, i); + + // Accumulate output into `c_row_acc` so that we can allow + // duplicate column indices. + c_row_acc.clear(); + for (auto&& [j, c_v] : c_row) { + c_row_acc[j] += c_v; + } + + for (auto&& [j, c_v] : c_row) { + EXPECT_EQ_(c_row_ref[j], c_row_acc[j]); + } + + EXPECT_EQ(c_row_ref.size(), c_row_acc.size()); + } + } + } +} diff --git a/test/gtest/device/spgemm_reuse_test.cpp b/test/gtest/device/spgemm_reuse_test.cpp new file mode 100644 index 0000000..b798d6e --- /dev/null +++ b/test/gtest/device/spgemm_reuse_test.cpp @@ -0,0 +1,448 @@ + +#include "../util.hpp" +#include + +#include +#include + +using value_t = float; +using index_t = spblas::index_t; +using offset_t = spblas::offset_t; + +TEST(thrust_CsrView, SpGEMMReuse) { + for (auto&& [m, k, nnz] : util::dims) { + for (auto&& n : {m, k}) { + auto [a_values, a_rowptr, a_colind, a_shape, a_nnz] = + spblas::generate_csr(m, k, nnz); + thrust::device_vector d_a_values(a_values); + thrust::device_vector d_a_rowptr(a_rowptr); + thrust::device_vector d_a_colind(a_colind); + spblas::csr_view d_a( + d_a_values.data().get(), d_a_rowptr.data().get(), + d_a_colind.data().get(), a_shape, a_nnz); + spblas::csr_view a(a_values, a_rowptr, + a_colind, a_shape, a_nnz); + + auto [b_values, b_rowptr, b_colind, b_shape, b_nnz] = + spblas::generate_csr(k, n, nnz); + thrust::device_vector d_b_values(b_values); + thrust::device_vector d_b_rowptr(b_rowptr); + thrust::device_vector d_b_colind(b_colind); + spblas::csr_view d_b( + d_b_values.data().get(), d_b_rowptr.data().get(), + d_b_colind.data().get(), b_shape, b_nnz); + spblas::csr_view b(b_values, b_rowptr, + b_colind, b_shape, b_nnz); + + thrust::device_vector d_c_rowptr(m + 1); + + spblas::csr_view d_c( + nullptr, d_c_rowptr.data().get(), nullptr, {m, n}, 0); + + spblas::spgemm_state_t state; + spblas::multiply_symbolic_compute(state, d_a, d_b, d_c); + auto nnz = state.result_nnz(); + thrust::device_vector d_c_values(nnz); + thrust::device_vector d_c_colind(nnz); + std::span d_c_values_span(d_c_values.data().get(), nnz); + std::span d_c_rowptr_span(d_c_rowptr.data().get(), m + 1); + std::span d_c_colind_span(d_c_colind.data().get(), nnz); + d_c.update(d_c_values_span, d_c_rowptr_span, d_c_colind_span, {m, n}, + nnz); + + spblas::multiply_symbolic_fill(state, d_a, d_b, d_c); + std::mt19937 g(0); + for (int i = 0; i < 3; i++) { + // we can change the value of a and b but only need to call + // multiply_numeric answer here. + if (i != 0) { + // regenerate value of a and b; + std::uniform_real_distribution val_dist(0.0, 100.0); + for (auto& v : a_values) { + v = val_dist(g); + } + for (auto& v : b_values) { + v = val_dist(g); + } + thrust::copy(a_values.begin(), a_values.end(), d_a_values.begin()); + thrust::copy(b_values.begin(), b_values.end(), d_b_values.begin()); + } + spblas::multiply_numeric(state, d_a, d_b, d_c); + std::vector c_values(nnz); + std::vector c_rowptr(m + 1); + std::vector c_colind(nnz); + thrust::copy(d_c_values.begin(), d_c_values.end(), c_values.begin()); + thrust::copy(d_c_rowptr.begin(), d_c_rowptr.end(), c_rowptr.begin()); + thrust::copy(d_c_colind.begin(), d_c_colind.end(), c_colind.begin()); + spblas::csr_view c(c_values, c_rowptr, + c_colind, {m, n}, nnz); + + spblas::__backend::spa_accumulator c_row_ref( + spblas::__backend::shape(c)[1]); + + spblas::__backend::spa_accumulator c_row_acc( + spblas::__backend::shape(c)[1]); + + for (auto&& [i, a_row] : spblas::__backend::rows(a)) { + c_row_ref.clear(); + for (auto&& [k, a_v] : a_row) { + auto&& b_row = spblas::__backend::lookup_row(b, k); + + for (auto&& [j, b_v] : b_row) { + c_row_ref[j] += a_v * b_v; + } + } + + auto&& c_row = spblas::__backend::lookup_row(c, i); + + // Accumulate output into `c_row_acc` so that we can allow + // duplicate column indices. + c_row_acc.clear(); + for (auto&& [j, c_v] : c_row) { + c_row_acc[j] += c_v; + } + + for (auto&& [j, c_v] : c_row) { + EXPECT_EQ_(c_row_ref[j], c_row_acc[j]); + } + + EXPECT_EQ(c_row_ref.size(), c_row_acc.size()); + } + } + } + } +} + +TEST(thrust_CsrView, SpGEMMReuse_AScaled) { + value_t alpha = 2.0f; + for (auto&& [m, k, nnz] : util::dims) { + for (auto&& n : {m, k}) { + auto [a_values, a_rowptr, a_colind, a_shape, a_nnz] = + spblas::generate_csr(m, k, nnz); + thrust::device_vector d_a_values(a_values); + thrust::device_vector d_a_rowptr(a_rowptr); + thrust::device_vector d_a_colind(a_colind); + spblas::csr_view d_a( + d_a_values.data().get(), d_a_rowptr.data().get(), + d_a_colind.data().get(), a_shape, a_nnz); + spblas::csr_view a(a_values, a_rowptr, + a_colind, a_shape, a_nnz); + + auto [b_values, b_rowptr, b_colind, b_shape, b_nnz] = + spblas::generate_csr(k, n, nnz); + thrust::device_vector d_b_values(b_values); + thrust::device_vector d_b_rowptr(b_rowptr); + thrust::device_vector d_b_colind(b_colind); + spblas::csr_view d_b( + d_b_values.data().get(), d_b_rowptr.data().get(), + d_b_colind.data().get(), b_shape, b_nnz); + spblas::csr_view b(b_values, b_rowptr, + b_colind, b_shape, b_nnz); + + thrust::device_vector d_c_rowptr(m + 1); + + spblas::csr_view d_c( + nullptr, d_c_rowptr.data().get(), nullptr, {m, n}, 0); + + spblas::spgemm_state_t state; + spblas::multiply_symbolic_compute(state, scaled(alpha, d_a), d_b, d_c); + auto nnz = state.result_nnz(); + thrust::device_vector d_c_values(nnz); + thrust::device_vector d_c_colind(nnz); + std::span d_c_values_span(d_c_values.data().get(), nnz); + std::span d_c_rowptr_span(d_c_rowptr.data().get(), m + 1); + std::span d_c_colind_span(d_c_colind.data().get(), nnz); + d_c.update(d_c_values_span, d_c_rowptr_span, d_c_colind_span, {m, n}, + nnz); + + spblas::multiply_symbolic_fill(state, scaled(alpha, d_a), d_b, d_c); + std::mt19937 g(0); + for (int i = 0; i < 3; i++) { + // we can change the value of a and b but only need to call + // multiply_numeric answer here. + if (i != 0) { + // regenerate value of a and b; + std::uniform_real_distribution val_dist(0.0, 100.0); + for (auto& v : a_values) { + v = val_dist(g); + } + for (auto& v : b_values) { + v = val_dist(g); + } + thrust::copy(a_values.begin(), a_values.end(), d_a_values.begin()); + thrust::copy(b_values.begin(), b_values.end(), d_b_values.begin()); + } + spblas::multiply_numeric(state, scaled(alpha, d_a), d_b, d_c); + std::vector c_values(nnz); + std::vector c_rowptr(m + 1); + std::vector c_colind(nnz); + thrust::copy(d_c_values.begin(), d_c_values.end(), c_values.begin()); + thrust::copy(d_c_rowptr.begin(), d_c_rowptr.end(), c_rowptr.begin()); + thrust::copy(d_c_colind.begin(), d_c_colind.end(), c_colind.begin()); + spblas::csr_view c(c_values, c_rowptr, + c_colind, {m, n}, nnz); + + spblas::__backend::spa_accumulator c_row_ref( + spblas::__backend::shape(c)[1]); + + spblas::__backend::spa_accumulator c_row_acc( + spblas::__backend::shape(c)[1]); + + for (auto&& [i, a_row] : spblas::__backend::rows(a)) { + c_row_ref.clear(); + for (auto&& [k, a_v] : a_row) { + auto&& b_row = spblas::__backend::lookup_row(b, k); + + for (auto&& [j, b_v] : b_row) { + c_row_ref[j] += alpha * a_v * b_v; + } + } + + auto&& c_row = spblas::__backend::lookup_row(c, i); + + // Accumulate output into `c_row_acc` so that we can allow + // duplicate column indices. + c_row_acc.clear(); + for (auto&& [j, c_v] : c_row) { + c_row_acc[j] += c_v; + } + + for (auto&& [j, c_v] : c_row) { + EXPECT_EQ_(c_row_ref[j], c_row_acc[j]); + } + + EXPECT_EQ(c_row_ref.size(), c_row_acc.size()); + } + } + } + } +} + +TEST(thrust_CsrView, SpGEMMReuse_BScaled) { + value_t alpha = 2.0f; + for (auto&& [m, k, nnz] : util::dims) { + for (auto&& n : {m, k}) { + auto [a_values, a_rowptr, a_colind, a_shape, a_nnz] = + spblas::generate_csr(m, k, nnz); + thrust::device_vector d_a_values(a_values); + thrust::device_vector d_a_rowptr(a_rowptr); + thrust::device_vector d_a_colind(a_colind); + spblas::csr_view d_a( + d_a_values.data().get(), d_a_rowptr.data().get(), + d_a_colind.data().get(), a_shape, a_nnz); + spblas::csr_view a(a_values, a_rowptr, + a_colind, a_shape, a_nnz); + + auto [b_values, b_rowptr, b_colind, b_shape, b_nnz] = + spblas::generate_csr(k, n, nnz); + thrust::device_vector d_b_values(b_values); + thrust::device_vector d_b_rowptr(b_rowptr); + thrust::device_vector d_b_colind(b_colind); + spblas::csr_view d_b( + d_b_values.data().get(), d_b_rowptr.data().get(), + d_b_colind.data().get(), b_shape, b_nnz); + spblas::csr_view b(b_values, b_rowptr, + b_colind, b_shape, b_nnz); + + thrust::device_vector d_c_rowptr(m + 1); + + spblas::csr_view d_c( + nullptr, d_c_rowptr.data().get(), nullptr, {m, n}, 0); + + spblas::spgemm_state_t state; + spblas::multiply_symbolic_compute(state, d_a, scaled(alpha, d_b), d_c); + auto nnz = state.result_nnz(); + thrust::device_vector d_c_values(nnz); + thrust::device_vector d_c_colind(nnz); + std::span d_c_values_span(d_c_values.data().get(), nnz); + std::span d_c_rowptr_span(d_c_rowptr.data().get(), m + 1); + std::span d_c_colind_span(d_c_colind.data().get(), nnz); + d_c.update(d_c_values_span, d_c_rowptr_span, d_c_colind_span, {m, n}, + nnz); + + spblas::multiply_symbolic_fill(state, d_a, scaled(alpha, d_b), d_c); + std::mt19937 g(0); + for (int i = 0; i < 3; i++) { + // we can change the value of a and b but only need to call + // multiply_numeric answer here. + if (i != 0) { + // regenerate value of a and b; + std::uniform_real_distribution val_dist(0.0, 100.0); + for (auto& v : a_values) { + v = val_dist(g); + } + for (auto& v : b_values) { + v = val_dist(g); + } + thrust::copy(a_values.begin(), a_values.end(), d_a_values.begin()); + thrust::copy(b_values.begin(), b_values.end(), d_b_values.begin()); + } + spblas::multiply_numeric(state, d_a, scaled(alpha, d_b), d_c); + std::vector c_values(nnz); + std::vector c_rowptr(m + 1); + std::vector c_colind(nnz); + thrust::copy(d_c_values.begin(), d_c_values.end(), c_values.begin()); + thrust::copy(d_c_rowptr.begin(), d_c_rowptr.end(), c_rowptr.begin()); + thrust::copy(d_c_colind.begin(), d_c_colind.end(), c_colind.begin()); + spblas::csr_view c(c_values, c_rowptr, + c_colind, {m, n}, nnz); + + spblas::__backend::spa_accumulator c_row_ref( + spblas::__backend::shape(c)[1]); + + spblas::__backend::spa_accumulator c_row_acc( + spblas::__backend::shape(c)[1]); + + for (auto&& [i, a_row] : spblas::__backend::rows(a)) { + c_row_ref.clear(); + for (auto&& [k, a_v] : a_row) { + auto&& b_row = spblas::__backend::lookup_row(b, k); + + for (auto&& [j, b_v] : b_row) { + c_row_ref[j] += alpha * a_v * b_v; + } + } + + auto&& c_row = spblas::__backend::lookup_row(c, i); + + // Accumulate output into `c_row_acc` so that we can allow + // duplicate column indices. + c_row_acc.clear(); + for (auto&& [j, c_v] : c_row) { + c_row_acc[j] += c_v; + } + + for (auto&& [j, c_v] : c_row) { + EXPECT_EQ_(c_row_ref[j], c_row_acc[j]); + } + + EXPECT_EQ(c_row_ref.size(), c_row_acc.size()); + } + } + } + } +} + +TEST(thrust_CsrView, SpGEMMReuseAndChangePointer) { + for (auto&& [m, k, nnz] : util::dims) { + for (auto&& n : {m, k}) { + auto [a_values, a_rowptr, a_colind, a_shape, a_nnz] = + spblas::generate_csr(m, k, nnz); + thrust::device_vector d_a_values(a_values); + thrust::device_vector d_a_rowptr(a_rowptr); + thrust::device_vector d_a_colind(a_colind); + spblas::csr_view d_a( + d_a_values.data().get(), d_a_rowptr.data().get(), + d_a_colind.data().get(), a_shape, a_nnz); + spblas::csr_view a(a_values, a_rowptr, + a_colind, a_shape, a_nnz); + + auto [b_values, b_rowptr, b_colind, b_shape, b_nnz] = + spblas::generate_csr(k, n, nnz); + thrust::device_vector d_b_values(b_values); + thrust::device_vector d_b_rowptr(b_rowptr); + thrust::device_vector d_b_colind(b_colind); + spblas::csr_view d_b( + d_b_values.data().get(), d_b_rowptr.data().get(), + d_b_colind.data().get(), b_shape, b_nnz); + spblas::csr_view b(b_values, b_rowptr, + b_colind, b_shape, b_nnz); + + thrust::device_vector d_c_rowptr(m + 1); + + spblas::csr_view d_c( + nullptr, d_c_rowptr.data().get(), nullptr, {m, n}, 0); + + spblas::spgemm_state_t state; + spblas::multiply_symbolic_compute(state, d_a, d_b, d_c); + auto nnz = state.result_nnz(); + thrust::device_vector d_c_values(nnz); + thrust::device_vector d_c_colind(nnz); + std::span d_c_values_span(d_c_values.data().get(), nnz); + std::span d_c_rowptr_span(d_c_rowptr.data().get(), m + 1); + std::span d_c_colind_span(d_c_colind.data().get(), nnz); + d_c.update(d_c_values_span, d_c_rowptr_span, d_c_colind_span, {m, n}, + nnz); + + spblas::multiply_symbolic_fill(state, d_a, d_b, d_c); + std::mt19937 g(0); + for (int i = 0; i < 3; i++) { + // regenerate value of a and b; + std::uniform_real_distribution val_dist(0.0, 100.0); + for (auto& v : a_values) { + v = val_dist(g); + } + for (auto& v : b_values) { + v = val_dist(g); + } + // create different pointers than the symbolic phase, but they still + // hold the same sparsity + thrust::device_vector d_a_values_new(a_values); + thrust::device_vector d_a_colind_new(d_a_colind); + thrust::device_vector d_a_rowptr_new(d_a_rowptr); + thrust::device_vector d_b_values_new(b_values); + thrust::device_vector d_b_colind_new(d_b_colind); + thrust::device_vector d_b_rowptr_new(d_b_rowptr); + thrust::device_vector d_c_values_new(d_c_values); + thrust::device_vector d_c_colind_new(d_c_colind); + thrust::device_vector d_c_rowptr_new(d_c_rowptr); + spblas::csr_view d_a( + d_a_values_new.data().get(), d_a_rowptr_new.data().get(), + d_a_colind_new.data().get(), a_shape, a_nnz); + spblas::csr_view d_b( + d_b_values_new.data().get(), d_b_rowptr_new.data().get(), + d_b_colind_new.data().get(), b_shape, b_nnz); + spblas::csr_view d_c( + d_c_values_new.data().get(), d_c_rowptr_new.data().get(), + d_c_colind_new.data().get(), {m, n}, nnz); + // call numeric on new data + spblas::multiply_numeric(state, d_a, d_b, d_c); + // move c back to host memory + std::vector c_values(nnz); + std::vector c_rowptr(m + 1); + std::vector c_colind(nnz); + thrust::copy(d_c_values_new.begin(), d_c_values_new.end(), + c_values.begin()); + thrust::copy(d_c_rowptr_new.begin(), d_c_rowptr_new.end(), + c_rowptr.begin()); + thrust::copy(d_c_colind_new.begin(), d_c_colind_new.end(), + c_colind.begin()); + + spblas::csr_view c(c_values, c_rowptr, + c_colind, {m, n}, nnz); + + spblas::__backend::spa_accumulator c_row_ref( + spblas::__backend::shape(c)[1]); + + spblas::__backend::spa_accumulator c_row_acc( + spblas::__backend::shape(c)[1]); + + for (auto&& [i, a_row] : spblas::__backend::rows(a)) { + c_row_ref.clear(); + for (auto&& [k, a_v] : a_row) { + auto&& b_row = spblas::__backend::lookup_row(b, k); + + for (auto&& [j, b_v] : b_row) { + c_row_ref[j] += a_v * b_v; + } + } + + auto&& c_row = spblas::__backend::lookup_row(c, i); + + // Accumulate output into `c_row_acc` so that we can allow + // duplicate column indices. + c_row_acc.clear(); + for (auto&& [j, c_v] : c_row) { + c_row_acc[j] += c_v; + } + + for (auto&& [j, c_v] : c_row) { + EXPECT_EQ_(c_row_ref[j], c_row_acc[j]); + } + + EXPECT_EQ(c_row_ref.size(), c_row_acc.size()); + } + } + } + } +} diff --git a/test/gtest/device/spgemm_test.cpp b/test/gtest/device/spgemm_test.cpp new file mode 100644 index 0000000..98e1aed --- /dev/null +++ b/test/gtest/device/spgemm_test.cpp @@ -0,0 +1,273 @@ + +#include "../util.hpp" +#include + +#include +#include + +using value_t = float; +using index_t = spblas::index_t; +using offset_t = spblas::offset_t; + +TEST(thrust_CsrView, SpGEMM) { + for (auto&& [m, k, nnz] : util::dims) { + for (auto&& n : {m, k}) { + auto [a_values, a_rowptr, a_colind, a_shape, a_nnz] = + spblas::generate_csr(m, k, nnz); + thrust::device_vector d_a_values(a_values); + thrust::device_vector d_a_rowptr(a_rowptr); + thrust::device_vector d_a_colind(a_colind); + spblas::csr_view d_a( + d_a_values.data().get(), d_a_rowptr.data().get(), + d_a_colind.data().get(), a_shape, a_nnz); + spblas::csr_view a(a_values, a_rowptr, + a_colind, a_shape, a_nnz); + + auto [b_values, b_rowptr, b_colind, b_shape, b_nnz] = + spblas::generate_csr(k, n, nnz); + thrust::device_vector d_b_values(b_values); + thrust::device_vector d_b_rowptr(b_rowptr); + thrust::device_vector d_b_colind(b_colind); + spblas::csr_view d_b( + d_b_values.data().get(), d_b_rowptr.data().get(), + d_b_colind.data().get(), b_shape, b_nnz); + spblas::csr_view b(b_values, b_rowptr, + b_colind, b_shape, b_nnz); + + thrust::device_vector d_c_rowptr(m + 1); + + spblas::csr_view d_c( + nullptr, d_c_rowptr.data().get(), nullptr, {m, n}, 0); + + spblas::spgemm_state_t state; + spblas::multiply_compute(state, d_a, d_b, d_c); + auto nnz = state.result_nnz(); + thrust::device_vector d_c_values(nnz); + thrust::device_vector d_c_colind(nnz); + std::span d_c_values_span(d_c_values.data().get(), nnz); + std::span d_c_rowptr_span(d_c_rowptr.data().get(), m + 1); + std::span d_c_colind_span(d_c_colind.data().get(), nnz); + d_c.update(d_c_values_span, d_c_rowptr_span, d_c_colind_span, {m, n}, + nnz); + + spblas::multiply_fill(state, d_a, d_b, d_c); + + std::vector c_values(nnz); + std::vector c_rowptr(m + 1); + std::vector c_colind(nnz); + thrust::copy(d_c_values.begin(), d_c_values.end(), c_values.begin()); + thrust::copy(d_c_rowptr.begin(), d_c_rowptr.end(), c_rowptr.begin()); + thrust::copy(d_c_colind.begin(), d_c_colind.end(), c_colind.begin()); + spblas::csr_view c(c_values, c_rowptr, + c_colind, {m, n}, nnz); + + spblas::__backend::spa_accumulator c_row_ref( + spblas::__backend::shape(c)[1]); + + spblas::__backend::spa_accumulator c_row_acc( + spblas::__backend::shape(c)[1]); + + for (auto&& [i, a_row] : spblas::__backend::rows(a)) { + c_row_ref.clear(); + for (auto&& [k, a_v] : a_row) { + auto&& b_row = spblas::__backend::lookup_row(b, k); + + for (auto&& [j, b_v] : b_row) { + c_row_ref[j] += a_v * b_v; + } + } + + auto&& c_row = spblas::__backend::lookup_row(c, i); + + // Accumulate output into `c_row_acc` so that we can allow + // duplicate column indices. + c_row_acc.clear(); + for (auto&& [j, c_v] : c_row) { + c_row_acc[j] += c_v; + } + + for (auto&& [j, c_v] : c_row) { + EXPECT_EQ_(c_row_ref[j], c_row_acc[j]); + } + + EXPECT_EQ(c_row_ref.size(), c_row_acc.size()); + } + } + } +} + +TEST(thrust_CsrView, SpGEMM_AScaled) { + value_t alpha = 2.0f; + for (auto&& [m, k, nnz] : util::dims) { + for (auto&& n : {m, k}) { + auto [a_values, a_rowptr, a_colind, a_shape, a_nnz] = + spblas::generate_csr(m, k, nnz); + thrust::device_vector d_a_values(a_values); + thrust::device_vector d_a_rowptr(a_rowptr); + thrust::device_vector d_a_colind(a_colind); + spblas::csr_view d_a( + d_a_values.data().get(), d_a_rowptr.data().get(), + d_a_colind.data().get(), a_shape, a_nnz); + spblas::csr_view a(a_values, a_rowptr, + a_colind, a_shape, a_nnz); + + auto [b_values, b_rowptr, b_colind, b_shape, b_nnz] = + spblas::generate_csr(k, n, nnz); + thrust::device_vector d_b_values(b_values); + thrust::device_vector d_b_rowptr(b_rowptr); + thrust::device_vector d_b_colind(b_colind); + spblas::csr_view d_b( + d_b_values.data().get(), d_b_rowptr.data().get(), + d_b_colind.data().get(), b_shape, b_nnz); + spblas::csr_view b(b_values, b_rowptr, + b_colind, b_shape, b_nnz); + + thrust::device_vector d_c_rowptr(m + 1); + + spblas::csr_view d_c( + nullptr, d_c_rowptr.data().get(), nullptr, {m, n}, 0); + + spblas::spgemm_state_t state; + spblas::multiply_compute(state, spblas::scaled(alpha, d_a), d_b, d_c); + auto nnz = state.result_nnz(); + thrust::device_vector d_c_values(nnz); + thrust::device_vector d_c_colind(nnz); + std::span d_c_values_span(d_c_values.data().get(), nnz); + std::span d_c_rowptr_span(d_c_rowptr.data().get(), m + 1); + std::span d_c_colind_span(d_c_colind.data().get(), nnz); + d_c.update(d_c_values_span, d_c_rowptr_span, d_c_colind_span, {m, n}, + nnz); + + spblas::multiply_fill(state, spblas::scaled(alpha, d_a), d_b, d_c); + + std::vector c_values(nnz); + std::vector c_rowptr(m + 1); + std::vector c_colind(nnz); + thrust::copy(d_c_values.begin(), d_c_values.end(), c_values.begin()); + thrust::copy(d_c_rowptr.begin(), d_c_rowptr.end(), c_rowptr.begin()); + thrust::copy(d_c_colind.begin(), d_c_colind.end(), c_colind.begin()); + spblas::csr_view c(c_values, c_rowptr, + c_colind, {m, n}, nnz); + + spblas::__backend::spa_accumulator c_row_ref( + spblas::__backend::shape(c)[1]); + + spblas::__backend::spa_accumulator c_row_acc( + spblas::__backend::shape(c)[1]); + + for (auto&& [i, a_row] : spblas::__backend::rows(a)) { + c_row_ref.clear(); + for (auto&& [k, a_v] : a_row) { + auto&& b_row = spblas::__backend::lookup_row(b, k); + + for (auto&& [j, b_v] : b_row) { + c_row_ref[j] += alpha * a_v * b_v; + } + } + + auto&& c_row = spblas::__backend::lookup_row(c, i); + + // Accumulate output into `c_row_acc` so that we can allow + // duplicate column indices. + c_row_acc.clear(); + for (auto&& [j, c_v] : c_row) { + c_row_acc[j] += c_v; + } + + for (auto&& [j, c_v] : c_row) { + EXPECT_EQ_(c_row_ref[j], c_row_acc[j]); + } + + EXPECT_EQ(c_row_ref.size(), c_row_acc.size()); + } + } + } +} + +TEST(thrust_CsrView, SpGEMM_BScaled) { + value_t alpha = 2.0f; + for (auto&& [m, k, nnz] : util::dims) { + for (auto&& n : {m, k}) { + auto [a_values, a_rowptr, a_colind, a_shape, a_nnz] = + spblas::generate_csr(m, k, nnz); + thrust::device_vector d_a_values(a_values); + thrust::device_vector d_a_rowptr(a_rowptr); + thrust::device_vector d_a_colind(a_colind); + spblas::csr_view d_a( + d_a_values.data().get(), d_a_rowptr.data().get(), + d_a_colind.data().get(), a_shape, a_nnz); + spblas::csr_view a(a_values, a_rowptr, + a_colind, a_shape, a_nnz); + + auto [b_values, b_rowptr, b_colind, b_shape, b_nnz] = + spblas::generate_csr(k, n, nnz); + thrust::device_vector d_b_values(b_values); + thrust::device_vector d_b_rowptr(b_rowptr); + thrust::device_vector d_b_colind(b_colind); + spblas::csr_view d_b( + d_b_values.data().get(), d_b_rowptr.data().get(), + d_b_colind.data().get(), b_shape, b_nnz); + spblas::csr_view b(b_values, b_rowptr, + b_colind, b_shape, b_nnz); + + thrust::device_vector d_c_rowptr(m + 1); + + spblas::csr_view d_c( + nullptr, d_c_rowptr.data().get(), nullptr, {m, n}, 0); + + spblas::spgemm_state_t state; + spblas::multiply_compute(state, d_a, spblas::scaled(alpha, d_b), d_c); + auto nnz = state.result_nnz(); + thrust::device_vector d_c_values(nnz); + thrust::device_vector d_c_colind(nnz); + std::span d_c_values_span(d_c_values.data().get(), nnz); + std::span d_c_rowptr_span(d_c_rowptr.data().get(), m + 1); + std::span d_c_colind_span(d_c_colind.data().get(), nnz); + d_c.update(d_c_values_span, d_c_rowptr_span, d_c_colind_span, {m, n}, + nnz); + + spblas::multiply_fill(state, d_a, spblas::scaled(alpha, d_b), d_c); + + std::vector c_values(nnz); + std::vector c_rowptr(m + 1); + std::vector c_colind(nnz); + thrust::copy(d_c_values.begin(), d_c_values.end(), c_values.begin()); + thrust::copy(d_c_rowptr.begin(), d_c_rowptr.end(), c_rowptr.begin()); + thrust::copy(d_c_colind.begin(), d_c_colind.end(), c_colind.begin()); + spblas::csr_view c(c_values, c_rowptr, + c_colind, {m, n}, nnz); + + spblas::__backend::spa_accumulator c_row_ref( + spblas::__backend::shape(c)[1]); + + spblas::__backend::spa_accumulator c_row_acc( + spblas::__backend::shape(c)[1]); + + for (auto&& [i, a_row] : spblas::__backend::rows(a)) { + c_row_ref.clear(); + for (auto&& [k, a_v] : a_row) { + auto&& b_row = spblas::__backend::lookup_row(b, k); + + for (auto&& [j, b_v] : b_row) { + c_row_ref[j] += alpha * a_v * b_v; + } + } + + auto&& c_row = spblas::__backend::lookup_row(c, i); + + // Accumulate output into `c_row_acc` so that we can allow + // duplicate column indices. + c_row_acc.clear(); + for (auto&& [j, c_v] : c_row) { + c_row_acc[j] += c_v; + } + + for (auto&& [j, c_v] : c_row) { + EXPECT_EQ_(c_row_ref[j], c_row_acc[j]); + } + + EXPECT_EQ(c_row_ref.size(), c_row_acc.size()); + } + } + } +}