Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/spblas/algorithms/algorithms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,6 @@

#include <spblas/algorithms/scaled.hpp>
#include <spblas/algorithms/scaled_impl.hpp>

#include <spblas/algorithms/transpose.hpp>
#include <spblas/algorithms/transpose_impl.hpp>
9 changes: 8 additions & 1 deletion include/spblas/algorithms/transpose.hpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
#pragma once

#include <spblas/concepts.hpp>
#include <spblas/detail/operation_info_t.hpp>

namespace spblas {

template <matrix A, matrix B>
operation_info_t transpose_inspect(A&& a, B&& b);

template <matrix A, matrix B>
void transpose(operation_info_t& info, A&& a, B&& b);

template <matrix M>
auto transposed(M&& m);

}
} // namespace spblas
52 changes: 52 additions & 0 deletions include/spblas/algorithms/transpose_impl.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#pragma once

#include <spblas/concepts.hpp>
#include <spblas/detail/operation_info_t.hpp>
#include <spblas/detail/view_inspectors.hpp>

namespace spblas {

template <matrix A, matrix B>
operation_info_t transpose_inspect(A&& a, B&& b) {
return {};
}

template <matrix A, matrix B>
requires(__detail::is_csr_view_v<A> && __detail::is_csr_view_v<B>)
void transpose(operation_info_t& info, A&& a, B&& b) {
if (__backend::shape(a)[0] != __backend::shape(b)[1] ||
__backend::shape(a)[1] != __backend::shape(b)[0]) {
throw std::invalid_argument(
"transpose: matrix dimensions are incompatible.");
}
if (__backend::size(a) != __backend::size(b)) {
throw std::invalid_argument("transpose: matrix nnz are incompatible.");
}
using O = tensor_offset_t<B>;

const auto b_base = __detail::get_ultimate_base(b);
const auto b_rowptr = b_base.rowptr();
const auto b_colind = b_base.colind();
const auto b_values = b_base.values();

__ranges::fill(b_rowptr, 0);

for (auto&& [i, row] : __backend::rows(a)) {
for (auto&& [j, _] : row) {
b_rowptr[j + 1]++;
}
}

std::exclusive_scan(b_rowptr.begin(), b_rowptr.end(), b_rowptr.begin(), O{});

for (auto&& [i, row] : __backend::rows(a)) {
for (auto&& [j, v] : row) {
const auto out_idx = b_rowptr[j + 1];
b_colind[out_idx] = i;
b_values[out_idx] = v;
b_rowptr[j + 1]++;
}
}
}

} // namespace spblas
1 change: 1 addition & 0 deletions test/gtest/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ add_executable(
spmm_test.cpp
spgemm_test.cpp
add_test.cpp
transpose_test.cpp
triangular_solve_test.cpp
)

Expand Down
87 changes: 87 additions & 0 deletions test/gtest/transpose_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
#include <gtest/gtest.h>

#include "util.hpp"
#include <spblas/backend/spa_accumulator.hpp>
#include <spblas/spblas.hpp>

#include <fmt/core.h>
#include <fmt/ranges.h>

TEST(CsrView, Transpose) {
using T = float;
using I = spblas::index_t;
using O = spblas::offset_t;

for (auto&& [m, k, nnz] : util::dims) {
// Generate CSR Matrix A.
auto [a_values, a_rowptr, a_colind, a_shape, a_nnz] =
spblas::generate_csr<T, I>(m, k, nnz);

spblas::csr_view<T, I, O> a(a_values, a_rowptr, a_colind, a_shape, a_nnz);

// Transpose; B = A_T

spblas::index b_shape(a.shape()[1], a.shape()[0]);

std::vector<O> b_rowptr(b_shape[0] + 1);
std::vector<I> b_colind(a.size());
std::vector<T> b_values(a.size());

spblas::csr_view<T, I, O> b(b_values, b_rowptr, b_colind, b_shape,
a.size());

auto info = spblas::transpose_inspect(a, b);
spblas::transpose(info, a, b);

// Create transposed COO for reference.
std::vector<T> ref_values;
std::vector<I> ref_rowind;
std::vector<I> ref_colind;

for (auto&& [i, row] : spblas::__backend::rows(a)) {
for (auto&& [j, v] : row) {
ref_values.push_back(v);
ref_rowind.push_back(j);
ref_colind.push_back(i);
}
}

// Create COO from transposed matrix for test.
std::vector<T> test_values;
std::vector<T> test_rowind;
std::vector<T> test_colind;

for (auto&& [i, row] : spblas::__backend::rows(b)) {
for (auto&& [j, v] : row) {
test_values.push_back(v);
test_rowind.push_back(i);
test_colind.push_back(j);
}
}

// Ensure both COO matrices are sorted.
spblas::__ranges::sort(
spblas::__ranges::views::zip(ref_rowind, ref_colind, ref_values));
spblas::__ranges::sort(
spblas::__ranges::views::zip(test_rowind, test_colind, test_values));

EXPECT_EQ(ref_values.size(), test_values.size());
EXPECT_EQ(ref_rowind.size(), test_rowind.size());
EXPECT_EQ(ref_colind.size(), test_colind.size());

for (auto&& [a, b] :
spblas::__ranges::views::zip(ref_values, test_values)) {
EXPECT_EQ_(a, b);
}

for (auto&& [a, b] :
spblas::__ranges::views::zip(ref_rowind, test_rowind)) {
EXPECT_EQ(a, b);
}

for (auto&& [a, b] :
spblas::__ranges::views::zip(ref_colind, test_colind)) {
EXPECT_EQ(a, b);
}
}
}