Skip to content

Commit de12da2

Browse files
committed
Implement csc_view, implement transposed.
1 parent b8df6ba commit de12da2

File tree

14 files changed

+451
-4
lines changed

14 files changed

+451
-4
lines changed

examples/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@ add_example(simple_spmv)
77
add_example(simple_spmm)
88
add_example(simple_spgemm)
99
add_example(simple_sptrsv)
10-
add_example(matrix_opt_example)
10+
add_example(matrix_opt_example)
11+
add_example(spmm_csc)

examples/spmm_csc.cpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
#include <iostream>
2+
#include <spblas/spblas.hpp>
3+
4+
#include <fmt/core.h>
5+
#include <fmt/ranges.h>
6+
7+
int main(int argc, char** argv) {
8+
using namespace spblas;
9+
namespace md = spblas::__mdspan;
10+
11+
using T = float;
12+
13+
spblas::index_t m = 100;
14+
spblas::index_t n = 10;
15+
spblas::index_t k = 100;
16+
spblas::index_t nnz_in = 10;
17+
18+
fmt::print("\n\t###########################################################"
19+
"######################");
20+
fmt::print("\n\t### Running SpMM Example:");
21+
fmt::print("\n\t###");
22+
fmt::print("\n\t### Y = alpha * A * X");
23+
fmt::print("\n\t###");
24+
fmt::print("\n\t### with ");
25+
fmt::print("\n\t### A, in CSR format, of size ({}, {}) with nnz = {}", m, k,
26+
nnz_in);
27+
fmt::print("\n\t### x, a dense matrix, of size ({}, {})", k, n);
28+
fmt::print("\n\t### y, a dense vector, of size ({}, {})", m, n);
29+
fmt::print("\n\t### using float and spblas::index_t (size = {} bytes)",
30+
sizeof(spblas::index_t));
31+
fmt::print("\n\t###########################################################"
32+
"######################");
33+
fmt::print("\n");
34+
35+
auto&& [values, colptr, rowind, shape, nnz] = generate_csc<T>(m, k, nnz_in);
36+
37+
csc_view<T> a(values, colptr, rowind, shape, nnz);
38+
39+
std::vector<T> x_values(k * n, 1);
40+
std::vector<T> y_values(m * n, 0);
41+
42+
md::mdspan x(x_values.data(), k, n);
43+
md::mdspan y(y_values.data(), m, n);
44+
45+
// y = A * (alpha * x)
46+
multiply(a, scaled(2.f, x), y);
47+
48+
fmt::print("{}\n", spblas::__backend::values(y));
49+
50+
fmt::print("\tExample is completed!\n");
51+
52+
return 0;
53+
}

include/spblas/algorithms/multiply_impl.hpp

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <spblas/concepts.hpp>
55
#include <spblas/detail/log.hpp>
66

7+
#include <spblas/algorithms/transposed.hpp>
78
#include <spblas/backend/csr_builder.hpp>
89
#include <spblas/backend/spa_accumulator.hpp>
910
#include <spblas/detail/operation_info_t.hpp>
@@ -95,14 +96,27 @@ void multiply(A&& a, B&& b, C&& c) {
9596
try {
9697
c_builder.insert_row(i, c_row.get());
9798
} catch (...) {
98-
throw std::runtime_error("multiply: ran out of memory. CSR output view "
99-
"has insufficient memory.");
99+
throw std::runtime_error("multiply: SpGEMM ran out of memory.");
100100
}
101101
}
102102
c.update(c.values(), c.rowptr(), c.colind(), c.shape(),
103103
c.rowptr()[c.shape()[0]]);
104104
}
105105

