Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
274 changes: 240 additions & 34 deletions batched/dense/impl/KokkosBatched_Gemm_TeamVector_Impl.hpp

Large diffs are not rendered by default.

119 changes: 89 additions & 30 deletions batched/dense/impl/KokkosBatched_Gemm_TeamVector_Internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,33 +17,38 @@
#define KOKKOSBATCHED_GEMM_TEAMVECTOR_INTERNAL_HPP

/// \author Kyungjoo Kim (kyukim@sandia.gov)
/// \author Yuuichi Asahi (yuuichi.asahi@cea.fr)

#include "KokkosBatched_Util.hpp"

#include "KokkosBlas1_set_impl.hpp"
#include "KokkosBlas1_team_scal_impl.hpp"

#include "KokkosBatched_InnerGemmFixC_Serial_Impl.hpp"

namespace KokkosBatched {
namespace Impl {

///
/// TeamVector Internal Impl
/// ====================
template <typename ArgAlgo, bool useConjA = false>
template <typename ArgAlgo>
struct TeamVectorGemmInternal {
template <typename MemberType, typename ScalarType, typename ValueType>
KOKKOS_INLINE_FUNCTION static int invoke(const MemberType &member, const int m, const int n, const int k,
const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const int as0,
const int as1, const ValueType *KOKKOS_RESTRICT B, const int bs0,
const int bs1, const ScalarType beta,
template <typename MemberType, typename OpA, typename OpB, typename ScalarType, typename ValueType>
KOKKOS_INLINE_FUNCTION static int invoke(const MemberType &member, const OpA opA, const OpB opB, const int m,
const int n, const int k, const ScalarType alpha,
const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1,
const ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1,
const ScalarType beta,
/**/ ValueType *KOKKOS_RESTRICT C, const int cs0, const int cs1);
};

template <>
template <typename MemberType, typename ScalarType, typename ValueType>
KOKKOS_INLINE_FUNCTION int TeamVectorGemmInternal<Algo::Gemm::Unblocked, false>::invoke(
const MemberType &member, const int m, const int n, const int k, const ScalarType alpha,
const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1, const ValueType *KOKKOS_RESTRICT B, const int bs0,
const int bs1, const ScalarType beta,
template <typename MemberType, typename OpA, typename OpB, typename ScalarType, typename ValueType>
KOKKOS_INLINE_FUNCTION int TeamVectorGemmInternal<Algo::Gemm::Unblocked>::invoke(
const MemberType &member, const OpA opA, const OpB opB, const int m, const int n, const int k,
const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1,
const ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1, const ScalarType beta,
/**/ ValueType *KOKKOS_RESTRICT C, const int cs0, const int cs1) {
// C = beta C + alpha A B
// C (m x n), A(m x k), B(k x n)
Expand All @@ -56,8 +61,6 @@ KOKKOS_INLINE_FUNCTION int TeamVectorGemmInternal<Algo::Gemm::Unblocked, false>:
KokkosBlas::Impl::TeamVectorScaleInternal::invoke(member, m, n, beta, C, cs0, cs1);

if (alpha != ScalarType(0.0)) {
if (m <= 0 || n <= 0 || k <= 0) return 0;

if (beta != one) member.team_barrier();

Kokkos::parallel_for(Kokkos::TeamThreadRange(member, m), [&](const int &i) {
Expand All @@ -66,7 +69,7 @@ KOKKOS_INLINE_FUNCTION int TeamVectorGemmInternal<Algo::Gemm::Unblocked, false>:
const ValueType *KOKKOS_RESTRICT pB = B + j * bs1;

ValueType c = ValueType(0);
for (int p = 0; p < k; ++p) c += pA[p * as1] * pB[p * bs0];
for (int p = 0; p < k; ++p) c += opA(pA[p * as1]) * opB(pB[p * bs0]);
C[i * cs0 + j * cs1] += alpha * c;
});
});
Expand All @@ -75,15 +78,18 @@ KOKKOS_INLINE_FUNCTION int TeamVectorGemmInternal<Algo::Gemm::Unblocked, false>:
}

template <>
template <typename MemberType, typename ScalarType, typename ValueType>
KOKKOS_INLINE_FUNCTION int TeamVectorGemmInternal<Algo::Gemm::Unblocked, true>::invoke(
const MemberType &member, const int m, const int n, const int k, const ScalarType alpha,
const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1, const ValueType *KOKKOS_RESTRICT B, const int bs0,
const int bs1, const ScalarType beta,
template <typename MemberType, typename OpA, typename OpB, typename ScalarType, typename ValueType>
KOKKOS_INLINE_FUNCTION int TeamVectorGemmInternal<Algo::Gemm::Blocked>::invoke(
const MemberType &member, const OpA opA, const OpB opB, const int m, const int n, const int k,
const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1,
const ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1, const ScalarType beta,
/**/ ValueType *KOKKOS_RESTRICT C, const int cs0, const int cs1) {
// C = beta C + alpha A B
// C (m x n), A(m x k), B(k x n)

constexpr int mbAlgo = Algo::Gemm::Blocked::mb();
constexpr int nbAlgo = Algo::Gemm::Blocked::mb();

const ScalarType one(1.0), zero(0.0);

if (beta == zero)
Expand All @@ -92,23 +98,76 @@ KOKKOS_INLINE_FUNCTION int TeamVectorGemmInternal<Algo::Gemm::Unblocked, true>::
KokkosBlas::Impl::TeamVectorScaleInternal::invoke(member, m, n, beta, C, cs0, cs1);

if (alpha != ScalarType(0.0)) {
if (m <= 0 || n <= 0 || k <= 0) return 0;

if (beta != one) member.team_barrier();

Kokkos::parallel_for(Kokkos::TeamThreadRange(member, m), [&](const int &i) {
const ValueType *KOKKOS_RESTRICT pA = A + i * as0;
Kokkos::parallel_for(Kokkos::ThreadVectorRange(member, n), [&](const int &j) {
const ValueType *KOKKOS_RESTRICT pB = B + j * bs1;

ValueType c = ValueType(0);
for (int p = 0; p < k; ++p) c += Kokkos::ArithTraits<ValueType>::conj(pA[p * as1]) * pB[p * bs0];
C[i * cs0 + j * cs1] += alpha * c;
///
/// GPU case: team size is large and blocksize (mb,nb) is small
InnerGemmFixC<mbAlgo, nbAlgo> inner(as0, as1, bs0, bs1, cs0, cs1);
auto gemm = [&](const int ib, const int jb, const int pb, const ValueType *KOKKOS_RESTRICT AA,
const ValueType *KOKKOS_RESTRICT BB,
/**/ ValueType *KOKKOS_RESTRICT CC) {
// Made this non-const in order to WORKAROUND issue #349
int mb = mbAlgo, mp = (ib % mb), mq = (ib / mb) + (mp > 0), nb = nbAlgo, np = (jb % nb),
nq = (jb / nb) + (np > 0);

// square tiling
Kokkos::parallel_for(Kokkos::TeamVectorRange(member, mq * nq), [&](const int &ij) {
int i, j;
// note: the condition is constexpr
if (KokkosKernels::Impl::is_gpu_exec_space_v<typename MemberType::execution_space>) {
i = ij % mq * mb;
j = ij / mq * nb;
} else {
i = ij / nq * mb;
j = ij % nq * nb;
}
inner.serial_invoke(opA, opB, alpha, AA + i * as0, BB + j * bs1, (i + mb) > ib ? mp : mb,
(j + nb) > jb ? np : nb, pb, CC + i * cs0 + j * cs1);
});
});
};

const bool is_small = true; //(m*n*k <= 64*64*64);
if (is_small) {
gemm(m, n, k, A, B, C);
} else {
// // cache blocking
// const int
// nc = nb*10, kc = mb*4, mc = mb*4;

// for (int jj=0;jj<n;jj+=nc) {
// const int tj = n-jj, jb = (tj < nc ? tj : nc);
// for (int pp=0;pp<k;pp+=kc) {
// const int tp = k-pp, pb = (tp < kc ? tp : kc);
// //const int pb = k, pp = 0;
// for (int ii=0;ii<m;ii+=mc) {
// const int ti = m-ii, ib = (ti < mc ? ti : mc);

// const ValueType *KOKKOS_RESTRICT AA = A+ii*as0+pp*as1;
// const ValueType *KOKKOS_RESTRICT BB = B+pp*bs0+jj*bs1;
// /**/ ValueType *KOKKOS_RESTRICT CC = C+ii*cs0+jj*cs1;

// gemm(ib, jb, pb, AA, BB, CC);
// } // for ii
// } // for pp
// } // for jj
}
}
return 0;
}
} // namespace Impl

template <typename ArgAlgo>
struct [[deprecated("Use KokkosBatched::TeamVectorGemm instead")]] TeamVectorGemmInternal {
template <typename MemberType, typename ScalarType, typename ValueType>
KOKKOS_INLINE_FUNCTION static int invoke(const MemberType &member, const int m, const int n, const int k,
const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const int as0,
const int as1, const ValueType *KOKKOS_RESTRICT B, const int bs0,
const int bs1, const ScalarType beta,
/**/ ValueType *KOKKOS_RESTRICT C, const int cs0, const int cs1) {
return Impl::TeamVectorGemmInternal<ArgAlgo>::invoke(member, KokkosBlas::Impl::OpID(), KokkosBlas::Impl::OpID(), m,
n, k, alpha, A, as0, as1, B, bs0, bs1, beta, C, cs0, cs1);
} // namespace KokkosBatched
};

} // namespace KokkosBatched

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,9 @@ struct TeamVectorSolveUTV_Internal {
/// T is matrix_rank x matrix_rank
/// V is matrix_rank x n
/// W = U^T B
TeamVectorGemmInternal<Algo::Gemm::Unblocked>::invoke(member, matrix_rank, nrhs, m, one, U, us1, us0, B, bs0, bs1,
zero, W, ws0, ws1);
Impl::TeamVectorGemmInternal<Algo::Gemm::Unblocked>::invoke(member, KokkosBlas::Impl::OpID(),
KokkosBlas::Impl::OpID(), matrix_rank, nrhs, m, one,
U, us1, us0, B, bs0, bs1, zero, W, ws0, ws1);
member.team_barrier();

/// W = T^{-1} W
Expand All @@ -101,13 +102,15 @@ struct TeamVectorSolveUTV_Internal {
member.team_barrier();

/// X = V^T W
TeamVectorGemmInternal<Algo::Gemm::Unblocked>::invoke(member, n, nrhs, matrix_rank, one, V, vs1, vs0, W, ws0, ws1,
zero, X, xs0, xs1);
Impl::TeamVectorGemmInternal<Algo::Gemm::Unblocked>::invoke(member, KokkosBlas::Impl::OpID(),
KokkosBlas::Impl::OpID(), n, nrhs, matrix_rank, one,
V, vs1, vs0, W, ws0, ws1, zero, X, xs0, xs1);
member.team_barrier();
} else {
/// W = U^T B
TeamVectorGemmInternal<Algo::Gemm::Unblocked>::invoke(member, matrix_rank, nrhs, m, one, U, us1, us0, B, bs0, bs1,
zero, X, xs0, xs1);
Impl::TeamVectorGemmInternal<Algo::Gemm::Unblocked>::invoke(member, KokkosBlas::Impl::OpID(),
KokkosBlas::Impl::OpID(), matrix_rank, nrhs, m, one,
U, us1, us0, B, bs0, bs1, zero, X, xs0, xs1);
member.team_barrier();

/// X = T^{-1} X
Expand Down
40 changes: 37 additions & 3 deletions batched/dense/src/KokkosBatched_Gemm_Decl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,46 @@ struct TeamGemm {
const BViewType &B, const ScalarType beta, const CViewType &C);
};

/// \brief TeamVector Batched Gemm:
///
/// TeamVector Gemm
///

/// performs one of the matrix-matrix operations
/// C := alpha*op( A )*op( B ) + beta*C,
/// where op( X ) is one of
/// op( X ) = X or op( X ) = X**T or op( X ) = X**H,
/// alpha and beta are scalars, and A, B and C are matrices, with op( A ) an m by k matrix,
/// op( B ) a k by n matrix and C an m by n matrix.
/// \tparam MemberType: Member type of the Kokkos team policy
/// \tparam ArgTransA: Type indicating whether the A (Trans::NoTranspose), or A**T (Trans::Transpose) or A**H
/// (Trans::ConjTranspose) is used.
/// \tparam ArgTransB: Type indicating whether the B (Trans::NoTranspose), or B**T (Trans::Transpose) or B**H
/// (Trans::ConjTranspose) is used.
/// \tparam ArgAlgo: Type indicating the blocked (KokkosBatched::Algo::Gemm::Blocked) or unblocked
/// (KokkosBatched::Algo::Gemm::Unblocked) algorithm to be used
template <typename MemberType, typename ArgTransA, typename ArgTransB, typename ArgAlgo>
struct TeamVectorGemm {
static_assert(KokkosBlas::is_trans_v<ArgTransA>,
"KokkosBatched::TeamVectorGemm: ArgTransA must be a KokkosBlas::Trans.");
static_assert(KokkosBlas::is_trans_v<ArgTransB>,
"KokkosBatched::TeamVectorGemm: ArgTransB must be a KokkosBlas::Trans.");
static_assert(std::is_same_v<ArgAlgo, Algo::Gemm::Unblocked> || std::is_same_v<ArgAlgo, Algo::Gemm::Blocked>,
"KokkosBatched::Gemm: Use Algo::Gemm::Unblocked or Algo::Gemm::Blocked");

/// \tparam ScalarType: Scalar type of alpha and beta
/// \tparam AViewType: Input type for the matrix A, needs to be a 0D-2D view
/// \tparam BViewType: Input type for the matrix B, needs to be a 0D-2D view
/// \tparam CViewType: Input/Output type for the matrix C, needs to be a 0D-2D view
///
/// \param alpha [in]: Scalar alpha
/// \param A [in]: A is a dimension ( lda, ka ) matrix, where ka is k when ArgTransA = Trans::NoTranspose, and is m
/// otherwise.
/// \param B [in]: B is a dimension ( ldb, kb ) matrix, where kb is n when ArgTransB = Trans::NoTranspose, and is k
/// otherwise.
/// \param beta [in]: Scalar beta
/// \param C [inout]: C is a dimension ( ldc, n ) matrix. Before entry, the leading m by n part of the array C
/// must contain the matrix C, except when beta is zero, in which case C need not be set on entry. On exit, the array
/// C is overwritten by the m by n matrix ( alpha*op( A )*op( B ) + beta*C )
///
/// Team vector parallelization is used inside of the function.
template <typename ScalarType, typename AViewType, typename BViewType, typename CViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const MemberType &member, const ScalarType alpha, const AViewType &A,
const BViewType &B, const ScalarType beta, const CViewType &C);
Expand Down
1 change: 0 additions & 1 deletion batched/dense/unit_test/Test_Batched_Dense_GEMM.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
#include "Test_Batched_TeamGemm_Complex.hpp"

// TeamVector Kernels
#include "Test_Batched_TeamVectorGemm.hpp"
#include "Test_Batched_TeamVectorGemm_Real.hpp"
#include "Test_Batched_TeamVectorGemm_Complex.hpp"

Expand Down
22 changes: 16 additions & 6 deletions batched/dense/unit_test/Test_Batched_TeamGemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,20 @@
#include "gtest/gtest.h"
#include "Kokkos_Core.hpp"
#include "Kokkos_Random.hpp"
#include "KokkosBatched_Util.hpp"
#include "KokkosBatched_Gemm_Decl.hpp"
#include "KokkosBatched_Gemm_Serial_Impl.hpp"

#include "KokkosBatched_Gemm_Team_Impl.hpp"
#include "KokkosBatched_Gemm_TeamVector_Impl.hpp"

#include "KokkosKernels_TestUtils.hpp"

namespace Test {
namespace TeamGemm {

template <typename TA, typename TB>
template <typename Mode, typename TA, typename TB>
struct ParamTag {
using mode = Mode;
using transA = TA;
using transB = TB;
};
Expand All @@ -53,13 +56,20 @@ struct Functor_TestBatchedTeamGemm {
auto bb = Kokkos::subview(m_b, k, Kokkos::ALL(), Kokkos::ALL());
auto cc = Kokkos::subview(m_c, k, Kokkos::ALL(), Kokkos::ALL());

KokkosBatched::TeamGemm<MemberType, typename ParamTagType::transA, typename ParamTagType::transB,
AlgoTagType>::invoke(member, m_alpha, aa, bb, m_beta, cc);
if constexpr (std::is_same_v<typename ParamTagType::mode, KokkosBatched::Mode::Team>) {
KokkosBatched::TeamGemm<MemberType, typename ParamTagType::transA, typename ParamTagType::transB,
AlgoTagType>::invoke(member, m_alpha, aa, bb, m_beta, cc);
} else if constexpr (std::is_same_v<typename ParamTagType::mode, KokkosBatched::Mode::TeamVector>) {
KokkosBatched::TeamVectorGemm<MemberType, typename ParamTagType::transA, typename ParamTagType::transB,
AlgoTagType>::invoke(member, m_alpha, aa, bb, m_beta, cc);
}
}

inline void run() {
using value_type = typename ViewType::non_const_value_type;
std::string name_region("KokkosBatched::Test::TeamGemm");
using value_type = typename ViewType::non_const_value_type;
std::string name_region = std::is_same_v<typename ParamTagType::mode, KokkosBatched::Mode::Team>
? "KokkosBatched::Test::TeamGemm"
: "KokkosBatched::Test::TeamVectorGemm";
const std::string name_value_type = Test::value_type_name<value_type>();
std::string name = name_region + name_value_type;
Kokkos::Profiling::pushRegion(name.c_str());
Expand Down
Loading
Loading