44#include < spblas/concepts.hpp>
55#include < spblas/detail/log.hpp>
66
7+ #include < spblas/algorithms/detail/sparse_dot_product.hpp>
78#include < spblas/algorithms/transposed.hpp>
89#include < spblas/backend/csr_builder.hpp>
910#include < spblas/backend/spa_accumulator.hpp>
@@ -190,44 +191,6 @@ void multiply_compute(operation_info_t& info, A&& a, B&& b, C&& c) {
190191 info.update_impl_ (new_info.result_shape (), new_info.result_nnz ());
191192}
192193
193- template <typename T, typename A, typename B>
194- std::optional<T> sparse_dot_product (A&& a, B&& b) {
195- auto sort_by_index = [](auto && a, auto && b) {
196- auto && [a_i, a_v] = a;
197- auto && [b_i, b_v] = b;
198- return a_i < b_i;
199- };
200- std::sort (a.begin (), a.end (), sort_by_index);
201- std::sort (b.begin (), b.end (), sort_by_index);
202-
203- auto a_iter = a.begin ();
204- auto b_iter = b.begin ();
205-
206- T sum = 0 ;
207- bool implicit_zero = true ;
208- for (; a_iter != a.end () && b_iter != b.end ();) {
209- auto && [a_i, a_v] = *a_iter;
210- auto && [b_i, b_v] = *b_iter;
211-
212- if (a_i == b_i) {
213- sum += a_v * b_v;
214- implicit_zero = false ;
215- ++a_iter;
216- ++b_iter;
217- } else if (a_i < b_i) {
218- ++a_iter;
219- } else {
220- ++b_iter;
221- }
222- }
223-
224- if (implicit_zero) {
225- return {};
226- } else {
227- return sum;
228- }
229- }
230-
231194// C = AB
232195// SpGEMM (Inner Product)
233196template <matrix A, matrix B, matrix C>
@@ -245,6 +208,7 @@ void multiply(A&& a, B&& b, C&& c) {
245208 using T = tensor_scalar_t <C>;
246209 using I = tensor_index_t <C>;
247210
211+ __backend::spa_accumulator<T, I> dot_product_acc (__backend::shape (c)[1 ]);
248212 __backend::spa_accumulator<T, I> c_row (__backend::shape (c)[1 ]);
249213 __backend::csr_builder c_builder (c);
250214
@@ -254,7 +218,8 @@ void multiply(A&& a, B&& b, C&& c) {
254218 if (!__ranges::empty (a_row)) {
255219 for (auto && [j, b_column] : __backend::columns (b)) {
256220 if (!__ranges::empty (b_column)) {
257- auto v = sparse_dot_product<T>(a_row, b_column);
221+ auto v =
222+ __detail::sparse_dot_product<T>(dot_product_acc, a_row, b_column);
258223
259224 if (v.has_value ()) {
260225 c_row[j] += v.value ();
@@ -295,11 +260,14 @@ operation_info_t multiply_compute(A&& a, B&& b, C&& c) {
295260
296261 O nnz = 0 ;
297262
263+ __backend::spa_accumulator<T, I> dot_product_acc (__backend::shape (c)[1 ]);
264+
298265 for (auto && [i, a_row] : __backend::rows (a)) {
299266 if (!__ranges::empty (a_row)) {
300267 for (auto && [j, b_column] : __backend::columns (b)) {
301268 if (!__ranges::empty (b_column)) {
302- auto v = sparse_dot_product<T>(a_row, b_column);
269+ auto v =
270+ __detail::sparse_dot_product<T>(dot_product_acc, a_row, b_column);
303271
304272 if (v.has_value ()) {
305273 nnz++;
0 commit comments