@@ -322,3 +322,127 @@ TEST(CsrView, SpGEMMReuse_BScaled) {
322322 }
323323 }
324324}
325+
326+ TEST (CsrView, SpGEMMReuseAndChangePointer) {
327+ for (auto && [m, k, nnz] : util::dims) {
328+ for (auto && n : {m, k}) {
329+ auto [a_values, a_rowptr, a_colind, a_shape, a_nnz] =
330+ spblas::generate_csr<value_t , index_t , offset_t >(m, k, nnz);
331+ thrust::device_vector<value_t > d_a_values (a_values);
332+ thrust::device_vector<offset_t > d_a_rowptr (a_rowptr);
333+ thrust::device_vector<index_t > d_a_colind (a_colind);
334+ spblas::csr_view<value_t , index_t , offset_t > d_a (
335+ d_a_values.data ().get (), d_a_rowptr.data ().get (),
336+ d_a_colind.data ().get (), a_shape, a_nnz);
337+ spblas::csr_view<value_t , index_t , offset_t > a (a_values, a_rowptr,
338+ a_colind, a_shape, a_nnz);
339+
340+ auto [b_values, b_rowptr, b_colind, b_shape, b_nnz] =
341+ spblas::generate_csr<value_t , index_t , offset_t >(k, n, nnz);
342+ thrust::device_vector<value_t > d_b_values (b_values);
343+ thrust::device_vector<offset_t > d_b_rowptr (b_rowptr);
344+ thrust::device_vector<index_t > d_b_colind (b_colind);
345+ spblas::csr_view<value_t , index_t , offset_t > d_b (
346+ d_b_values.data ().get (), d_b_rowptr.data ().get (),
347+ d_b_colind.data ().get (), b_shape, b_nnz);
348+ spblas::csr_view<value_t , index_t , offset_t > b (b_values, b_rowptr,
349+ b_colind, b_shape, b_nnz);
350+
351+ thrust::device_vector<offset_t > d_c_rowptr (m + 1 );
352+
353+ spblas::csr_view<value_t , index_t , offset_t > d_c (
354+ nullptr , d_c_rowptr.data ().get (), nullptr , {m, n}, 0 );
355+
356+ spblas::spgemm_state_t state;
357+ spblas::multiply_symbolic_compute (state, d_a, d_b, d_c);
358+ auto nnz = state.result_nnz ();
359+ thrust::device_vector<value_t > d_c_values (nnz);
360+ thrust::device_vector<index_t > d_c_colind (nnz);
361+ std::span<value_t > d_c_values_span (d_c_values.data ().get (), nnz);
362+ std::span<offset_t > d_c_rowptr_span (d_c_rowptr.data ().get (), m + 1 );
363+ std::span<index_t > d_c_colind_span (d_c_colind.data ().get (), nnz);
364+ d_c.update (d_c_values_span, d_c_rowptr_span, d_c_colind_span, {m, n},
365+ nnz);
366+
367+ spblas::multiply_symbolic_fill (state, d_a, d_b, d_c);
368+ std::mt19937 g (0 );
369+ for (int i = 0 ; i < 3 ; i++) {
370+ // regenerate value of a and b;
371+ std::uniform_real_distribution val_dist (0.0 , 100.0 );
372+ for (auto & v : a_values) {
373+ v = val_dist (g);
374+ }
375+ for (auto & v : b_values) {
376+ v = val_dist (g);
377+ }
378+ // create different pointers than the symbolic phase, but they still
379+ // hold the same sparsity
380+ thrust::device_vector<value_t > d_a_values_new (a_values);
381+ thrust::device_vector<index_t > d_a_colind_new (d_a_colind);
382+ thrust::device_vector<index_t > d_a_rowptr_new (d_a_rowptr);
383+ thrust::device_vector<value_t > d_b_values_new (b_values);
384+ thrust::device_vector<index_t > d_b_colind_new (d_b_colind);
385+ thrust::device_vector<index_t > d_b_rowptr_new (d_b_rowptr);
386+ thrust::device_vector<value_t > d_c_values_new (d_c_values);
387+ thrust::device_vector<index_t > d_c_colind_new (d_c_colind);
388+ thrust::device_vector<index_t > d_c_rowptr_new (d_c_rowptr);
389+ spblas::csr_view<value_t , index_t , offset_t > d_a (
390+ d_a_values_new.data ().get (), d_a_rowptr_new.data ().get (),
391+ d_a_colind_new.data ().get (), a_shape, a_nnz);
392+ spblas::csr_view<value_t , index_t , offset_t > d_b (
393+ d_b_values_new.data ().get (), d_b_rowptr_new.data ().get (),
394+ d_b_colind_new.data ().get (), b_shape, b_nnz);
395+ spblas::csr_view<value_t , index_t , offset_t > d_c (
396+ d_c_values_new.data ().get (), d_c_rowptr_new.data ().get (),
397+ d_c_colind_new.data ().get (), {m, n}, nnz);
398+ // call numeric on new data
399+ spblas::multiply_numeric (state, d_a, d_b, d_c);
400+ // move c back to host memory
401+ std::vector<value_t > c_values (nnz);
402+ std::vector<offset_t > c_rowptr (m + 1 );
403+ std::vector<index_t > c_colind (nnz);
404+ thrust::copy (d_c_values_new.begin (), d_c_values_new.end (),
405+ c_values.begin ());
406+ thrust::copy (d_c_rowptr_new.begin (), d_c_rowptr_new.end (),
407+ c_rowptr.begin ());
408+ thrust::copy (d_c_colind_new.begin (), d_c_colind_new.end (),
409+ c_colind.begin ());
410+
411+ spblas::csr_view<value_t , index_t , offset_t > c (c_values, c_rowptr,
412+ c_colind, {m, n}, nnz);
413+
414+ spblas::__backend::spa_accumulator<value_t , index_t > c_row_ref (
415+ spblas::__backend::shape (c)[1 ]);
416+
417+ spblas::__backend::spa_accumulator<value_t , index_t > c_row_acc (
418+ spblas::__backend::shape (c)[1 ]);
419+
420+ for (auto && [i, a_row] : spblas::__backend::rows (a)) {
421+ c_row_ref.clear ();
422+ for (auto && [k, a_v] : a_row) {
423+ auto && b_row = spblas::__backend::lookup_row (b, k);
424+
425+ for (auto && [j, b_v] : b_row) {
426+ c_row_ref[j] += a_v * b_v;
427+ }
428+ }
429+
430+ auto && c_row = spblas::__backend::lookup_row (c, i);
431+
432+ // Accumulate output into `c_row_acc` so that we can allow
433+ // duplicate column indices.
434+ c_row_acc.clear ();
435+ for (auto && [j, c_v] : c_row) {
436+ c_row_acc[j] += c_v;
437+ }
438+
439+ for (auto && [j, c_v] : c_row) {
440+ EXPECT_EQ_ (c_row_ref[j], c_row_acc[j]);
441+ }
442+
443+ EXPECT_EQ (c_row_ref.size (), c_row_acc.size ());
444+ }
445+ }
446+ }
447+ }
448+ }
0 commit comments