Skip to content

Commit 3a77817

Browse files
committed
Commit missing files
1 parent 4a2e3a1 commit 3a77817

File tree

6 files changed

+311
-43
lines changed

6 files changed

+311
-43
lines changed

include/spblas/algorithms/detail/spgemm/spgemm_gustavsons.hpp

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,4 +126,92 @@ operation_info_t multiply_compute(A&& a, B&& b, C&& c) {
126126
return info;
127127
}
128128

129+
// C = AB
130+
// CSR * CSR -> CSC
131+
// SpGEMM (Gustavson's Algorithm, scattered)
132+
template <matrix A, matrix B, matrix C>
133+
requires(__backend::row_iterable<A> && __backend::row_iterable<B> &&
134+
__detail::is_csc_view_v<C>)
135+
void multiply(A&& a, B&& b, C&& c) {
136+
log_trace("");
137+
if (__backend::shape(a)[0] != __backend::shape(c)[0] ||
138+
__backend::shape(b)[1] != __backend::shape(c)[1] ||
139+
__backend::shape(a)[1] != __backend::shape(b)[0]) {
140+
throw std::invalid_argument(
141+
"multiply: matrix dimensions are incompatible.");
142+
}
143+
144+
using T = tensor_scalar_t<C>;
145+
using I = tensor_index_t<C>;
146+
147+
__backend::spa_accumulator<T, I> c_row(__backend::shape(c)[1]);
148+
149+
std::vector<std::vector<std::pair<I, T>>> columns(__backend::shape(c)[1]);
150+
151+
for (auto&& [i, a_row] : __backend::rows(a)) {
152+
c_row.clear();
153+
for (auto&& [k, a_v] : a_row) {
154+
for (auto&& [j, b_v] : __backend::lookup_row(b, k)) {
155+
c_row[j] += a_v * b_v;
156+
}
157+
}
158+
for (auto&& [j, v] : c_row.get()) {
159+
columns[j].push_back({i, v});
160+
}
161+
}
162+
163+
__backend::csc_builder c_builder(c);
164+
165+
for (std::size_t j = 0; j < columns.size(); j++) {
166+
auto&& column = columns[j];
167+
std::sort(column.begin(), column.end(),
168+
[](auto&& a, auto&& b) { return a.first < b.first; });
169+
170+
try {
171+
c_builder.insert_column(j, column);
172+
} catch (...) {
173+
throw std::runtime_error("multiply: SpGEMM ran out of memory.");
174+
}
175+
}
176+
c.update(c.values(), c.colptr(), c.rowind(), c.shape(),
177+
c.colptr()[c.shape()[1]]);
178+
}
179+
180+
// C = AB
181+
// CSR * CSR -> CSC
182+
// SpGEMM (Gustavson's Algorithm, scattered)
183+
template <matrix A, matrix B, matrix C>
184+
requires(__backend::row_iterable<A> && __backend::row_iterable<B> &&
185+
__detail::is_csc_view_v<C>)
186+
operation_info_t multiply_compute(A&& a, B&& b, C&& c) {
187+
log_trace("");
188+
if (__backend::shape(a)[0] != __backend::shape(c)[0] ||
189+
__backend::shape(b)[1] != __backend::shape(c)[1] ||
190+
__backend::shape(a)[1] != __backend::shape(b)[0]) {
191+
throw std::invalid_argument(
192+
"multiply: matrix dimensions are incompatible.");
193+
}
194+
195+
using T = tensor_scalar_t<C>;
196+
using I = tensor_index_t<C>;
197+
using O = tensor_offset_t<C>;
198+
199+
O nnz = 0;
200+
__backend::spa_set<I> c_row(__backend::shape(c)[1]);
201+
202+
for (auto&& [i, a_row] : __backend::rows(a)) {
203+
c_row.clear();
204+
205+
for (auto&& [k, a_v] : a_row) {
206+
for (auto&& [j, b_v] : __backend::lookup_row(b, k)) {
207+
c_row.insert(j);
208+
}
209+
}
210+
211+
nnz += c_row.size();
212+
}
213+
214+
return operation_info_t{__backend::shape(c), nnz};
215+
}
216+
129217
} // namespace spblas

include/spblas/vendor/armpl/detail/armpl.hpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,22 @@ template <>
124124
inline constexpr auto export_spmat_csr<std::complex<double>> =
125125
&armpl_spmat_export_csr_z;
126126

127+
template <typename T>
128+
armpl_status_t (*export_spmat_csc)(armpl_const_spmat_t, armpl_int_t,
129+
armpl_int_t*, armpl_int_t*,
130+
const armpl_int_t**, const armpl_int_t**,
131+
const T**);
132+
template <>
133+
inline constexpr auto export_spmat_csc<float> = &armpl_spmat_export_csc_s;
134+
template <>
135+
inline constexpr auto export_spmat_csc<double> = &armpl_spmat_export_csc_d;
136+
template <>
137+
inline constexpr auto export_spmat_csc<std::complex<float>> =
138+
&armpl_spmat_export_csc_c;
139+
template <>
140+
inline constexpr auto export_spmat_csc<std::complex<double>> =
141+
&armpl_spmat_export_csc_z;
142+
127143
} // namespace __armpl
128144

129145
} // namespace spblas

include/spblas/vendor/armpl/detail/detail.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22

33
#include "armpl.hpp"
44
#include "create_matrix_handle.hpp"
5+
#include "export_matrix_handle.hpp"
Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,54 @@
11
#pragma once
22

3+
#include <spblas/detail/operation_info_t.hpp>
4+
35
#include <armpl_sparse.h>
46
#include <spblas/detail/view_inspectors.hpp>
57

