Skip to content

Commit 543ed24

Browse files
committed
Support mixed CSR/CSC in MKL.
1 parent 53a59c0 commit 543ed24

File tree

7 files changed

+151
-36
lines changed

7 files changed

+151
-36
lines changed

CMakeLists.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON)
66

77
set(CMAKE_CXX_FLAGS "-O3 -march=native")
88

9+
option(ENABLE_SANITIZERS "Enable Clang sanitizers" OFF)
10+
911
# Get includes, which declares the `spblas` library
1012
add_subdirectory(include)
1113

@@ -36,6 +38,13 @@ if (LOG_LEVEL)
3638
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DLOG_LEVEL=${LOG_LEVEL}") # SPBLAS_DEBUG | SPBLAS_WARNING | SPBLAS_TRACE | SPBLAS_INFO
3739
endif()
3840

41+
# Enable sanitizers
42+
if (ENABLE_SANITIZERS)
43+
set(SANITIZER_FLAGS "-fsanitize=address,undefined")
44+
target_compile_options(spblas INTERFACE ${SANITIZER_FLAGS} -g -O1 -fno-omit-frame-pointer)
45+
target_link_options(spblas INTERFACE ${SANITIZER_FLAGS})
46+
endif()
47+
3948
# mdspan
4049
FetchContent_Declare(
4150
mdspan

include/spblas/algorithms/multiply_impl.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,8 @@ std::optional<T> sparse_dot_product(A&& a, B&& b) {
212212
if (a_i == b_i) {
213213
sum += a_v * b_v;
214214
implicit_zero = false;
215+
++a_iter;
216+
++b_iter;
215217
} else if (a_i < b_i) {
216218
++a_iter;
217219
} else {
@@ -268,6 +270,7 @@ void multiply(A&& a, B&& b, C&& c) {
268270
}
269271
}
270272
}
273+
c_builder.finish();
271274
c.update(c.values(), c.rowptr(), c.colind(), c.shape(),
272275
c.rowptr()[c.shape()[0]]);
273276
}

