Skip to content

Commit 91c017e

Browse files
committed
Implement dot product with accumulator to avoid having to sort indices
of inputs.
1 parent 543ed24 commit 91c017e

File tree

3 files changed

+50
-40
lines changed

3 files changed

+50
-40
lines changed
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#pragma once
2+
3+
#include <optional>
4+
5+
#include <spblas/backend/spa_accumulator.hpp>
6+
7+
namespace spblas {
8+
9+
namespace __detail {
10+
11+
template <typename T, typename I, typename A, typename B>
12+
std::optional<T> sparse_dot_product(__backend::spa_accumulator<T, I>& acc,
13+
A&& a, B&& b) {
14+
acc.clear();
15+
16+
for (auto&& [i, v] : a) {
17+
acc[i] = v;
18+
}
19+
20+
T sum = 0;
21+
bool implicit_zero = true;
22+
for (auto&& [i, v] : b) {
23+
if (acc.contains(i)) {
24+
sum += acc[i] * v;
25+
implicit_zero = false;
26+
}
27+
}
28+
29+
if (implicit_zero) {
30+
return {};
31+
} else {
32+
return sum;
33+
}
34+
}
35+
36+
} // namespace __detail
37+
38+
} // namespace spblas

include/spblas/algorithms/multiply_impl.hpp

Lines changed: 8 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <spblas/concepts.hpp>
55
#include <spblas/detail/log.hpp>
66

7+
#include <spblas/algorithms/detail/sparse_dot_product.hpp>
78
#include <spblas/algorithms/transposed.hpp>
89
#include <spblas/backend/csr_builder.hpp>
910
#include <spblas/backend/spa_accumulator.hpp>
@@ -190,44 +191,6 @@ void multiply_compute(operation_info_t& info, A&& a, B&& b, C&& c) {
190191
info.update_impl_(new_info.result_shape(), new_info.result_nnz());
191192
}
192193

193-
template <typename T, typename A, typename B>
194-
std::optional<T> sparse_dot_product(A&& a, B&& b) {
195-
auto sort_by_index = [](auto&& a, auto&& b) {
196-
auto&& [a_i, a_v] = a;
197-
auto&& [b_i, b_v] = b;
198-
return a_i < b_i;
199-
};
200-
std::sort(a.begin(), a.end(), sort_by_index);
201-
std::sort(b.begin(), b.end(), sort_by_index);
202-
203-
auto a_iter = a.begin();
204-
auto b_iter = b.begin();
205-
206-
T sum = 0;
207-
bool implicit_zero = true;
208-
for (; a_iter != a.end() && b_iter != b.end();) {
209-
auto&& [a_i, a_v] = *a_iter;
210-
auto&& [b_i, b_v] = *b_iter;
211-
212-
if (a_i == b_i) {
213-
sum += a_v * b_v;
214-
implicit_zero = false;
215-
++a_iter;
216-
++b_iter;
217-
} else if (a_i < b_i) {
218-
++a_iter;
219-
} else {
220-
++b_iter;
221-
}
222-
}
223-
224-
if (implicit_zero) {
225-
return {};
226-
} else {
227-
return sum;
228-
}
229-
}
230-
231194
// C = AB
232195
// SpGEMM (Inner Product)
233196
template <matrix A, matrix B, matrix C>
@@ -245,6 +208,7 @@ void multiply(A&& a, B&& b, C&& c) {
245208
using T = tensor_scalar_t<C>;
246209
using I = tensor_index_t<C>;
247210

211+
__backend::spa_accumulator<T, I> dot_product_acc(__backend::shape(c)[1]);
248212
__backend::spa_accumulator<T, I> c_row(__backend::shape(c)[1]);
249213
__backend::csr_builder c_builder(c);
250214

@@ -254,7 +218,8 @@ void multiply(A&& a, B&& b, C&& c) {
254218
if (!__ranges::empty(a_row)) {
255219
for (auto&& [j, b_column] : __backend::columns(b)) {
256220
if (!__ranges::empty(b_column)) {
257-
auto v = sparse_dot_product<T>(a_row, b_column);
221+
auto v =
222+
__detail::sparse_dot_product<T>(dot_product_acc, a_row, b_column);
258223

259224
if (v.has_value()) {
260225
c_row[j] += v.value();
@@ -295,11 +260,14 @@ operation_info_t multiply_compute(A&& a, B&& b, C&& c) {
295260

296261
O nnz = 0;
297262

263+
__backend::spa_accumulator<T, I> dot_product_acc(__backend::shape(c)[1]);
264+
298265
for (auto&& [i, a_row] : __backend::rows(a)) {
299266
if (!__ranges::empty(a_row)) {
300267
for (auto&& [j, b_column] : __backend::columns(b)) {
301268
if (!__ranges::empty(b_column)) {
302-
auto v = sparse_dot_product<T>(a_row, b_column);
269+
auto v =
270+
__detail::sparse_dot_product<T>(dot_product_acc, a_row, b_column);
303271

304272
if (v.has_value()) {
305273
nnz++;

include/spblas/backend/spa_accumulator.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ class spa_accumulator {
2525
return data_[pos];
2626
}
2727

28+
bool contains(I pos) {
29+
return set_[pos];
30+
}
31+
2832
void clear() {
2933
for (auto&& pos : stored_) {
3034
set_[pos] = false;

0 commit comments

Comments
 (0)