Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,160 @@

namespace ck_tile::core::arch::mma {

/**---------------------------------------------------
* Meaning of amdgcn_mma layout parameters (general)
* ---------------------------------------------------
*
* The fragment (MmaTile) sizes and layout constants in the amdgcn_mma struct describe the mapping
* between intrinsic input / output matrix elements and vector registers (lane x vector_item space).
* Note that we end up having a mapping for A, B and C separately, although those for A and B are
* usually similar if not identical. All mappings can be described as an unmerge operation on one of
* the matrix dims (either K for AB or M for C), followed by remerging of the resulting subdims and
* raw other dim into the Lane and Vector_item dimensions. When considering an unmerge operation on
* a dimension K, we can label the resulting sub-dimensions as K0, K1, and K2, where K0 is the size
* of the fastest changing dimension. K0 is also referred to as "The size of the first unmerge", and
* K1 would be "The size of the second unmerge". There are never more than 2 unmerge operations, and
* unmerge operations may be trivial (unmerge size of 1). Example double unmerge of size {3, 2} of a
* K dimension of size 12:
*
* K K2 K1 K0
* 0 0 0 0
* 1 0 0 1
* 2 0 1 0
* 3 0 1 1
* 4 0 2 0
* 5 0 2 1
* 6 1 0 0
* 7 1 0 1
* 8 1 1 0
* 9 1 1 1
* 10 1 2 0
* 11 1 2 1
*
* Note that K0 = 2 (first unmerge size, fastest changing), K1 = 3 (second unmerge size,
* second-fastest changing), and K2 = 12 / 2 / 3 = 2 (outermost dimension, whatever is left).
*
* If we were to use this unmerge op to describe an A matrix layout in registers, we might have for
* example that L (lane dim) is composed of K1 and M, and V (vector_item dim) is composed of K2 and
* K0. Compactly described, this would be K{3, 2} L{K1M} V{K2K0}, and if the M dimension was 2 we
* would have the following layout (6 lanes, 4 vector items each):
*
* | V0 | V1 | V2 | V3 |
* L0 | M=0 K=0 | M=0 K=1 | M=0 K=6 | M=0 K=7 |
* L1 | M=1 K=0 | M=1 K=1 | M=1 K=6 | M=1 K=7 |
* L2 | M=0 K=2 | M=0 K=3 | M=0 K=8 | M=0 K=9 |
* L3 | M=1 K=2 | M=1 K=3 | M=1 K=8 | M=1 K=9 |
* L4 | M=0 K=4 | M=0 K=5 | M=0 K=10 | M=0 K=11 |
* L5 | M=1 K=4 | M=1 K=5 | M=1 K=10 | M=1 K=11 |
*
* Note that all A matrix elements are now placed in a unique (lane, vector_item). In case a Repeat
* dimension is used, every single matrix element is mapped to multiple (Lane, vector_item)
* locations, usually along the Lane dimension.
*
* Check out TileDistrEncRegMap which can print full forward and backward mapping tables for any
* register mapping (expressed as a tile distribution encoding).
*
* ------------------------------------------
* Individual amdgcn_mma layout parameters
* ------------------------------------------
*
* -- ABKPerLane --
* The number of K dim elements in each lane. Always the same for A and B, even when they have
* different layouts. In terms of unmerge sizes, it's equal to K0 * K2, i.e the product of the sizes
* of the outermost and innermost dimensions after a double K unmerge.
*
* -- A / B NumAccess --
* These two variables describe the size of the outermost dimension if two unmerge operations are
* required for K (so K2). Alternatively it can be described as the number of sets the vector
* dimension, which houses a number of K indices, is split up into. We may be able to actually
* remove the A / B NumAccess from the amdgcn struct, but it sort of depends on how load and store
* tile work and whether we want the mid-level code to always have to know about this. There are
* only two reasons for the A / B NumAccess to ever not be 1, and they are different types of
* reasons:
*
* (logical correctness). Applies to scale MFMA fp8, which due to the index matrix layout does not
* allow arbitrary K perms to simplify layouts. This means the layout can only properly be described
* with a Num Access value of at least 2.
*
* (load / store manipulation). It seems like the load and store tile functions end up looking for
* the size of the smallest unmerged K dimension (K0) to determine how many elements should be
* loaded at a time. Different Num Access values will lead to different load / store behavior, even
* if logically equivalent.
*
* -- A / B Repeat --
* Variable indicating that all matrix values are represented multiple times in the vector
* registers, typically repeating in the lane dimension. This is always equal to the repeat value
* used in Tile Distribution encodings. There are two reasons to have non-trivial (non-1) value
* here: MFMA block-hiding to create oblong "virtual" intrinsics, and RDNA3 input repetition.
*
* -- CMPerLane --
* The number of M dim elements in each lane. In terms of unmerge sizes, it's equal to M0 * M2, i.e
* the product of the sizes of the outermost and innermost dimensions after a double M unmerge.
*
* -- CNumAccess --
* Same as A / B NumAccess but for the M dim (so M2), but the mid-level code doesn't care about this
* and will not try to request a specific value. Absolutely needed for logical correctness of
* register mappings since we can not perform arbitrary M permutations without messing up the A
* layout.
*/

/**
* @class amdgcn_mma_base
* @brief Base class for amdgcn_mma structs to avoid a lot of code duplication. Also puts
* all generic parameter derivations and static asserts in one place. Houses all of the
* amdgcn struct types and variables, except for the exec() function.
*/
template <typename ADataType_,
typename BDataType_,
typename CDataType_,
uint32_t FragM,
uint32_t FragN,
uint32_t FragK,
uint32_t WaveSize_,
index_t kABKPerLane_,
index_t kAKNumAccess_,
index_t kARepeat_,
index_t kBKNumAccess_,
index_t kBRepeat_,
index_t kCMPerLane_,
index_t kCMNumAccess_,
typename OpType_,
MmaOpFamily OpFamily_>
struct amdgcn_mma_base
{
using OpType = OpType_;
static constexpr MmaOpFamily OpFamily = OpFamily_;

// Data types
using ADataType = ADataType_;
using BDataType = BDataType_;
using CDataType = CDataType_;

// Fragment (MmaTile) sizes, check description above.
static constexpr index_t kM = FragM; // M = M2 * M1 * M0
static constexpr index_t kN = FragN;
static constexpr index_t kK = FragK; // K = K2 * K1 * K0

// Layout constants, check description above.
static constexpr index_t kABKPerLane = kABKPerLane_; // K2 * K0
static constexpr index_t kAKNumAccess = kAKNumAccess_; // K2
static constexpr index_t kARepeat = kARepeat_; // RDNA3 repetition and MFMA block-hiding
static constexpr index_t kBKNumAccess = kBKNumAccess_; // K2
static constexpr index_t kBRepeat = kBRepeat_; // RDNA3 repetition and MFMA block-hiding
static constexpr index_t kCMPerLane = kCMPerLane_; // M2 * M0
static constexpr index_t kCMNumAccess = kCMNumAccess_; // M2

// Register types (derived)
static constexpr index_t WaveSize = WaveSize_;
static_assert((kM * kK * kARepeat) % WaveSize == 0);
static_assert((kN * kK * kBRepeat) % WaveSize == 0);
static_assert((kM * kN) % WaveSize == 0);

using AVecType = ext_vector_t<ADataType, kM * kK * kARepeat / WaveSize>;
using BVecType = ext_vector_t<BDataType, kN * kK * kBRepeat / WaveSize>;
using CVecType = ext_vector_t<CDataType, kM * kN / WaveSize>;
};

/**
* @struct Unsupported
* @brief Meta-tag to indicate unsupported amdgcn_mma instance.
Expand All @@ -31,23 +185,24 @@ template <typename MmaOp>
concept MmaOpI = requires(MmaOp op) {
// Requires an op context
typename MmaOp::OpType;
typename MmaOp::OpFamily;

// Captures types for inputs / outputs to mma function
typename MmaOp::ADataType;
typename MmaOp::BDataType;
typename MmaOp::CDataType;
typename MmaOp::AVecType;
typename MmaOp::BVecType;
typename MmaOp::CVecType;

// Captures CK-specific layout properties
{ MmaOp::kAMBlock } -> std::convertible_to<unsigned int>;
{ MmaOp::kBNBlock } -> std::convertible_to<unsigned int>;
{ MmaOp::kAMLane } -> std::convertible_to<unsigned int>;
{ MmaOp::kBNLane } -> std::convertible_to<unsigned int>;
{ MmaOp::kABKLane } -> std::convertible_to<unsigned int>;
{ MmaOp::kABKPerLane } -> std::convertible_to<unsigned int>;
{ MmaOp::kCMLane } -> std::convertible_to<unsigned int>;
{ MmaOp::kCNLane } -> std::convertible_to<unsigned int>;
{ MmaOp::kCM0PerLane } -> std::convertible_to<unsigned int>;
{ MmaOp::kCM1PerLane } -> std::convertible_to<unsigned int>;
{ MmaOp::kAKNumAccess } -> std::convertible_to<unsigned int>;
{ MmaOp::kARepeat } -> std::convertible_to<unsigned int>;
{ MmaOp::kBKNumAccess } -> std::convertible_to<unsigned int>;
{ MmaOp::kBRepeat } -> std::convertible_to<unsigned int>;
{ MmaOp::kCMPerLane } -> std::convertible_to<unsigned int>;
{ MmaOp::kCMNumAccess } -> std::convertible_to<unsigned int>;

// Static exec function
{
Expand All @@ -69,52 +224,40 @@ concept MmaOpI = requires(MmaOp op) {
* @tparam ADataType Datatype of input A
* @tparam BDataType Datatype of input B
* @tparam CDataType Datatype of accumulator
* @tparam BlockM M-dimension of mma block
* @tparam BlockN N-dimension of mma block
* @tparam BlockK K-dimension of mma block
* @tparam FragM M-dimension of mma intrinsic (MmaTile)
* @tparam FragN N-dimension of mma intrinsic (MmaTile)
* @tparam FragK K-dimension of mma intrinsic (MmaTile)
* @tparam CtrlFlags Control flags for mma operation
* @tparam CompilerTarget The current compiler target
* @tparam OpFamily_ The type of operation (dense, sparse, scale, etc.)
* @tparam Enabler SFINAE enabler
*/
template <typename ADataType,
typename BDataType,
typename CDataType,
uint32_t BlockM,
uint32_t BlockN,
uint32_t BlockK,
uint32_t FragM,
uint32_t FragN,
uint32_t FragK,
typename CtrlFlags,
typename CompilerTarget,
MmaOpFamily OpFamily_,
typename Enabler = void>
struct amdgcn_mma
// clang-format off
// | A B C DataTypes |MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma : amdgcn_mma_base<fp32_t, fp32_t, fp32_t, 1u, 1u, 1u, 1u, 1, 1, 1, 1, 1, 1, 1, Unsupported, MmaOpFamily::UNDEFINED>
// clang-format on
{
// The base instance is unsupported because there is no __builtin to wrap.
using OpType = Unsupported;
static constexpr MmaOpFamily OpFamily = MmaOpFamily::UNDEFINED;

// Interface types for A, B, C vectors types
using AVecType = ext_vector_t<ADataType, 1>;
using BVecType = ext_vector_t<BDataType, 1>;
using CVecType = ext_vector_t<CDataType, 1>;

// Layout constants - default to 0
static constexpr index_t kAMBlock = 0;
static constexpr index_t kBNBlock = 0;

static constexpr index_t kAMLane = 0;
static constexpr index_t kBNLane = 0;
static constexpr index_t kABKLane = 0;
static constexpr index_t kABKPerLane = 0;

static constexpr index_t kCMLane = 0;
static constexpr index_t kCNLane = 0;
static constexpr index_t kCM0PerLane = 0;
static constexpr index_t kCM1PerLane = 0;

// This is a default pass-through implementation that doesn't do anything practical.
CK_TILE_DEVICE static CVecType const&
exec(AVecType const& regsA, BVecType const& regsB, CVecType const& regsC)
{
// Prints once across all thread blocks and threads.
static __device__ int printed = 0;
if(threadIdx.x == 0 && atomicCAS(&printed, 0, 1) == 0)
{
printf("[WARNING] Running amdgcn_mma dummy exec function!\n");
}

ignore(regsA, regsB);
return regsC; // No-op, just return C
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,48 +25,20 @@ namespace ck_tile::core::arch::mma {
* @brief Specialization of amdgcn_mma for MFMA on GFX9 targets
*
* This specialization implements the MFMA instruction for fp16_t A and B
* matrices, and fp32_t accumulator matrix, with 16x16x16 block sizes.
* matrices, and fp32_t accumulator matrix, with 16x16x16 fragment sizes.
*
* @tparam CtrlFlags Control flags for the MFMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx9I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
struct amdgcn_mma<fp16_t,
fp16_t,
fp32_t,
16u,
16u,
16u,
CtrlFlags,
CompilerTarget,
MmaOpFamily::DENSE,
enable_if_target_family_gfx9_t<CompilerTarget>>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx9_t<CompilerTarget>>
: amdgcn_mma_base<fp16_t, fp16_t, fp32_t, 16u, 16u, 16u, 64u, 4, 1, 1, 1, 1, 4, 1, MfmaOp, MmaOpFamily::DENSE>
// clang-format on
{
// Mfma operation type
using OpType = MfmaOp;
static constexpr MmaOpFamily OpFamily = MmaOpFamily::DENSE;

// Register types
using AVecType = ext_vector_t<fp16_t, 4>;
using BVecType = ext_vector_t<fp16_t, 4>;
using CVecType = ext_vector_t<fp32_t, 4>;

// Layout constants
static constexpr index_t kAMBlock = 1;
static constexpr index_t kBNBlock = 1;

static constexpr index_t kAMLane = 16;
static constexpr index_t kBNLane = 16;
static constexpr index_t kABKLane = 4;
static constexpr index_t kABKPerLane = 4;

static constexpr index_t kCMLane = 4;
static constexpr index_t kCNLane = 16;
static constexpr index_t kCM0PerLane = 1;
static constexpr index_t kCM1PerLane = 4;

CK_TILE_DEVICE static auto
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
{
Expand All @@ -84,47 +56,20 @@ struct amdgcn_mma<fp16_t,
* @brief Specialization of amdgcn_mma for MFMA on GFX950 targets
*
* This specialization implements the MFMA instruction for fp16_t A and B
* matrices, and fp32_t accumulator matrix, with 16x16x32 block sizes.
* matrices, and fp32_t accumulator matrix, with 16x16x32 fragment sizes.
*
* @tparam CtrlFlags Control flags for the MFMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx9I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
struct amdgcn_mma<fp16_t,
fp16_t,
fp32_t,
16u,
16u,
32u,
CtrlFlags,
CompilerTarget,
MmaOpFamily::DENSE,
enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
: amdgcn_mma_base<fp16_t, fp16_t, fp32_t, 16u, 16u, 32u, 64u, 8, 1, 1, 1, 1, 4, 1, MfmaOp, MmaOpFamily::DENSE>
// clang-format on
{
using OpType = MfmaOp;
static constexpr MmaOpFamily OpFamily = MmaOpFamily::DENSE;

// Packed register types
using AVecType = ext_vector_t<fp16_t, 8>;
using BVecType = ext_vector_t<fp16_t, 8>;
using CVecType = ext_vector_t<fp32_t, 4>;

// Layout constants
static constexpr index_t kAMBlock = 1;
static constexpr index_t kBNBlock = 1;

static constexpr index_t kAMLane = 16;
static constexpr index_t kBNLane = 16;
static constexpr index_t kABKLane = 8;
static constexpr index_t kABKPerLane = 8;

static constexpr index_t kCMLane = 4;
static constexpr index_t kCNLane = 16;
static constexpr index_t kCM0PerLane = 1;
static constexpr index_t kCM1PerLane = 4;

CK_TILE_DEVICE static auto
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
{
Expand Down
Loading
Loading