68
namespace spblas {
79

810
namespace __armpl {
911

10-
template <matrix M, typename O>
12+
template <matrix M>
1113
requires __detail::is_csr_view_v<M>
12-
void export_matrix_handle(operation_info_t& info, M&& m, armpl_spmat_t m_handle) {
14+
void export_matrix_handle(operation_info_t& info, M&& matrix,
15+
armpl_spmat_t matrix_handle) {
1316
auto nnz = info.result_nnz();
1417
armpl_int_t m, n;
1518
armpl_int_t *rowptr, *colind;
1619
tensor_scalar_t<M>* values;
17-
__armpl::export_spmat_csr<tensor_scalar_t<M>>(m_handle, 0, &m, &n, &rowptr,
18-
&colind, &values);
20+
__armpl::export_spmat_csr<tensor_scalar_t<M>>(matrix_handle, 0, &m, &n,
21+
&rowptr, &colind, &values);
1922

20-
std::copy(values, values + nnz, m.values().begin());
21-
std::copy(colind, colind + nnz, m.colind().begin());
22-
std::copy(rowptr, rowptr + m + 1, m.rowptr().begin());
23+
std::copy(values, values + nnz, matrix.values().begin());
24+
std::copy(colind, colind + nnz, matrix.colind().begin());
25+
std::copy(rowptr, rowptr + m + 1, matrix.rowptr().begin());
2326

2427
free(values);
2528
free(rowptr);
2629
free(colind);
2730
}
2831

32+
template <matrix M>
33+
requires __detail::is_csc_view_v<M>
34+
void export_matrix_handle(operation_info_t& info, M&& matrix,
35+
armpl_spmat_t matrix_handle) {
36+
auto nnz = info.result_nnz();
37+
armpl_int_t m, n;
38+
armpl_int_t *colptr, *rowind;
39+
tensor_scalar_t<M>* values;
40+
__armpl::export_spmat_csc<tensor_scalar_t<M>>(matrix_handle, 0, &m, &n,
41+
&rowind, &colptr, &values);
42+
43+
std::copy(values, values + nnz, matrix.values().begin());
44+
std::copy(rowind, rowind + nnz, matrix.rowind().begin());
45+
std::copy(colptr, colptr + n + 1, matrix.colptr().begin());
46+
47+
free(values);
48+
free(colptr);
49+
free(rowind);
2950
}
3051

31-
}
52+
} // namespace __armpl
53+
54+
} // namespace spblas

include/spblas/vendor/armpl/multiply_impl.hpp

Lines changed: 3 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ void multiply(A&& a, B&& b, C&& c) {
8787
template <matrix A, matrix B, matrix C>
8888
requires((__detail::has_csr_base<A> || __detail::has_csc_base<A>) &&
8989
(__detail::has_csr_base<B> || __detail::has_csc_base<B>) &&
90-
__detail::is_csr_view_v<C>)
90+
(__detail::is_csr_view_v<C> || __detail::is_csc_view_v<C>) )
9191
operation_info_t multiply_compute(A&& a, B&& b, C&& c) {
9292
log_trace("");
9393
auto a_base = __detail::get_ultimate_base(a);
@@ -117,42 +117,12 @@ operation_info_t multiply_compute(A&& a, B&& b, C&& c) {
117117
template <matrix A, matrix B, matrix C>
118118
requires((__detail::has_csr_base<A> || __detail::has_csc_base<A>) &&
119119
(__detail::has_csr_base<B> || __detail::has_csc_base<B>) &&
120-
__detail::is_csr_view_v<C>)
120+
(__detail::is_csr_view_v<C> || __detail::is_csc_view_v<C>) )
121121
void multiply_fill(operation_info_t& info, A&& a, B&& b, C&& c) {
122122
log_trace("");
123-
auto a_handle = info.state_.a_handle;
124-
auto b_handle = info.state_.b_handle;
125123
auto c_handle = info.state_.c_handle;
126124

127-
armpl_int_t m, n;
128-
auto nnz = info.result_nnz();
129-
armpl_int_t *rowptr, *colind;
130-
tensor_scalar_t<C>* values;
131-
__armpl::export_spmat_csr<tensor_scalar_t<C>>(c_handle, 0, &m, &n, &rowptr,
132-
&colind, &values);
133-
134-
std::copy(values, values + nnz, c.values().begin());
135-
std::copy(colind, colind + nnz, c.colind().begin());
136-
std::copy(rowptr, rowptr + m + 1, c.rowptr().begin());
137-
138-
free(values);
139-
free(rowptr);
140-
free(colind);
141-
}
142-
143-
template <matrix A, matrix B, matrix C>
144-
requires __detail::has_csc_base<A> && __detail::has_csc_base<B> &&
145-
__detail::is_csc_view_v<C>
146-
operation_info_t multiply_compute(A&& a, B&& b, C&& c) {
147-
return multiply_compute(transposed(b), transposed(a), transposed(c));
148-
}
149-
150-
template <matrix A, matrix B, matrix C>
151-
requires((__detail::has_csr_base<A> || __detail::has_csc_base<A>) &&
152-
(__detail::has_csr_base<B> || __detail::has_csc_base<B>) &&
153-
__detail::is_csc_view_v<C>)
154-
void multiply_fill(operation_info_t& info, A&& a, B&& b, C&& c) {
155-
multiply_fill(info, transposed(b), transposed(a), transposed(c));
125+
__armpl::export_matrix_handle(info, c, c_handle);
156126
}
157127

158128
} // namespace spblas

0 commit comments

Comments
 (0)