Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/spblas/vendor/armpl/algorithms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

#include "multiply_impl.hpp"

#include <spblas/algorithms/triangular_solve_impl.hpp>
#include "triangular_solve_impl.hpp"
14 changes: 13 additions & 1 deletion include/spblas/vendor/armpl/detail/armpl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ inline constexpr auto create_spmat_csc<std::complex<double>> =
template <typename T>
armpl_status_t (*create_spmat_dense)(armpl_spmat_t*, enum armpl_dense_layout,
armpl_int_t, armpl_int_t, armpl_int_t,
const float*, armpl_int_t);
const T*, armpl_int_t);
template <>
inline constexpr auto create_spmat_dense<float> = &armpl_spmat_create_dense_s;
template <>
Expand Down Expand Up @@ -81,6 +81,18 @@ inline constexpr auto spmm_exec<std::complex<float>> = &armpl_spmm_exec_c;
template <>
inline constexpr auto spmm_exec<std::complex<double>> = &armpl_spmm_exec_z;

template <typename T>
armpl_status_t (*sptrsv_exec)(enum armpl_sparse_hint_value, armpl_spmat_t, T*,
T, const T*);
template <>
inline constexpr auto sptrsv_exec<float> = &armpl_spsv_exec_s;
template <>
inline constexpr auto sptrsv_exec<double> = &armpl_spsv_exec_d;
template <>
inline constexpr auto sptrsv_exec<std::complex<float>> = &armpl_spsv_exec_c;
template <>
inline constexpr auto sptrsv_exec<std::complex<double>> = &armpl_spsv_exec_z;

template <typename T>
armpl_status_t (*export_spmat_dense)(armpl_const_spmat_t,
enum armpl_dense_layout, armpl_int_t*,
Expand Down
130 changes: 130 additions & 0 deletions include/spblas/vendor/armpl/triangular_solve_impl.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
#pragma once

#include <spblas/vendor/armpl/detail/armpl.hpp>

#include <spblas/detail/log.hpp>
#include <spblas/detail/operation_info_t.hpp>
#include <spblas/detail/ranges.hpp>
#include <spblas/detail/view_inspectors.hpp>

#include <spblas/detail/triangular_types.hpp>

namespace spblas {

template <matrix A, class Triangle, class DiagonalStorage, vector B, vector X>
requires __detail::has_csr_base<A> &&
__detail::has_contiguous_range_base<B> &&
__ranges::contiguous_range<X>
void triangular_solve(A&& a, Triangle uplo, DiagonalStorage diag, B&& b,
X&& x) {
log_trace("");
static_assert(std::is_same_v<Triangle, upper_triangle_t> ||
std::is_same_v<Triangle, lower_triangle_t>);
static_assert(std::is_same_v<DiagonalStorage, explicit_diagonal_t> ||
std::is_same_v<DiagonalStorage, implicit_unit_diagonal_t>);

auto a_base = __detail::get_ultimate_base(a);
auto b_base = __detail::get_ultimate_base(b);

using T = tensor_scalar_t<A>;
using I = tensor_index_t<A>;
using O = tensor_offset_t<A>;

auto m = __backend::shape(a_base)[0];
auto n = __backend::shape(a_base)[1];

auto alpha_optional = __detail::get_scaling_factor(a, b);
T alpha = alpha_optional.value_or(1);

armpl_spmat_t a_handle = __armpl::create_matrix_handle(a_base);

// Optimistically try the solve without a copy, in case the matrix is already
// triangular
auto stat = __armpl::sptrsv_exec<tensor_scalar_t<A>>(
ARMPL_SPARSE_OPERATION_NOTRANS, a_handle, __ranges::data(x), alpha,
__ranges::data(b_base));

armpl_spmat_destroy(a_handle);

if (stat != ARMPL_STATUS_SUCCESS) {

// Arm PL needs a copy of the matrix corresponding to the specified
// triangule with the diagonal set appropriately.

auto is_upper = std::is_same_v<Triangle, upper_triangle_t>;
auto is_unit = std::is_same_v<DiagonalStorage, implicit_unit_diagonal_t>;

auto colind = a_base.colind().data();
auto rowptr = a_base.rowptr().data();
auto values = a_base.values().data();

std::vector<T> tmp_values;
std::vector<I> tmp_rowptr(m + 1);
std::vector<O> tmp_colind;

auto index_base = rowptr[0];

auto is_included = [&](auto r, auto c) {
if (is_unit) {
if (is_upper) {
return r < c;
} else {
return r > c;
}
} else {
if (is_upper) {
return r <= c;
} else {
return r >= c;
}
}
};

int k = 0;
for (armpl_int_t r = 0; r < m; r++) {

if (is_unit && is_upper) {
tmp_colind.push_back(r);
tmp_values.push_back(T(1));
k++;
}

for (auto i = rowptr[r] - index_base; i < rowptr[r + 1] - index_base;
i++) {
auto c = colind[i];
auto v = values[i];

if (is_included(r, c)) {
tmp_colind.push_back(c);
tmp_values.push_back(v);
k++;
}
}

if (is_unit && !is_upper) {
tmp_colind.push_back(r);
tmp_values.push_back(T(1));
k++;
}

tmp_rowptr[r + 1] = k;
}

__armpl::create_spmat_csr<tensor_scalar_t<A>>(
&a_handle, m, n, tmp_rowptr.data(), tmp_colind.data(),
tmp_values.data(), ARMPL_SPARSE_CREATE_NOCOPY);

stat = __armpl::sptrsv_exec<tensor_scalar_t<A>>(
ARMPL_SPARSE_OPERATION_NOTRANS, a_handle, __ranges::data(x), alpha,
__ranges::data(b_base));
if (stat != ARMPL_STATUS_SUCCESS) {
armpl_spmat_print_err(a_handle);
assert(false);
}

armpl_spmat_destroy(a_handle);
}

} // triangular_solve

} // namespace spblas