Skip to content

Commit 64d662f

Browse files
committed
add the test when changing the pointer
1 parent 083c367 commit 64d662f

File tree

2 files changed

+141
-0
lines changed

2 files changed

+141
-0
lines changed

include/spblas/vendor/rocsparse/multiply_spgemm.hpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ class spgemm_state_t {
202202
void multiply_numeric(A&& a, B&& b, C&& c, D&& d) {
203203
auto a_base = __detail::get_ultimate_base(a);
204204
auto b_base = __detail::get_ultimate_base(b);
205+
auto d_base = __detail::get_ultimate_base(d);
205206
using matrix_type = decltype(a_base);
206207
using input_type = decltype(b_base);
207208
using output_type = std::remove_reference_t<decltype(c)>;
@@ -213,6 +214,22 @@ class spgemm_state_t {
213214
auto beta_optional = __detail::get_scaling_factor(d);
214215
value_type beta = beta_optional.value_or(1);
215216

217+
// Update the pointer from the matrix but they must contains the same
218+
// sparsity as the previous call.
219+
__rocsparse::throw_if_error(rocsparse_csr_set_pointers(
220+
this->mat_a_, a_base.rowptr().data(), a_base.colind().data(),
221+
a_base.values().data()));
222+
__rocsparse::throw_if_error(rocsparse_csr_set_pointers(
223+
this->mat_b_, b_base.rowptr().data(), b_base.colind().data(),
224+
b_base.values().data()));
225+
__rocsparse::throw_if_error(rocsparse_csr_set_pointers(
226+
this->mat_c_, c.rowptr().data(), c.colind().data(), c.values().data()));
227+
if (d_base.values().data()) {
228+
// when it is still a null matrix, we can not use set pointer function
229+
__rocsparse::throw_if_error(rocsparse_csr_set_pointers(
230+
this->mat_d_, d_base.rowptr().data(), d_base.colind().data(),
231+
d_base.values().data()));
232+
}
216233
__rocsparse::throw_if_error(rocsparse_spgemm(
217234
this->handle_.get(), rocsparse_operation_none, rocsparse_operation_none,
218235
&alpha, this->mat_a_, this->mat_b_, &beta, this->mat_d_, this->mat_c_,

test/gtest/rocsparse/spgemm_reuse_test.cpp

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,3 +322,127 @@ TEST(CsrView, SpGEMMReuse_BScaled) {
322322
}
323323
}
324324
}
325+
326+
TEST(CsrView, SpGEMMReuseAndChangePointer) {
327+
for (auto&& [m, k, nnz] : util::dims) {
328+
for (auto&& n : {m, k}) {
329+
auto [a_values, a_rowptr, a_colind, a_shape, a_nnz] =
330+
spblas::generate_csr<value_t, index_t, offset_t>(m, k, nnz);
331+
thrust::device_vector<value_t> d_a_values(a_values);
332+
thrust::device_vector<offset_t> d_a_rowptr(a_rowptr);
333+
thrust::device_vector<index_t> d_a_colind(a_colind);
334+
spblas::csr_view<value_t, index_t, offset_t> d_a(
335+
d_a_values.data().get(), d_a_rowptr.data().get(),
336+
d_a_colind.data().get(), a_shape, a_nnz);
337+
spblas::csr_view<value_t, index_t, offset_t> a(a_values, a_rowptr,
338+
a_colind, a_shape, a_nnz);
339+
340+
auto [b_values, b_rowptr, b_colind, b_shape, b_nnz] =
341+
spblas::generate_csr<value_t, index_t, offset_t>(k, n, nnz);
342+
thrust::device_vector<value_t> d_b_values(b_values);
343+
thrust::device_vector<offset_t> d_b_rowptr(b_rowptr);
344+
thrust::device_vector<index_t> d_b_colind(b_colind);
345+
spblas::csr_view<value_t, index_t, offset_t> d_b(
346+
d_b_values.data().get(), d_b_rowptr.data().get(),
347+
d_b_colind.data().get(), b_shape, b_nnz);
348+
spblas::csr_view<value_t, index_t, offset_t> b(b_values, b_rowptr,
349+
b_colind, b_shape, b_nnz);
350+
351+
thrust::device_vector<offset_t> d_c_rowptr(m + 1);
352+
353+
spblas::csr_view<value_t, index_t, offset_t> d_c(
354+
nullptr, d_c_rowptr.data().get(), nullptr, {m, n}, 0);
355+
356+
spblas::spgemm_state_t state;
357+
spblas::multiply_symbolic_compute(state, d_a, d_b, d_c);
358+
auto nnz = state.result_nnz();
359+
thrust::device_vector<value_t> d_c_values(nnz);
360+
thrust::device_vector<index_t> d_c_colind(nnz);
361+
std::span<value_t> d_c_values_span(d_c_values.data().get(), nnz);
362+
std::span<offset_t> d_c_rowptr_span(d_c_rowptr.data().get(), m + 1);
363+
std::span<index_t> d_c_colind_span(d_c_colind.data().get(), nnz);
364+
d_c.update(d_c_values_span, d_c_rowptr_span, d_c_colind_span, {m, n},
365+
nnz);
366+
367+
spblas::multiply_symbolic_fill(state, d_a, d_b, d_c);
368+
std::mt19937 g(0);
369+
for (int i = 0; i < 3; i++) {
370+
// regenerate value of a and b;
371+
std::uniform_real_distribution val_dist(0.0, 100.0);
372+
for (auto& v : a_values) {
373+
v = val_dist(g);
374+
}
375+
for (auto& v : b_values) {
376+
v = val_dist(g);
377+
}
378+
// create different pointers than the symbolic phase, but they still
379+
// hold the same sparsity
380+
thrust::device_vector<value_t> d_a_values_new(a_values);
381+
thrust::device_vector<index_t> d_a_colind_new(d_a_colind);
382+
thrust::device_vector<index_t> d_a_rowptr_new(d_a_rowptr);
383+
thrust::device_vector<value_t> d_b_values_new(b_values);
384+
thrust::device_vector<index_t> d_b_colind_new(d_b_colind);
385+
thrust::device_vector<index_t> d_b_rowptr_new(d_b_rowptr);
386+
thrust::device_vector<value_t> d_c_values_new(d_c_values);
387+
thrust::device_vector<index_t> d_c_colind_new(d_c_colind);
388+
thrust::device_vector<index_t> d_c_rowptr_new(d_c_rowptr);
389+
spblas::csr_view<value_t, index_t, offset_t> d_a(
390+
d_a_values_new.data().get(), d_a_rowptr_new.data().get(),
391+
d_a_colind_new.data().get(), a_shape, a_nnz);
392+
spblas::csr_view<value_t, index_t, offset_t> d_b(
393+
d_b_values_new.data().get(), d_b_rowptr_new.data().get(),
394+
d_b_colind_new.data().get(), b_shape, b_nnz);
395+
spblas::csr_view<value_t, index_t, offset_t> d_c(
396+
d_c_values_new.data().get(), d_c_rowptr_new.data().get(),
397+
d_c_colind_new.data().get(), {m, n}, nnz);
398+
// call numeric on new data
399+
spblas::multiply_numeric(state, d_a, d_b, d_c);
400+
// move c back to host memory
401+
std::vector<value_t> c_values(nnz);
402+
std::vector<offset_t> c_rowptr(m + 1);
403+
std::vector<index_t> c_colind(nnz);
404+
thrust::copy(d_c_values_new.begin(), d_c_values_new.end(),
405+
c_values.begin());
406+
thrust::copy(d_c_rowptr_new.begin(), d_c_rowptr_new.end(),
407+
c_rowptr.begin());
408+
thrust::copy(d_c_colind_new.begin(), d_c_colind_new.end(),
409+
c_colind.begin());
410+
411+
spblas::csr_view<value_t, index_t, offset_t> c(c_values, c_rowptr,
412+
c_colind, {m, n}, nnz);
413+
414+
spblas::__backend::spa_accumulator<value_t, index_t> c_row_ref(
415+
spblas::__backend::shape(c)[1]);
416+
417+
spblas::__backend::spa_accumulator<value_t, index_t> c_row_acc(
418+
spblas::__backend::shape(c)[1]);
419+
420+
for (auto&& [i, a_row] : spblas::__backend::rows(a)) {
421+
c_row_ref.clear();
422+
for (auto&& [k, a_v] : a_row) {
423+
auto&& b_row = spblas::__backend::lookup_row(b, k);
424+
425+
for (auto&& [j, b_v] : b_row) {
426+
c_row_ref[j] += a_v * b_v;
427+
}
428+
}
429+
430+
auto&& c_row = spblas::__backend::lookup_row(c, i);
431+
432+
// Accumulate output into `c_row_acc` so that we can allow
433+
// duplicate column indices.
434+
c_row_acc.clear();
435+
for (auto&& [j, c_v] : c_row) {
436+
c_row_acc[j] += c_v;
437+
}
438+
439+
for (auto&& [j, c_v] : c_row) {
440+
EXPECT_EQ_(c_row_ref[j], c_row_acc[j]);
441+
}
442+
443+
EXPECT_EQ(c_row_ref.size(), c_row_acc.size());
444+
}
445+
}
446+
}
447+
}
448+
}

0 commit comments

Comments
 (0)