Skip to content

Commit 6917a78

Browse files
upsjBenBrock
andauthored
Add transpose algorithm (#34)
* add transpose algorithm --------- Co-authored-by: Benjamin Brock <[email protected]>
1 parent 62fdee3 commit 6917a78

File tree

5 files changed

+151
-1
lines changed

5 files changed

+151
-1
lines changed

include/spblas/algorithms/algorithms.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,6 @@
1515

1616
#include <spblas/algorithms/scaled.hpp>
1717
#include <spblas/algorithms/scaled_impl.hpp>
18+
19+
#include <spblas/algorithms/transpose.hpp>
20+
#include <spblas/algorithms/transpose_impl.hpp>
Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
11
#pragma once
22

33
#include <spblas/concepts.hpp>
4+
#include <spblas/detail/operation_info_t.hpp>
45

56
namespace spblas {
67

8+
template <matrix A, matrix B>
9+
operation_info_t transpose_inspect(A&& a, B&& b);
10+
11+
template <matrix A, matrix B>
12+
void transpose(operation_info_t& info, A&& a, B&& b);
13+
714
template <matrix M>
815
auto transposed(M&& m);
916

10-
}
17+
} // namespace spblas
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
#pragma once
2+
3+
#include <spblas/concepts.hpp>
4+
#include <spblas/detail/operation_info_t.hpp>
5+
#include <spblas/detail/view_inspectors.hpp>
6+
7+
namespace spblas {
8+
9+
template <matrix A, matrix B>
10+
operation_info_t transpose_inspect(A&& a, B&& b) {
11+
return {};
12+
}
13+
14+
template <matrix A, matrix B>
15+
requires(__detail::is_csr_view_v<A> && __detail::is_csr_view_v<B>)
16+
void transpose(operation_info_t& info, A&& a, B&& b) {
17+
if (__backend::shape(a)[0] != __backend::shape(b)[1] ||
18+
__backend::shape(a)[1] != __backend::shape(b)[0]) {
19+
throw std::invalid_argument(
20+
"transpose: matrix dimensions are incompatible.");
21+
}
22+
if (__backend::size(a) != __backend::size(b)) {
23+
throw std::invalid_argument("transpose: matrix nnz are incompatible.");
24+
}
25+
using O = tensor_offset_t<B>;
26+
27+
const auto b_base = __detail::get_ultimate_base(b);
28+
const auto b_rowptr = b_base.rowptr();
29+
const auto b_colind = b_base.colind();
30+
const auto b_values = b_base.values();
31+
32+
__ranges::fill(b_rowptr, 0);
33+
34+
for (auto&& [i, row] : __backend::rows(a)) {
35+
for (auto&& [j, _] : row) {
36+
b_rowptr[j + 1]++;
37+
}
38+
}
39+
40+
std::exclusive_scan(b_rowptr.begin(), b_rowptr.end(), b_rowptr.begin(), O{});
41+
42+
for (auto&& [i, row] : __backend::rows(a)) {
43+
for (auto&& [j, v] : row) {
44+
const auto out_idx = b_rowptr[j + 1];
45+
b_colind[out_idx] = i;
46+
b_values[out_idx] = v;
47+
b_rowptr[j + 1]++;
48+
}
49+
}
50+
}
51+
52+
} // namespace spblas

test/gtest/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ add_executable(
66
spmm_test.cpp
77
spgemm_test.cpp
88
add_test.cpp
9+
transpose_test.cpp
910
triangular_solve_test.cpp
1011
)
1112

test/gtest/transpose_test.cpp

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
#include <gtest/gtest.h>
2+
3+
#include "util.hpp"
4+
#include <spblas/backend/spa_accumulator.hpp>
5+
#include <spblas/spblas.hpp>
6+
7+
#include <fmt/core.h>
8+
#include <fmt/ranges.h>
9+
10+
TEST(CsrView, Transpose) {
11+
using T = float;
12+
using I = spblas::index_t;
13+
using O = spblas::offset_t;
14+
15+
for (auto&& [m, k, nnz] : util::dims) {
16+
// Generate CSR Matrix A.
17+
auto [a_values, a_rowptr, a_colind, a_shape, a_nnz] =
18+
spblas::generate_csr<T, I>(m, k, nnz);
19+
20+
spblas::csr_view<T, I, O> a(a_values, a_rowptr, a_colind, a_shape, a_nnz);
21+
22+
// Transpose; B = A_T
23+
24+
spblas::index b_shape(a.shape()[1], a.shape()[0]);
25+
26+
std::vector<O> b_rowptr(b_shape[0] + 1);
27+
std::vector<I> b_colind(a.size());
28+
std::vector<T> b_values(a.size());
29+
30+
spblas::csr_view<T, I, O> b(b_values, b_rowptr, b_colind, b_shape,
31+
a.size());
32+
33+
auto info = spblas::transpose_inspect(a, b);
34+
spblas::transpose(info, a, b);
35+
36+
// Create transposed COO for reference.
37+
std::vector<T> ref_values;
38+
std::vector<I> ref_rowind;
39+
std::vector<I> ref_colind;
40+
41+
for (auto&& [i, row] : spblas::__backend::rows(a)) {
42+
for (auto&& [j, v] : row) {
43+
ref_values.push_back(v);
44+
ref_rowind.push_back(j);
45+
ref_colind.push_back(i);
46+
}
47+
}
48+
49+
// Create COO from transposed matrix for test.
50+
std::vector<T> test_values;
51+
std::vector<T> test_rowind;
52+
std::vector<T> test_colind;
53+
54+
for (auto&& [i, row] : spblas::__backend::rows(b)) {
55+
for (auto&& [j, v] : row) {
56+
test_values.push_back(v);
57+
test_rowind.push_back(i);
58+
test_colind.push_back(j);
59+
}
60+
}
61+
62+
// Ensure both COO matrices are sorted.
63+
spblas::__ranges::sort(
64+
spblas::__ranges::views::zip(ref_rowind, ref_colind, ref_values));
65+
spblas::__ranges::sort(
66+
spblas::__ranges::views::zip(test_rowind, test_colind, test_values));
67+
68+
EXPECT_EQ(ref_values.size(), test_values.size());
69+
EXPECT_EQ(ref_rowind.size(), test_rowind.size());
70+
EXPECT_EQ(ref_colind.size(), test_colind.size());
71+
72+
for (auto&& [a, b] :
73+
spblas::__ranges::views::zip(ref_values, test_values)) {
74+
EXPECT_EQ_(a, b);
75+
}
76+
77+
for (auto&& [a, b] :
78+
spblas::__ranges::views::zip(ref_rowind, test_rowind)) {
79+
EXPECT_EQ(a, b);
80+
}
81+
82+
for (auto&& [a, b] :
83+
spblas::__ranges::views::zip(ref_colind, test_colind)) {
84+
EXPECT_EQ(a, b);
85+
}
86+
}
87+
}

0 commit comments

Comments
 (0)