include/spblas/algorithms/transpose_impl.hpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,15 @@ operation_info_t transpose_inspect(A&& a, B&& b) {
1313

1414
template <matrix A, matrix B>
1515
requires(__detail::is_csr_view_v<A> && __detail::is_csr_view_v<B>)
16-
void transpose(operation_info_t& info, A&& a, B&& b) {
16+
void transpose(A&& a, B&& b) {
1717
if (__backend::shape(a)[0] != __backend::shape(b)[1] ||
1818
__backend::shape(a)[1] != __backend::shape(b)[0]) {
1919
throw std::invalid_argument(
2020
"transpose: matrix dimensions are incompatible.");
2121
}
22-
if (__backend::size(a) != __backend::size(b)) {
23-
throw std::invalid_argument("transpose: matrix nnz are incompatible.");
22+
if (b.values().size() < __backend::size(a) ||
23+
b.colind().size() < __backend::size(a)) {
24+
throw std::runtime_error("transpose: Transpose ran out of memory.");
2425
}
2526
using O = tensor_offset_t<B>;
2627

@@ -47,6 +48,14 @@ void transpose(operation_info_t& info, A&& a, B&& b) {
4748
b_rowptr[j + 1]++;
4849
}
4950
}
51+
52+
b.update(b.values(), b.rowptr(), b.colind(), b.shape(), a.size());
53+
}
54+
55+
template <matrix A, matrix B>
56+
requires(__detail::is_csr_view_v<A> && __detail::is_csr_view_v<B>)
57+
void transpose(operation_info_t& info, A&& a, B&& b) {
58+
transpose(std::forward<A>(a), std::forward<B>(b));
5059
}
5160

5261
} // namespace spblas

include/spblas/backend/csr_builder.hpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,25 @@ class csr_builder {
2626

2727
while (i_ < row_index) {
2828
view_.rowptr()[i_ + 1] = j_ptr_;
29+
i_++;
2930
}
3031

3132
for (auto&& [j, v] : row) {
3233
view_.values()[j_ptr_] = v;
3334
view_.colind()[j_ptr_] = j;
34-
++j_ptr_;
35+
j_ptr_++;
3536
}
3637
view_.rowptr()[i_ + 1] = j_ptr_;
3738
i_++;
3839
}
3940

41+
void finish() {
42+
while (i_ < view_.shape()[0]) {
43+
view_.rowptr()[i_ + 1] = j_ptr_;
44+
i_++;
45+
}
46+
}
47+
4048
private:
4149
csr_view<T, I, O> view_;
4250
O j_ptr_ = 0;

include/spblas/vendor/onemkl_sycl/spgemm_impl.hpp

Lines changed: 24 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <spblas/detail/operation_info_t.hpp>
99
#include <spblas/detail/ranges.hpp>
1010
#include <spblas/detail/view_inspectors.hpp>
11+
#include <spblas/vendor/onemkl_sycl/detail/create_matrix_handle.hpp>
1112

1213
//
1314
// Defines the following APIs for SpGEMM:
@@ -23,8 +24,9 @@
2324
namespace spblas {
2425

2526
template <matrix A, matrix B, matrix C>
26-
requires __detail::has_csr_base<A> && __detail::has_csr_base<B> &&
27-
__detail::is_csr_view_v<C>
27+
requires(__detail::has_csr_base<A> || __detail::has_csc_base<A>) &&
28+
(__detail::has_csr_base<B> || __detail::has_csc_base<B>) &&
29+
__detail::is_csr_view_v<C>
2830
operation_info_t multiply_compute(A&& a, B&& b, C&& c) {
2931
log_trace("");
3032
auto a_base = __detail::get_ultimate_base(a);
@@ -34,51 +36,40 @@ operation_info_t multiply_compute(A&& a, B&& b, C&& c) {
3436
using oneapi::mkl::sparse::matmat_request;
3537
using oneapi::mkl::sparse::matrix_view_descr;
3638

37-
oneapi::mkl::sparse::matmat_descr_t descr = nullptr;
38-
3939
sycl::queue q(sycl::cpu_selector_v);
4040

41-
oneapi::mkl::sparse::init_matmat_descr(&descr);
42-
43-
oneapi::mkl::sparse::set_matmat_data(
44-
descr, matrix_view_descr::general, transpose::nontrans, // view/op for A
45-
matrix_view_descr::general, transpose::nontrans, // view/op for B
46-
matrix_view_descr::general); // view for C
47-
48-
oneapi::mkl::sparse::matrix_handle_t a_handle, b_handle, c_handle;
49-
a_handle = b_handle = c_handle = nullptr;
50-
51-
oneapi::mkl::sparse::init_matrix_handle(&a_handle);
52-
oneapi::mkl::sparse::init_matrix_handle(&b_handle);
53-
oneapi::mkl::sparse::init_matrix_handle(&c_handle);
54-
55-
oneapi::mkl::sparse::set_csr_data(
56-
q, a_handle, __backend::shape(a_base)[0], __backend::shape(a_base)[1],
57-
oneapi::mkl::index_base::zero, a_base.rowptr().data(),
58-
a_base.colind().data(), a_base.values().data())
59-
.wait();
60-
61-
oneapi::mkl::sparse::set_csr_data(
62-
q, b_handle, __backend::shape(b_base)[0], __backend::shape(b_base)[1],
63-
oneapi::mkl::index_base::zero, b_base.rowptr().data(),
64-
b_base.colind().data(), b_base.values().data())
65-
.wait();
66-
6741
using T = tensor_scalar_t<C>;
6842
using I = tensor_index_t<C>;
6943

44+
oneapi::mkl::sparse::matrix_handle_t a_handle =
45+
__mkl::create_matrix_handle(q, a_base);
46+
oneapi::mkl::sparse::matrix_handle_t b_handle =
47+
__mkl::create_matrix_handle(q, b_base);
48+
7049
I* c_rowptr;
7150
if (c.rowptr().size() >= __backend::shape(c)[0] + 1) {
7251
c_rowptr = c.rowptr().data();
7352
} else {
7453
c_rowptr = sycl::malloc_device<I>(__backend::shape(c)[0] + 1, q);
7554
}
7655

56+
oneapi::mkl::sparse::matrix_handle_t c_handle = nullptr;
57+
oneapi::mkl::sparse::init_matrix_handle(&c_handle);
58+
7759
oneapi::mkl::sparse::set_csr_data(
7860
q, c_handle, __backend::shape(c)[0], __backend::shape(c)[1],
7961
oneapi::mkl::index_base::zero, c_rowptr, (I*) nullptr, (T*) nullptr)
8062
.wait();
8163

64+
oneapi::mkl::sparse::matmat_descr_t descr = nullptr;
65+
oneapi::mkl::sparse::init_matmat_descr(&descr);
66+
67+
oneapi::mkl::sparse::set_matmat_data(
68+
descr, matrix_view_descr::general,
69+
__mkl::get_transpose(a), // view/op for A
70+
matrix_view_descr::general, __mkl::get_transpose(b), // view/op for B
71+
matrix_view_descr::general); // view for C
72+
8273
auto ev1 = oneapi::mkl::sparse::matmat(q, a_handle, b_handle, c_handle,
8374
matmat_request::work_estimation, descr,
8475
nullptr, nullptr, {});
@@ -113,8 +104,9 @@ operation_info_t multiply_compute(A&& a, B&& b, C&& c) {
113104
}
114105

115106
template <matrix A, matrix B, matrix C>
116-
requires __detail::has_csr_base<A> && __detail::has_csr_base<B> &&
117-
__detail::is_csr_view_v<C>
107+
requires(__detail::has_csr_base<A> || __detail::has_csc_base<A>) &&
108+
(__detail::has_csr_base<B> || __detail::has_csc_base<B>) &&
109+
__detail::is_csr_view_v<C>
118110
void multiply_fill(operation_info_t& info, A&& a, B&& b, C&& c) {
119111

120112
log_trace("");

test/gtest/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ add_executable(
55
spmv_test.cpp
66
spmm_test.cpp
77
spgemm_test.cpp
8+
spgemm_csr_csc.cpp
89
add_test.cpp
910
transpose_test.cpp
1011
triangular_solve_test.cpp

test/gtest/spgemm_csr_csc.cpp

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
#include <gtest/gtest.h>
2+
3+
#include "util.hpp"
4+
#include <spblas/backend/spa_accumulator.hpp>
5+
#include <spblas/spblas.hpp>
6+
7+
#include <fmt/core.h>
8+
#include <fmt/ranges.h>
9+
10+
TEST(MixedViews, SpGEMM_CsrCsc) {
11+
using T = float;
12+
using I = spblas::index_t;
13+
14+
for (auto&& [m, k, nnz] : util::dims) {
15+
for (auto&& n : {m, k}) {
16+
auto [a_values, a_rowptr, a_colind, a_shape, a_nnz] =
17+
spblas::generate_csr<T, I>(m, k, nnz);
18+
19+
auto [b_values, b_rowptr, b_colind, b_shape, b_nnz] =
20+
spblas::generate_csr<T, I>(k, n, nnz);
21+
22+
// We will be multiplying a times b.
23+
spblas::csr_view<T, I> a(a_values, a_rowptr, a_colind, a_shape, a_nnz);
24+
spblas::csr_view<T, I> b(b_values, b_rowptr, b_colind, b_shape, b_nnz);
25+
26+
// But we'd like the second operand to be a CSC matrix.
27+
// We first transpose b.
28+
29+
std::vector<T> b_t_values(b.size());
30+
std::vector<I> b_t_rowptr(b.shape()[1] + 1);
31+
std::vector<I> b_t_colind(b.size());
32+
33+
spblas::csr_view<T, I> b_t(b_t_values, b_t_rowptr, b_t_colind,
34+
{b.shape()[1], b.shape()[0]}, 0);
35+
36+
spblas::transpose(b, b_t);
37+
38+
// We then build a CSC representation of b from b_t.
39+
spblas::csc_view<T, I> b_csc(b_t.values(), b_t.rowptr(), b_t.colind(),
40+
{b_t.shape()[1], b_t.shape()[0]},
41+
b_t.size());
42+
43+
// Now let's multiply a * b_csc -> c.
44+
45+
std::vector<I> c_rowptr(m + 1);
46+
spblas::csr_view<T, I> c(nullptr, c_rowptr.data(), nullptr, {m, n}, 0);
47+
48+
auto info = spblas::multiply_compute(a, b_csc, c);
49+
50+
std::vector<T> c_values(info.result_nnz());
51+
std::vector<I> c_colind(info.result_nnz());
52+
53+
c.update(c_values, c_rowptr, c_colind);
54+
55+
spblas::multiply_fill(info, a, b_csc, c);
56+
57+
// Now that we have c, let's compute a reference c_ref.
58+
// We perform a * b -> c_ref
59+
60+
std::vector<I> c_ref_rowptr(m + 1);
61+
62+
spblas::csr_view<T, I> c_ref(nullptr, c_ref_rowptr.data(), nullptr,
63+
{m, n}, 0);
64+
65+
info = spblas::multiply_compute(a, b, c_ref);
66+
67+
std::vector<T> c_ref_values(info.result_nnz());
68+
std::vector<I> c_ref_colind(info.result_nnz());
69+
70+
c_ref.update(c_ref_values, c_ref_rowptr, c_ref_colind);
71+
72+
spblas::multiply_fill(info, a, b, c_ref);
73+
74+
spblas::__backend::spa_accumulator<T, I> c_row_acc(c.shape()[1]);
75+
76+
for (auto&& [i, c_row] : spblas::__backend::rows(c)) {
77+
c_row_acc.clear();
78+
79+
auto&& c_ref_row = spblas::__backend::lookup_row(c_ref, i);
80+
81+
EXPECT_EQ(c_row.size(), c_ref_row.size());
82+
83+
for (auto&& [j, v] : c_row) {
84+
c_row_acc[j] += v;
85+
}
86+
87+
for (auto&& [j, v] : c_ref_row) {
88+
EXPECT_EQ_(v, c_row_acc[j]);
89+
}
90+
}
91+
}
92+
}
93+
}

0 commit comments

Comments
 (0)