Skip to content

Commit df63409

Browse files
author
Christopher Armstrong
committed
ArmPL SpGEMM: sort input row indices
1 parent 57d7096 commit df63409

File tree

3 files changed

+121
-4
lines changed

3 files changed

+121
-4
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ project(spblas)
44
set(CMAKE_CXX_STANDARD 23)
55
set(CMAKE_CXX_STANDARD_REQUIRED ON)
66

7-
set(CMAKE_CXX_FLAGS "-O3 -march=native")
7+
set(CMAKE_CXX_FLAGS "-O3")
88

99
# Get includes, which declares the `spblas` library
1010
add_subdirectory(include)

include/spblas/backend/view_customizations.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,8 @@ template <typename T, typename Extents, typename LayoutPolicy,
149149
auto tag_invoke(__backend::rows_fn_,
150150
__mdspan::mdspan<T, Extents, LayoutPolicy, AccessorPolicy> m) {
151151
using index_type = tensor_index_t<decltype(m)>;
152-
using reference =
153-
__mdspan::mdspan<T, Extents, LayoutPolicy, AccessorPolicy>::reference;
152+
using reference = typename __mdspan::mdspan<T, Extents, LayoutPolicy,
153+
AccessorPolicy>::reference;
154154

155155
auto row_indices = __ranges::views::iota(index_type(0), m.extent(0));
156156

include/spblas/vendor/armpl/multiply_impl.hpp

Lines changed: 118 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
#include <spblas/detail/ranges.hpp>
88
#include <spblas/detail/view_inspectors.hpp>
99

10+
#include <fmt/printf.h>
11+
1012
namespace spblas {
1113

1214
template <matrix A, vector B, vector C>
@@ -99,15 +101,123 @@ operation_info_t multiply_compute(A&& a, B&& b, C&& c) {
99101

100102
armpl_spmat_t a_handle, b_handle, c_handle;
101103

104+
#if 0
105+
// ArmPL has a limitation that rows must be sorted
106+
using T = tensor_scalar_t<A>;
107+
using I = tensor_index_t<A>;
108+
using O = tensor_offset_t<A>;
109+
110+
auto ma = __backend::shape(a_base)[0];
111+
auto na = __backend::shape(a_base)[1];
112+
auto nnzA = a_base.rowptr().data()[ma] - a_base.rowptr().data()[0];
113+
std::vector<T> tmp_values_A(nnzA);
114+
std::vector<I> tmp_colind_A(nnzA);
115+
116+
using T = tensor_scalar_t<B>;
117+
using I = tensor_index_t<B>;
118+
using O = tensor_offset_t<B>;
119+
120+
auto mb = __backend::shape(b_base)[0];
121+
auto nb = __backend::shape(b_base)[1];
122+
auto nnzB = b_base.rowptr().data()[mb] - b_base.rowptr().data()[0];
123+
std::vector<T> tmp_values_B(nnzB);
124+
std::vector<I> tmp_colind_B(nnzB);
125+
126+
auto get_permutation = [](I* v1, T* v2, O len) {
127+
std::vector<armpl_int_t> indices(len);
128+
129+
for (size_t i = 0; i < indices.size(); ++i)
130+
indices[i] = i;
131+
132+
std::sort(indices.begin(), indices.end(), [&](I i, I j) {
133+
return v1[i] < v1[j]; // Sorting based on v1
134+
});
135+
136+
return indices;
137+
};
138+
139+
// fmt::print("matrix A {}x{}\n", ma, na);
140+
auto rowptr_A = a_base.rowptr().data();
141+
auto colind_A = a_base.colind().data();
142+
auto values_A = a_base.values().data();
143+
auto index_base_A = rowptr_A[0];
144+
145+
for (armpl_int_t i = 0; i < ma; i++) {
146+
auto indices =
147+
get_permutation(&colind_A[rowptr_A[i]], &values_A[rowptr_A[i]],
148+
rowptr_A[i + 1] - rowptr_A[i]);
149+
150+
// std::vector<armpl_int_t> indices(rowptr_A[i + 1] - rowptr_A[i]);
151+
152+
// for (size_t i = 0; i < indices.size(); ++i)
153+
// indices[i] = i;
154+
155+
auto start = rowptr_A[i];
156+
for (size_t ii = 0; ii < indices.size(); ++ii) {
157+
tmp_values_A[start + ii] = values_A[start + indices[ii]];
158+
tmp_colind_A[start + ii] = colind_A[start + indices[ii]];
159+
}
160+
161+
// for (armpl_int_t j = rowptr_A[i] - index_base_A; j < rowptr_A[i + 1] -
162+
// index_base_A;
163+
// j++) {
164+
// fmt::print("row {} col {} val {}\n", i, tmp_colind_A[j],
165+
// tmp_values_A[j]);
166+
// }
167+
}
168+
169+
// fmt::print("matrix B {}x{}\n", mb, nb);
170+
171+
auto rowptr_B = b_base.rowptr().data();
172+
auto colind_B = b_base.colind().data();
173+
auto values_B = b_base.values().data();
174+
auto index_base_B = rowptr_B[0];
175+
176+
for (armpl_int_t i = 0; i < mb; i++) {
177+
auto indices =
178+
get_permutation(&colind_B[rowptr_B[i]], &values_B[rowptr_B[i]],
179+
rowptr_B[i + 1] - rowptr_B[i]);
180+
181+
// std::vector<armpl_int_t> indices(rowptr_B[i + 1] - rowptr_B[i]);
182+
183+
// for (size_t i = 0; i < indices.size(); ++i)
184+
// indices[i] = i;
185+
186+
auto start = rowptr_B[i];
187+
for (size_t ii = 0; ii < indices.size(); ++ii) {
188+
tmp_values_B[start + ii] = values_B[start + indices[ii]];
189+
tmp_colind_B[start + ii] = colind_B[start + indices[ii]];
190+
}
191+
192+
// for (armpl_int_t j = rowptr_B[i] - index_base_B; j < rowptr_B[i + 1] -
193+
// index_base_B;
194+
// j++) {
195+
// fmt::print("row {} col {} val {}\n", i, tmp_colind_B[j],
196+
// tmp_values_B[j]);
197+
// }
198+
}
199+
200+
__armpl::create_spmat_csr<tensor_scalar_t<A>>(
201+
&a_handle, __backend::shape(a_base)[0], __backend::shape(a_base)[1],
202+
a_base.rowptr().data(), tmp_colind_A.data(), tmp_values_A.data(),
203+
ARMPL_SPARSE_CREATE_NOCOPY);
204+
205+
__armpl::create_spmat_csr<tensor_scalar_t<B>>(
206+
&b_handle, __backend::shape(b_base)[0], __backend::shape(b_base)[1],
207+
b_base.rowptr().data(), tmp_colind_B.data(), tmp_values_B.data(),
208+
ARMPL_SPARSE_CREATE_NOCOPY);
209+
#else
210+
102211
__armpl::create_spmat_csr<tensor_scalar_t<A>>(
103212
&a_handle, __backend::shape(a_base)[0], __backend::shape(a_base)[1],
104213
a_base.rowptr().data(), a_base.colind().data(), a_base.values().data(),
105214
ARMPL_SPARSE_CREATE_NOCOPY);
106215

107216
__armpl::create_spmat_csr<tensor_scalar_t<B>>(
108217
&b_handle, __backend::shape(b_base)[0], __backend::shape(b_base)[1],
109-
b_base.rowptr().data(), b_base.colind().data(), a_base.values().data(),
218+
b_base.rowptr().data(), b_base.colind().data(), b_base.values().data(),
110219
ARMPL_SPARSE_CREATE_NOCOPY);
220+
#endif
111221

112222
c_handle =
113223
armpl_spmat_create_null(__backend::shape(c)[0], __backend::shape(c)[1]);
@@ -116,6 +226,13 @@ operation_info_t multiply_compute(A&& a, B&& b, C&& c) {
116226
ARMPL_SPARSE_OPERATION_NOTRANS, alpha,
117227
a_handle, b_handle, 0, c_handle);
118228

229+
/*
230+
armpl_spmm_optimize(ARMPL_SPARSE_OPERATION_NOTRANS,
231+
ARMPL_SPARSE_OPERATION_NOTRANS,
232+
ARMPL_SPARSE_SCALAR_ANY, a_handle, b_handle, ARMPL_SPARSE_SCALAR_ZERO,
233+
c_handle);
234+
*/
235+
119236
armpl_int_t index_base, m, n, nnz;
120237
armpl_spmat_query(c_handle, &index_base, &m, &n, &nnz);
121238

0 commit comments

Comments
 (0)