Skip to content

Commit 1188be9

Browse files
committed
Use hash_accumulator for outer product, use sparse_intersection for
first phase of inner product.
1 parent 0ca3447 commit 1188be9

File tree

6 files changed

+117
-8
lines changed

6 files changed

+117
-8
lines changed

include/spblas/algorithms/detail/sparse_dot_product.hpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,23 @@ std::optional<T> sparse_dot_product(__backend::spa_accumulator<T, I>& acc,
3333
}
3434
}
3535

36+
template <typename Set, typename A, typename B>
37+
bool sparse_intersection(Set&& set, A&& a, B&& b) {
38+
set.clear();
39+
40+
for (auto&& [i, v] : a) {
41+
set.insert(i);
42+
}
43+
44+
for (auto&& [i, v] : b) {
45+
if (set.contains(i)) {
46+
return true;
47+
}
48+
}
49+
50+
return false;
51+
}
52+
3653
} // namespace __detail
3754

3855
} // namespace spblas

include/spblas/algorithms/detail/spgemm/spgemm_innerproduct.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,16 +83,16 @@ operation_info_t multiply_compute(A&& a, B&& b, C&& c) {
8383

8484
O nnz = 0;
8585

86-
__backend::spa_accumulator<T, I> dot_product_acc(__backend::shape(a)[1]);
86+
__backend::spa_set<I> dot_product_acc(__backend::shape(a)[1]);
8787

8888
for (auto&& [i, a_row] : __backend::rows(a)) {
8989
if (!__ranges::empty(a_row)) {
9090
for (auto&& [j, b_column] : __backend::columns(b)) {
9191
if (!__ranges::empty(b_column)) {
9292
auto v =
93-
__detail::sparse_dot_product<T>(dot_product_acc, a_row, b_column);
93+
__detail::sparse_intersection(dot_product_acc, a_row, b_column);
9494

95-
if (v.has_value()) {
95+
if (v) {
9696
nnz++;
9797
}
9898
}

include/spblas/algorithms/detail/spgemm/spgemm_outerproduct.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
#include <spblas/algorithms/detail/sparse_dot_product.hpp>
88
#include <spblas/algorithms/transposed.hpp>
99
#include <spblas/backend/csr_builder.hpp>
10-
#include <spblas/backend/spa_accumulator.hpp>
10+
#include <spblas/backend/hash_accumulator.hpp>
1111
#include <spblas/detail/operation_info_t.hpp>
1212

1313
namespace spblas {
@@ -30,7 +30,7 @@ void multiply(A&& a, B&& b, C&& c) {
3030
using T = tensor_scalar_t<C>;
3131
using I = tensor_index_t<C>;
3232

33-
std::vector<__backend::spa_accumulator<T, I>> row_accumulators;
33+
std::vector<__backend::hash_accumulator<T, I>> row_accumulators;
3434

3535
for (std::size_t i = 0; i < __backend::shape(c)[0]; i++) {
3636
row_accumulators.emplace_back(__backend::shape(c)[1]);
@@ -80,7 +80,7 @@ operation_info_t multiply_compute(A&& a, B&& b, C&& c) {
8080
using I = tensor_index_t<C>;
8181
using O = tensor_offset_t<C>;
8282

83-
std::vector<__backend::spa_accumulator<T, I>> row_accumulators;
83+
std::vector<__backend::hash_accumulator<T, I>> row_accumulators;
8484

8585
for (std::size_t i = 0; i < __backend::shape(c)[0]; i++) {
8686
row_accumulators.emplace_back(__backend::shape(c)[1]);
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
#pragma once
2+
3+
#include <functional>
4+
#include <span>
5+
#include <tuple>
6+
#include <unordered_map>
7+
#include <unordered_set>
8+
#include <vector>
9+
10+
#include <spblas/detail/ranges.hpp>
11+
12+
namespace spblas {
13+
14+
namespace __backend {
15+
16+
template <typename T, std::integral I>
17+
class hash_accumulator {
18+
public:
19+
hash_accumulator(I count) {}
20+
21+
T& operator[](I pos) {
22+
return hash_[pos];
23+
}
24+
25+
bool contains(I pos) {
26+
return hash_.contains(pos);
27+
}
28+
29+
void clear() {
30+
hash_.clear();
31+
}
32+
33+
I size() const {
34+
return hash_.size();
35+
}
36+
37+
bool empty() {
38+
return hash_.empty();
39+
}
40+
41+
void sort() {}
42+
43+
auto get() {
44+
std::vector<std::pair<I, T>> values(hash_.begin(), hash_.end());
45+
46+
std::sort(values.begin(), values.end(), [](auto&& a, auto&& b) {
47+
return std::get<0>(a) < std::get<0>(b);
48+
});
49+
50+
return values;
51+
}
52+
53+
private:
54+
std::unordered_map<I, T> hash_;
55+
};
56+
57+
template <std::integral T>
58+
class hash_set {
59+
public:
60+
hash_set(T count) {}
61+
62+
void insert(T key) {
63+
set_.insert(key);
64+
}
65+
66+
bool contains(T key) {
67+
return set_.contains(key);
68+
}
69+
70+
void clear() {
71+
set_.clear();
72+
}
73+
74+
T size() const {
75+
return set_.size();
76+
}
77+
78+
bool empty() {
79+
return set_.empty();
80+
}
81+
82+
auto get() const {
83+
return __ranges::views::all(set_);
84+
}
85+
86+
private:
87+
std::unordered_set<T> set_;
88+
};
89+
90+
} // namespace __backend
91+
92+
} // namespace spblas

include/spblas/backend/spa_accumulator.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
#include <functional>
44
#include <span>
5-
#include <spblas/detail/ranges.hpp>
65
#include <tuple>
76
#include <vector>
87

include/spblas/concepts.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,14 @@ namespace spblas {
1111
/*
1212
The following types fulfill the matrix concept:
1313
- Instantiations of csr_view<...>
14+
- Instantiations of csc_view<...>
1415
- Instantiations of mdspan<...> with rank 2
1516
- Instantiations of scaled_view<T> where M is a matrix
1617
*/
1718

1819
template <typename M>
1920
concept matrix =
20-
__detail::is_csr_view_v<M> ||
21+
__detail::is_csr_view_v<M> || __detail::is_csc_view_v<M> ||
2122
__detail::is_matrix_instantiation_of_mdspan_v<M> || __detail::matrix<M>;
2223

2324
/*

0 commit comments

Comments
 (0)