106+
template <matrix A, matrix B, matrix C>
107+
requires(__backend::column_iterable<A> && __backend::column_iterable<B> &&
108+
__detail::is_csc_view_v<C>)
109+
void multiply(A&& a, B&& b, C&& c) {
110+
log_trace("");
111+
if (__backend::shape(a)[0] != __backend::shape(c)[0] ||
112+
__backend::shape(b)[1] != __backend::shape(c)[1] ||
113+
__backend::shape(a)[1] != __backend::shape(b)[0]) {
114+
throw std::invalid_argument(
115+
"multiply: matrix dimensions are incompatible.");
116+
}
117+
multiply(transposed(b), transposed(a), transposed(c));
118+
}
119+
106120
template <matrix A, matrix B, matrix C>
107121
operation_info_t multiply_inspect(A&& a, B&& b, C&& c) {
108122
return operation_info_t{};
@@ -147,6 +161,26 @@ operation_info_t multiply_compute(A&& a, B&& b, C&& c) {
147161
return operation_info_t{__backend::shape(c), nnz};
148162
}
149163

164+
// C = AB
165+
// SpGEMM (Gustavson's Algorithm, transposed)
166+
template <matrix A, matrix B, matrix C>
167+
requires(__backend::column_iterable<A> && __backend::column_iterable<B> &&
168+
__detail::is_csc_view_v<C>)
169+
operation_info_t multiply_compute(A&& a, B&& b, C&& c) {
170+
log_trace("");
171+
if (__backend::shape(a)[0] != __backend::shape(c)[0] ||
172+
__backend::shape(b)[1] != __backend::shape(c)[1] ||
173+
__backend::shape(a)[1] != __backend::shape(b)[0]) {
174+
throw std::invalid_argument(
175+
"multiply: matrix dimensions are incompatible.");
176+
}
177+
178+
auto info = multiply_compute(transposed(b), transposed(a), transposed(c));
179+
info.update_impl_({info.result_shape()[1], info.result_shape()[0]},
180+
info.result_nnz());
181+
return info;
182+
}
183+
150184
template <matrix A, matrix B, matrix C>
151185
requires(__backend::row_iterable<A> && __backend::row_iterable<B> &&
152186
__detail::is_csr_view_v<C>)
@@ -156,6 +190,15 @@ void multiply_compute(operation_info_t& info, A&& a, B&& b, C&& c) {
156190
info.update_impl_(new_info.result_shape(), new_info.result_nnz());
157191
}
158192

193+
template <matrix A, matrix B, matrix C>
194+
requires(__backend::column_iterable<A> && __backend::column_iterable<B> &&
195+
__detail::is_csc_view_v<C>)
196+
void multiply_compute(operation_info_t& info, A&& a, B&& b, C&& c) {
197+
auto new_info = multiply_compute(std::forward<A>(a), std::forward<B>(b),
198+
std::forward<C>(c));
199+
info.update_impl_(new_info.result_shape(), new_info.result_nnz());
200+
}
201+
159202
// C = AB
160203
template <matrix A, matrix B, matrix C>
161204
void multiply_fill(operation_info_t info, A&& a, B&& b, C&& c) {
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#pragma once
2+
3+
#include <spblas/concepts.hpp>
4+
5+
namespace spblas {
6+
7+
template <matrix M>
8+
auto transposed(M&& m);
9+
10+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#pragma once
2+
3+
#include <spblas/detail/view_inspectors.hpp>
4+
5+
namespace spblas {
6+
7+
template <matrix M>
8+
requires(__detail::is_csr_view_v<M>)
9+
auto transposed(M&& m) {
10+
return csc_view<tensor_scalar_t<M>, tensor_index_t<M>, tensor_offset_t<M>>(
11+
m.values(), m.rowptr(), m.colind(), {m.shape()[1], m.shape()[0]},
12+
m.size());
13+
}
14+
15+
template <matrix M>
16+
requires(__detail::is_csc_view_v<M>)
17+
auto transposed(M&& m) {
18+
return csr_view<tensor_scalar_t<M>, tensor_index_t<M>, tensor_offset_t<M>>(
19+
m.values(), m.colptr(), m.rowind(), {m.shape()[1], m.shape()[0]},
20+
m.size());
21+
}
22+
23+
} // namespace spblas

include/spblas/backend/algorithms.hpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,16 @@ void for_each(M&& m, F&& f) {
1818
}
1919
}
2020

21+
template <matrix M, typename F>
22+
requires(__backend::column_iterable<M>)
23+
void for_each(M&& m, F&& f) {
24+
for (auto&& [j, column] : __backend::columns(m)) {
25+
for (auto&& [i, v] : column) {
26+
f(std::make_tuple(std::tuple{i, j}, std::reference_wrapper(v)));
27+
}
28+
}
29+
}
30+
2131
template <vector V, typename F>
2232
requires(__backend::lookupable<V> && __ranges::random_access_range<V>)
2333
void for_each(V&& v, F&& f) {

include/spblas/backend/concepts.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,16 @@ namespace __backend {
1010
template <typename T>
1111
concept row_iterable = requires(T& t) { rows(t); };
1212

13+
template <typename T>
14+
concept column_iterable = requires(T& t) { columns(t); };
15+
1316
template <typename T>
1417
concept row_lookupable = requires(T& t) { lookup_row(t, tensor_index_t<T>{}); };
1518

19+
template <typename T>
20+
concept column_lookupable =
21+
requires(T& t) { lookup_column(t, tensor_index_t<T>{}); };
22+
1623
namespace {
1724

1825
template <typename T>

include/spblas/backend/cpos.hpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,16 @@ struct rows_fn_ {
4646

4747
inline constexpr auto rows = rows_fn_{};
4848

49+
struct columns_fn_ {
50+
template <typename T>
51+
requires(spblas::is_tag_invocable_v<columns_fn_, T>)
52+
constexpr auto operator()(T&& t) const {
53+
return spblas::tag_invoke(columns_fn_{}, std::forward<T>(t));
54+
}
55+
};
56+
57+
inline constexpr auto columns = columns_fn_{};
58+
4959
struct lookup_fn_ {
5060
template <typename T, typename... Args>
5161
requires(spblas::is_tag_invocable_v<lookup_fn_, T, Args...>)
@@ -70,6 +80,18 @@ struct lookup_row_fn_ {
7080

7181
inline constexpr auto lookup_row = lookup_row_fn_{};
7282

83+
struct lookup_column_fn_ {
84+
template <typename T, typename... Args>
85+
requires(spblas::is_tag_invocable_v<lookup_column_fn_, T, Args...>)
86+
constexpr tag_invoke_result_t<lookup_column_fn_, T, Args...>
87+
operator()(T&& t, Args&&... args) const {
88+
return spblas::tag_invoke(lookup_column_fn_{}, std::forward<T>(t),
89+
std::forward<Args>(args)...);
90+
}
91+
};
92+
93+
inline constexpr auto lookup_column = lookup_column_fn_{};
94+
7395
} // namespace __backend
7496

7597
} // namespace spblas

include/spblas/backend/view_customizations.hpp

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,44 @@
55

66
namespace spblas {
77

8-
// Customization point implementations for csr_view.
8+
// Customization point implementations for csr_view and csc_view.
99

1010
template <typename M>
1111
requires(__detail::is_csr_view_v<M>)
1212
auto tag_invoke(__backend::size_fn_, M&& m) {
1313
return m.size();
1414
}
1515

16+
template <typename M>
17+
requires(__detail::is_csc_view_v<M>)
18+
auto tag_invoke(__backend::size_fn_, M&& m) {
19+
return m.size();
20+
}
21+
1622
template <typename M>
1723
requires(__detail::is_csr_view_v<M>)
1824
auto tag_invoke(__backend::shape_fn_, M&& m) {
1925
return m.shape();
2026
}
2127

28+
template <typename M>
29+
requires(__detail::is_csc_view_v<M>)
30+
auto tag_invoke(__backend::shape_fn_, M&& m) {
31+
return m.shape();
32+
}
33+
2234
template <typename M>
2335
requires(__detail::is_csr_view_v<M>)
2436
auto tag_invoke(__backend::values_fn_, M&& m) {
2537
return m.values();
2638
}
2739

40+
template <typename M>
41+
requires(__detail::is_csc_view_v<M>)
42+
auto tag_invoke(__backend::values_fn_, M&& m) {
43+
return m.values();
44+
}
45+
2846
namespace {
2947

3048
template <typename M>
@@ -47,6 +65,26 @@ auto row(M&& m, typename std::remove_cvref_t<M>::index_type row_index) {
4765
return __ranges::views::zip(column_indices, row_values);
4866
}
4967

68+
template <typename M>
69+
requires(__detail::is_csc_view_v<M>)
70+
auto column(M&& m, typename std::remove_cvref_t<M>::index_type column_index) {
71+
using O = typename std::remove_cvref_t<M>::offset_type;
72+
O first = m.colptr()[column_index];
73+
O last = m.colptr()[column_index + 1];
74+
75+
using row_iter_t = decltype(m.rowind().data());
76+
using value_iter_t = decltype(m.values().data());
77+
78+
__ranges::subrange<row_iter_t> column_indices(
79+
__ranges::next(m.rowind().data(), first),
80+
__ranges::next(m.rowind().data(), last));
81+
__ranges::subrange<value_iter_t> column_values(
82+
__ranges::next(m.values().data(), first),
83+
__ranges::next(m.values().data(), last));
84+
85+
return __ranges::views::zip(column_indices, column_values);
86+
}
87+
5088
} // namespace
5189

5290
template <typename M>
@@ -62,6 +100,20 @@ auto tag_invoke(__backend::rows_fn_, M&& m) {
62100
return __ranges::views::zip(row_indices, row_values);
63101
}
64102

103+
template <typename M>
104+
requires(__detail::is_csc_view_v<M>)
105+
auto tag_invoke(__backend::columns_fn_, M&& m) {
106+
using I = typename std::remove_cvref_t<M>::index_type;
107+
auto column_indices = __ranges::views::iota(I(0), I(m.shape()[1]));
108+
109+
auto column_values =
110+
column_indices | __ranges::views::transform([=](auto column_index) {
111+
return column(m, column_index);
112+
});
113+
114+
return __ranges::views::zip(column_indices, column_values);
115+
}
116+
65117
template <typename M>
66118
requires(__detail::is_csr_view_v<M>)
67119
auto tag_invoke(__backend::lookup_row_fn_, M&& m,
@@ -70,6 +122,14 @@ auto tag_invoke(__backend::lookup_row_fn_, M&& m,
70122
return row(m, row_index);
71123
}
72124

125+
template <typename M>
126+
requires(__detail::is_csc_view_v<M>)
127+
auto tag_invoke(__backend::lookup_column_fn_, M&& m,
128+
typename std::remove_cvref_t<M>::index_type column_index) {
129+
using I = typename std::remove_cvref_t<M>::index_type;
130+
return column(m, column_index);
131+
}
132+
73133
// Customization point implementations for vectors
74134

75135
template <__ranges::random_access_range V>

0 commit comments

Comments
 (0)