From f974282cf90c2054b2193316359259c03e6cba43 Mon Sep 17 00:00:00 2001 From: Chris Tsiaousis Date: Thu, 12 Mar 2026 12:43:34 +0000 Subject: [PATCH 01/16] Added default implementation for `MmaTransformsDefaultSelector` The default implementation should include the pass through transforms and is needed to avoid instantiations of an undefined template. Signed-off-by: Chris Tsiaousis --- .../core/arch/mma/mfma/mfma_transforms.hpp | 8 +++++--- .../ck_tile/core/arch/mma/mma_transforms.hpp | 17 ++++++++++++++++- .../core/arch/mma/wmma/wmma_transforms.hpp | 16 ++++++++++------ 3 files changed, 31 insertions(+), 10 deletions(-) diff --git a/projects/composablekernel/include/ck_tile/core/arch/mma/mfma/mfma_transforms.hpp b/projects/composablekernel/include/ck_tile/core/arch/mma/mfma/mfma_transforms.hpp index b20d2436186c..5a3fc9a7e456 100644 --- a/projects/composablekernel/include/ck_tile/core/arch/mma/mfma/mfma_transforms.hpp +++ b/projects/composablekernel/include/ck_tile/core/arch/mma/mfma/mfma_transforms.hpp @@ -28,9 +28,11 @@ struct MmaDefaultTransformsGfx9 // TODO: c++20 template // TODO: c++20 requires template -struct MmaTransformsDefaultSelector> +struct MmaTransformsDefaultSelector< + MmaOp, + CompilerTarget, + enable_if_all, + std::enable_if_t>> { using SelectedTransforms = MmaDefaultTransformsGfx9; }; diff --git a/projects/composablekernel/include/ck_tile/core/arch/mma/mma_transforms.hpp b/projects/composablekernel/include/ck_tile/core/arch/mma/mma_transforms.hpp index 811df043644f..c41aa0ae1190 100644 --- a/projects/composablekernel/include/ck_tile/core/arch/mma/mma_transforms.hpp +++ b/projects/composablekernel/include/ck_tile/core/arch/mma/mma_transforms.hpp @@ -18,6 +18,18 @@ struct PassThroughTransform } }; +/** + * @struct MmaDefaultPassThroughTransforms + * @brief Implements the default MMA transforms + */ +struct MmaDefaultPassThroughTransforms +{ + using ATransform = PassThroughTransform; + using BTransform = PassThroughTransform; + using CTransform = PassThroughTransform; + using DTransform = PassThroughTransform; +}; + /** * @class MmaTransformsDefaultSelector * @brief Default selector for MmaTransforms based on MmaOp and CompilerTarget @@ -27,7 +39,10 @@ struct PassThroughTransform */ template // TODO: c++20 template -struct MmaTransformsDefaultSelector; +struct MmaTransformsDefaultSelector +{ + using SelectedTransforms = MmaDefaultPassThroughTransforms; +}; #if CK_TILE_CONCEPTS diff --git a/projects/composablekernel/include/ck_tile/core/arch/mma/wmma/wmma_transforms.hpp b/projects/composablekernel/include/ck_tile/core/arch/mma/wmma/wmma_transforms.hpp index eb87c38e8777..fd9cd698137e 100644 --- a/projects/composablekernel/include/ck_tile/core/arch/mma/wmma/wmma_transforms.hpp +++ b/projects/composablekernel/include/ck_tile/core/arch/mma/wmma/wmma_transforms.hpp @@ -86,9 +86,11 @@ struct MmaDefaultTransformsGfx12 template // TODO: c++20 template // TODO: c++20 requires -struct MmaTransformsDefaultSelector> +struct MmaTransformsDefaultSelector< + MmaOp, + CompilerTarget, + enable_if_all, + std::enable_if_t>> { using SelectedTransforms = MmaDefaultTransformsGfx11; }; @@ -102,9 +104,11 @@ struct MmaTransformsDefaultSelector // TODO: c++20 template // TODO: c++20 requires -struct MmaTransformsDefaultSelector> +struct MmaTransformsDefaultSelector< + MmaOp, + CompilerTarget, + enable_if_all, + std::enable_if_t>> { using SelectedTransforms = MmaDefaultTransformsGfx12; }; From 69e8516ec5443288ad26278551e7ff7fe5a8acb4 Mon Sep 17 00:00:00 2001 From: Chris Tsiaousis Date: Mon, 16 Mar 2026 16:34:45 +0000 Subject: [PATCH 02/16] Added the MmaPipelineBase struct and renamed `mma.hpp` This is because this filename was too general for what it did. Also transfered basic components to the reusable base class. Signed-off-by: Chris Tsiaousis --- .../composablekernel/include/ck_tile/core.hpp | 4 +- .../ck_tile/core/arch/mma/mma_pipeline.hpp | 151 ++++++++++++++++++ .../arch/mma/{mma.hpp => mma_wavewise.hpp} | 147 +++++++---------- .../ck_tile/core/arch/mma/test_amdgcn_mma.cpp | 2 +- 4 files changed, 210 insertions(+), 94 deletions(-) create mode 100644 projects/composablekernel/include/ck_tile/core/arch/mma/mma_pipeline.hpp rename projects/composablekernel/include/ck_tile/core/arch/mma/{mma.hpp => mma_wavewise.hpp} (57%) diff --git a/projects/composablekernel/include/ck_tile/core.hpp b/projects/composablekernel/include/ck_tile/core.hpp index ed063b3abbdd..1917e36d7f95 100644 --- a/projects/composablekernel/include/ck_tile/core.hpp +++ b/projects/composablekernel/include/ck_tile/core.hpp @@ -19,11 +19,12 @@ #include "ck_tile/core/arch/mma/mfma/mfma_selector.hpp" #include "ck_tile/core/arch/mma/mfma/mfma_traits.hpp" #include "ck_tile/core/arch/mma/mfma/mfma_transforms.hpp" -#include "ck_tile/core/arch/mma/mma.hpp" #include "ck_tile/core/arch/mma/mma_op_family.hpp" +#include "ck_tile/core/arch/mma/mma_pipeline.hpp" #include "ck_tile/core/arch/mma/mma_selector.hpp" #include "ck_tile/core/arch/mma/mma_traits.hpp" #include "ck_tile/core/arch/mma/mma_transforms.hpp" +#include "ck_tile/core/arch/mma/mma_wavewise.hpp" #include "ck_tile/core/arch/mma/sparse/mfma/selector.hpp" #include "ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp" #include "ck_tile/core/arch/mma/sparse/sparse.hpp" @@ -32,6 +33,7 @@ #include "ck_tile/core/arch/mma/sparse/sparse_transforms.hpp" #include "ck_tile/core/arch/mma/sparse/wmma/selector.hpp" #include "ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp" +#include "ck_tile/core/arch/mma/sparse_mma.hpp" #include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_register_mapper.hpp" #include "ck_tile/core/arch/mma/wmma/wmma.hpp" #include "ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp" diff --git a/projects/composablekernel/include/ck_tile/core/arch/mma/mma_pipeline.hpp b/projects/composablekernel/include/ck_tile/core/arch/mma/mma_pipeline.hpp new file mode 100644 index 000000000000..748b167503ca --- /dev/null +++ b/projects/composablekernel/include/ck_tile/core/arch/mma/mma_pipeline.hpp @@ -0,0 +1,151 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once +#include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/core/numeric/vector_type.hpp" + +#include "amdgcn_mma.hpp" +#include "mma_selector.hpp" +#include "mma_traits.hpp" +#include "mma_transforms.hpp" +#include +#include +#include + +namespace ck_tile::core::arch::mma { + +enum struct MmaPipelineOptionFlag +{ + NONE = 0x0, + C_TRANSPOSE = 0x1, + SWIZZLE_A = 0x2, + SWIZZLE_B = 0x4, + DOUBLE_ATTR_NUM_ACCESS = 0x8, + QUAD_ATTR_NUM_ACCESS = 0x10, + COMPRESS_A = 0x20, +}; + +struct MmaPipelineOptionFlags +{ + using Type = std::underlying_type::type; + + explicit constexpr MmaPipelineOptionFlags() : mFlags(0) {} + explicit constexpr MmaPipelineOptionFlags(Type value) : mFlags(value) {} + constexpr MmaPipelineOptionFlags(MmaPipelineOptionFlag singleFlag) : mFlags(toType(singleFlag)) + { + } + constexpr MmaPipelineOptionFlags(const MmaPipelineOptionFlags& original) + : mFlags(original.mFlags) + { + } + + constexpr MmaPipelineOptionFlags& operator|=(MmaPipelineOptionFlag addValue) + { + mFlags |= toType(addValue); + return *this; + } + constexpr MmaPipelineOptionFlags operator|(MmaPipelineOptionFlag addValue) const + { + MmaPipelineOptionFlags result(*this); + result |= addValue; + return result; + } + constexpr MmaPipelineOptionFlags& operator&=(MmaPipelineOptionFlag maskValue) + { + mFlags &= toType(maskValue); + return *this; + } + constexpr MmaPipelineOptionFlags operator&(MmaPipelineOptionFlag maskValue) const + { + MmaPipelineOptionFlags result(*this); + result &= maskValue; + return result; + } + constexpr MmaPipelineOptionFlags operator~() const + { + MmaPipelineOptionFlags result(*this); + result.mFlags = ~result.mFlags; + return result; + } + constexpr operator bool() const { return mFlags != toType(MmaPipelineOptionFlag::NONE); } + + private: + Type mFlags; + static constexpr Type toType(MmaPipelineOptionFlag f) { return static_cast(f); } +}; + +// TODO: c++20: use MmaPipelineOptionFlags directly +template +struct MmaPipelineBase +{ + static constexpr auto Flags = MmaPipelineOptionFlags(Flags_); + // TODO: Implement those cases + static_assert(!(Flags & MmaPipelineOptionFlag::C_TRANSPOSE), "Flag not yet implemented"); + static_assert(!(Flags & MmaPipelineOptionFlag::SWIZZLE_A), "Flag not yet implemented"); + static_assert(!(Flags & MmaPipelineOptionFlag::SWIZZLE_B), "Flag not yet implemented"); + static_assert(!(Flags & MmaPipelineOptionFlag::DOUBLE_ATTR_NUM_ACCESS), + "Flag not yet implemented"); + static_assert(!(Flags & MmaPipelineOptionFlag::QUAD_ATTR_NUM_ACCESS), + "Flag not yet implemented"); + + private: + template + CK_TILE_DEVICE static bool hasFlag() + { + return Flags & Flag; + } + + template + CK_TILE_DEVICE static auto formatBuffer(SrcT&& inputBuffer) + { + // TODO: Implement formatting logic as needed. + // This is intended to convert input fragments to the native vector types + // required by the BlockWiseMma operation for iteration + static_assert(sizeof(DstT) == sizeof(std::remove_reference_t), + "Size mismatch in formatBuffer"); + + using QualifiedDstT = + std::conditional_t>, DstT const, DstT>; + + return reinterpret_cast(inputBuffer); + } + + protected: + template + CK_TILE_DEVICE static auto preApplyTransform(Args&&... args) + { + return formatBuffer(Transform::exec(std::forward(args)...)); + } + + template + CK_TILE_DEVICE static auto postApplyTransform(Args&&... args) + { + return Transform::exec(formatBuffer(std::forward(args)...)); + } + + public: + template + CK_TILE_DEVICE static decltype(auto) exec(VecTA&& a, VecTB&& b, VecTC&& accum) + { + // TODO: c++20: Call template functions with MmaPipelineOptionFlags directly + auto pre = Derived::template preApply( + std::forward(a), std::forward(b), std::forward(accum)); + Derived::execImpl(pre); + return Derived::template postApply(std::move(pre)); + } +}; + +#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER + +#include + +/** + * @concept MmaPipelineI + * @brief Expresses the meta-data interface required for a CRTP MmaPipeline. + */ +template +concept MmaPipelineInterface = std::derived_from>; + +#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER + +} // namespace ck_tile::core::arch::mma diff --git a/projects/composablekernel/include/ck_tile/core/arch/mma/mma.hpp b/projects/composablekernel/include/ck_tile/core/arch/mma/mma_wavewise.hpp similarity index 57% rename from projects/composablekernel/include/ck_tile/core/arch/mma/mma.hpp rename to projects/composablekernel/include/ck_tile/core/arch/mma/mma_wavewise.hpp index b0eb507b4969..897143a907f3 100644 --- a/projects/composablekernel/include/ck_tile/core/arch/mma/mma.hpp +++ b/projects/composablekernel/include/ck_tile/core/arch/mma/mma_wavewise.hpp @@ -5,11 +5,13 @@ #include "ck_tile/core/numeric/vector_type.hpp" #include "amdgcn_mma.hpp" +#include "mma_pipeline.hpp" #include "mma_selector.hpp" #include "mma_transforms.hpp" #include "mfma/mfma.hpp" #include "wmma/wmma.hpp" +#include namespace ck_tile::core::arch::mma { @@ -71,8 +73,13 @@ template ::SelectedOp, typename MmaTransforms = // TODO: c++20 MmaTransformsI MmaTransforms = typename MmaTransformsDefaultSelector::SelectedTransforms> -struct WaveWiseMma +// clang-format off +struct WaveWiseMma : public MmaPipelineBase(MmaPipelineOptionFlag::NONE), + WaveWiseMma> { + using Base = MmaPipelineBase(MmaPipelineOptionFlag::NONE), + WaveWiseMma>; + // clang-format on using FragWiseMmaOp = MmaOp; // Fragment dimensions @@ -110,119 +117,75 @@ struct WaveWiseMma static_assert(WaveTileN % FragN == 0u, "WaveTileN must be a multiple of FragN"); static_assert(WaveTileK % FragK == 0u, "WaveTileK must be a multiple of FragK"); - private: - template - CK_TILE_DEVICE static auto formatBuffer(SrcT const& inputBuffer) + template + CK_TILE_DEVICE static decltype(auto) preApply(VecTA&& a, VecTB&& b, VecTC&& accum) { - // TODO: Implement formatting logic as needed. - // This is intended to convert input WaveTiles to the native vector types - // required by the FragWiseMma operation for iteration - static_assert(sizeof(DstT) == sizeof(SrcT), "Size mismatch in formatBuffer"); - return reinterpret_cast(inputBuffer); - } + static_assert(Flags == MmaPipelineOptionFlags(), "No special flags implemented yet."); - template - CK_TILE_DEVICE static auto formatBuffer(SrcT& inputBuffer) - { - // TODO: Implement formatting logic as needed. - // This is intended to convert input WaveTiles to the native vector types - // required by the FragWiseMma operation for iteration - static_assert(sizeof(DstT) == sizeof(SrcT), "Size mismatch in formatBuffer"); - return reinterpret_cast(inputBuffer); - } - - /*! @brief Execute Mma in row-major accumulation order. - * @tparam VecTA The input WaveTile A vector type - * @tparam VecTB The input WaveTile B vector type - * @tparam VecTC The input/output WaveTile C vector type - */ - template - CK_TILE_DEVICE static decltype(auto) exec_col_major(VecTA&& a, VecTB&& b, VecTC&& accum) - { // We implement an example wave-tile pipeline here. // First, we apply the necessary transforms to the input fragments, // then we convert the result into buffers of native vector formats // that we can easily index. Native vector formats are necessary inputs // to the given MmaOp exec function. - auto a_frag = formatBuffer(ATransform::exec(a)); - auto b_frag = formatBuffer(BTransform::exec(b)); - auto c_frag = formatBuffer(CTransform::exec(accum)); + auto a_frag = + Base::template preApplyTransform(std::forward(a)); + auto b_frag = + Base::template preApplyTransform(std::forward(b)); + auto c_frag = + Base::template preApplyTransform(std::forward(accum)); + + return std::make_tuple(std::move(a_frag), std::move(b_frag), std::move(c_frag)); + } - // "Col-major" accumulation over the M-dimension fragments first. - // Pseudo code here, but we would basically iterate over the fragments in col-major order - for(uint32_t bn = 0u; bn < FragsN; ++bn) - { - for(uint32_t bm = 0u; bm < FragsM; ++bm) - { - for(uint32_t bk = 0u; bk < FragsK; ++bk) - { - c_frag[bm][bn] = - FragWiseMmaOp::exec(a_frag[bm][bk], b_frag[bn][bk], c_frag[bm][bn]); - } - } - } + template + CK_TILE_DEVICE static decltype(auto) postApply(std::tuple&& vecs) + { + static_assert(Flags == MmaPipelineOptionFlags(), "No special flags implemented yet."); + auto& [a_frag, b_frag, c_frag] = vecs; // Convert native vector results back to the output WaveTile format // and then return after we apply the final output transform. - return DTransform::exec(formatBuffer>(c_frag)); + return Base::template postApplyTransform, DTransform>(c_frag); } - /*! @brief Execute Mma in row-major accumulation order. - * @tparam VecTA The input WaveTile A vector type - * @tparam VecTB The input WaveTile B vector type - * @tparam VecTC The input/output WaveTile C vector type - */ template - CK_TILE_DEVICE static decltype(auto) exec_row_major(VecTA&& a, VecTB&& b, VecTC&& accum) + CK_TILE_DEVICE static void execImpl(std::tuple& vecs) { - // We implement an example wave-tile pipeline here. - // First, we apply the necessary transforms to the input WaveTiles, - // then we convert the result into buffers of native vector formats - // that we can easily index. Native vector formats are necessary inputs - // to the given MmaOp exec function. - auto a_frag = formatBuffer(ATransform::exec(a)); - auto b_frag = formatBuffer(BTransform::exec(b)); - auto c_frag = formatBuffer(CTransform::exec(accum)); - - // "Row-major" accumulation over the N-dimension fragments first. - // Pseudo code here, but we would basically iterate over the fragments in row-major order. - // We also have to ensure that the incoming vector WaveTiles are converted to native vector - // types before passing to the FragWiseMma exec function. - for(uint32_t bm = 0u; bm < FragsM; ++bm) + auto& [a_frag, b_frag, c_frag] = vecs; + + if constexpr(AccumPolicy == MmaAccumPolicy::ROW_MAJOR) { - for(uint32_t bn = 0u; bn < FragsN; ++bn) + // "Row-major" accumulation over the N-dimension fragments first. + // Pseudo code here, but we would basically iterate over the fragments in row-major + // order. We also have to ensure that the incoming vector WaveTiles are converted to + // native vector types before passing to the FragWiseMma exec function. + for(uint32_t bm = 0u; bm < FragsM; ++bm) { - for(uint32_t bk = 0u; bk < FragsK; ++bk) + for(uint32_t bn = 0u; bn < FragsN; ++bn) { - c_frag[bm][bn] = - FragWiseMmaOp::exec(a_frag[bm][bk], b_frag[bn][bk], c_frag[bm][bn]); + for(uint32_t bk = 0u; bk < FragsK; ++bk) + { + c_frag[bm][bn] = + FragWiseMmaOp::exec(a_frag[bm][bk], b_frag[bn][bk], c_frag[bm][bn]); + } } } } - - // Convert native vector results back to the output WaveTile format - // and then return after we apply the final output transform. - return DTransform::exec(formatBuffer>(c_frag)); - } - - public: - /*! @brief Forward to Mma operation with specified accumulation order. - * @tparam VecTA The input WaveTile A vector type - * @tparam VecTB The input WaveTile B vector type - * @tparam VecTC The input/output WaveTile C vector type - */ - template - CK_TILE_DEVICE static decltype(auto) exec(VecTA&& a, VecTB&& b, VecTC&& accum) - { - if constexpr(AccumPolicy == MmaAccumPolicy::ROW_MAJOR) + else { - return exec_row_major( - std::forward(a), std::forward(b), std::forward(accum)); - } - else // if constexpr(AccumPolicy == MmaAccumPolicy::COL_MAJOR) - { - return exec_col_major( - std::forward(a), std::forward(b), std::forward(accum)); + // "Col-major" accumulation over the M-dimension fragments first. + // Pseudo code here, but we would basically iterate over the blocks in col-major order + for(uint32_t bn = 0u; bn < FragsN; ++bn) + { + for(uint32_t bm = 0u; bm < FragsM; ++bm) + { + for(uint32_t bk = 0u; bk < FragsK; ++bk) + { + c_frag[bm][bn] = + FragWiseMmaOp::exec(a_frag[bm][bk], b_frag[bn][bk], c_frag[bm][bn]); + } + } + } } } }; diff --git a/projects/composablekernel/test/ck_tile/core/arch/mma/test_amdgcn_mma.cpp b/projects/composablekernel/test/ck_tile/core/arch/mma/test_amdgcn_mma.cpp index 865c3e1011a2..5a8f478f4872 100644 --- a/projects/composablekernel/test/ck_tile/core/arch/mma/test_amdgcn_mma.cpp +++ b/projects/composablekernel/test/ck_tile/core/arch/mma/test_amdgcn_mma.cpp @@ -7,7 +7,7 @@ #include "ck_tile/core/arch/arch.hpp" #include "ck_tile/core/arch/mma/amdgcn_mma.hpp" #include "ck_tile/core/arch/mma/mma_selector.hpp" -#include "ck_tile/core/arch/mma/mma.hpp" +#include "ck_tile/core/arch/mma/mma_wavewise.hpp" #include "ck_tile/core/utility/type_traits.hpp" #include "ck_tile/host/hip_check_error.hpp" From 29893f89339c20b1c9b6c672bc2627f06c287559 Mon Sep 17 00:00:00 2001 From: Chris Tsiaousis Date: Mon, 16 Mar 2026 16:37:13 +0000 Subject: [PATCH 03/16] Added compress A functionality to sparse transforms Signed-off-by: Chris Tsiaousis --- .../arch/mma/sparse/sparse_transforms.hpp | 75 ++++++++++++++++++- 1 file changed, 74 insertions(+), 1 deletion(-) diff --git a/projects/composablekernel/include/ck_tile/core/arch/mma/sparse/sparse_transforms.hpp b/projects/composablekernel/include/ck_tile/core/arch/mma/sparse/sparse_transforms.hpp index 7da8f4f616ca..bf58b17c78ba 100644 --- a/projects/composablekernel/include/ck_tile/core/arch/mma/sparse/sparse_transforms.hpp +++ b/projects/composablekernel/include/ck_tile/core/arch/mma/sparse/sparse_transforms.hpp @@ -6,9 +6,82 @@ #include "ck_tile/core/arch/arch.hpp" #include "ck_tile/core/arch/mma/mma_op_family.hpp" #include "ck_tile/core/arch/mma/mma_transforms.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/vector_type.hpp" +#include namespace ck_tile::core::arch::mma { +namespace detail { +/** + * @brief Compress A vector for 2:4 structured sparsity instruction by moving all non-zero + * elements into lower part of a_vec to half its effective size. + * @param a_vec Vector to be compressed. + * @tparam ADataType The data type of a_vec + * @tparam CompressedSize The target compression size + * @tparam AVec The vector type of a_vec (deduced) + * @return Packed 32‑bit word containing **CompressedSize** 2‑bit fields. + * Each field encodes the original position (0–3) of the corresponding + * non‑zero element in the input. If fewer than CompressedSize + * non‑zeros are found, remaining fields default to 2 (see below). + */ +template +static CK_TILE_DEVICE int32_t compress_a_impl(AVec& a_vec) +{ + // idx holds one 2‑bit index per output element (total CompressedSize entries). + // It is initialized to the pattern 0b10 for every field. This matches + // what the hardware expects when there are fewer than two non‑zero values + // in a 4‑element group – the unused output is treated as coming from slot 2. + // The loop below will clear and set each field as real non‑zeros are seen. + int32_t idx = 0; + static_for<0, CompressedSize, 1>{}([&](auto k) { idx |= (2 << (2 * k)); }); + + static_for<0, CompressedSize / 2, 1>{}([&](auto i) { + ADataType nonzero_elems[2] = {a_vec[i * 4 + 2], a_vec[i * 4 + 3]}; + int32_t non_zero_pos = 0; + + static_for<0, 3, 1>{}([&](auto j) { + if(a_vec[i * 4 + j] != 0.0f) + { + nonzero_elems[non_zero_pos] = a_vec[i * 4 + j]; + // clear the two‑bit field for this output and insert j + idx &= ~(0b11 << 2 * (i * 2 + non_zero_pos)); + idx |= j << 2 * (i * 2 + non_zero_pos); + ++non_zero_pos; + } + }); + a_vec[i * 2] = nonzero_elems[0]; + a_vec[i * 2 + 1] = nonzero_elems[1]; + }); + + return idx; +} +} // namespace detail + +/** + * @class SparseCompressTransform + * @brief Transform to unpad data from b32 type to original type + */ +template +struct SparseCompressTransform +{ + template + CK_TILE_DEVICE static decltype(auto) exec(VecType&& v, int32_t& idx) + { + using VecTraits = vector_traits>; + using ScalarT = typename VecTraits::scalar_type; + static constexpr auto VecN = VecTraits::vector_size; + static constexpr index_t CompressedSize = VecN / CompressionRatio; + using VecCompressed = ext_vector_t; + + idx = detail::compress_a_impl(v); + + VecCompressed result; + __builtin_memcpy(&result, &v, sizeof(VecCompressed)); + return result; + } +}; + /** * @struct MmaDefaultTransformsSparse * @brief Implements the default transforms for Sparse @@ -21,7 +94,7 @@ namespace ck_tile::core::arch::mma { */ struct MmaDefaultTransformsSparse { - using ATransform = PassThroughTransform; + using ATransform = SparseCompressTransform<2>; using BTransform = PassThroughTransform; using CTransform = PassThroughTransform; using DTransform = PassThroughTransform; From 49f35e99704f0570385a4f6f7aed2a1c54fe8724 Mon Sep 17 00:00:00 2001 From: Chris Tsiaousis Date: Mon, 16 Mar 2026 16:38:10 +0000 Subject: [PATCH 04/16] Added test case for wavewise mma pipeline Signed-off-by: Chris Tsiaousis --- .../test/ck_tile/core/arch/mma/CMakeLists.txt | 2 + .../arch/mma/test_amdgcn_wavewise_mma.cpp | 147 ++++++++++++++++++ 2 files changed, 149 insertions(+) create mode 100644 projects/composablekernel/test/ck_tile/core/arch/mma/test_amdgcn_wavewise_mma.cpp diff --git a/projects/composablekernel/test/ck_tile/core/arch/mma/CMakeLists.txt b/projects/composablekernel/test/ck_tile/core/arch/mma/CMakeLists.txt index 964acfb02a1b..cd77589c4d86 100644 --- a/projects/composablekernel/test/ck_tile/core/arch/mma/CMakeLists.txt +++ b/projects/composablekernel/test/ck_tile/core/arch/mma/CMakeLists.txt @@ -14,6 +14,8 @@ endif() if(GPU_TARGETS MATCHES "gfx9") add_gtest_executable(test_amdgcn_mma test_amdgcn_mma.cpp) target_compile_options(test_amdgcn_mma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + add_gtest_executable(test_amdgcn_wavewise_mma test_amdgcn_wavewise_mma.cpp) + target_compile_options(test_amdgcn_wavewise_mma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) else() message(DEBUG "Skipping ck_tile_gemm tests for current target") endif() diff --git a/projects/composablekernel/test/ck_tile/core/arch/mma/test_amdgcn_wavewise_mma.cpp b/projects/composablekernel/test/ck_tile/core/arch/mma/test_amdgcn_wavewise_mma.cpp new file mode 100644 index 000000000000..3f05552b415f --- /dev/null +++ b/projects/composablekernel/test/ck_tile/core/arch/mma/test_amdgcn_wavewise_mma.cpp @@ -0,0 +1,147 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/core/arch/mma/mma_op_family.hpp" +#include "ck_tile/core/arch/mma/mma_wavewise.hpp" + +#include "get_wave_size_helper.hpp" + +#include +#include + +using namespace ck_tile; +using namespace ck_tile::core::arch; +using namespace ck_tile::core::arch::mma; + +template +__global__ void test_pipeline(void* a, void* b, void* c) +{ + using CompilerTarget = decltype(get_compiler_target()); + using MmaOp = typename MmaDefaultSelector::SelectedOp; + + using MmaTraits = MmaOpTraits; + + if constexpr(MmaTraits::IsSupported) + { + using Pipeline = WaveWiseMma; + + using AVecType = typename Pipeline::AVecType; + using BVecType = typename Pipeline::BVecType; + using CVecType = typename Pipeline::CVecType; + + Pipeline::exec(*reinterpret_cast(a), + *reinterpret_cast(b), + *reinterpret_cast(c)); + } +} + +TEST(WaveWiseMmaPipeline, testKIter) +{ + int devCount; + hipDevice_t dev; + HIP_CHECK_ERROR(hipGetDevice(&dev)); + HIP_CHECK_ERROR(hipGetDeviceCount(&devCount)); + + hipDeviceProp_t devProp; + HIP_CHECK_ERROR(hipGetDeviceProperties(&devProp, dev)); + + auto currentArchId = hip_device_prop_gcn_arch_name_to_amdgcn_target_id(devProp.gcnArchName); + bool hasDevice = static_cast(devCount > 0); + int deviceWarpSize = devProp.warpSize; + + bool isSupportedWmma = false; + bool isSupportedMfma = + (currentArchId >= amdgcn_target_id::GFX942) && (currentArchId <= amdgcn_target_id::GFX950); + // TODO: c++20 add check for arch id + if(!hasDevice || (currentArchId == amdgcn_target_id::HOST) || + !(isSupportedWmma || isSupportedMfma)) + { + GTEST_SKIP() << "No HIP device found. Skipping test."; + } + + using AType = fp16_t; + using BType = fp16_t; + using CType = fp32_t; + + // WaveTile size, also the expected fragment size (MmaTile) from the selector. + // Note: Actual FragK might be slightly different due to hardware implementation, but the + // test_accum_over_k kernel will loop over the K dimension to ensure that the total K is + // correct. + static constexpr uint32_t WaveTileM = 16; + static constexpr uint32_t WaveTileN = 16; + static constexpr uint32_t WaveTileK = 32; + static constexpr uint32_t FragM = WaveTileM; + static constexpr uint32_t FragN = WaveTileN; + static constexpr uint32_t FragK = WaveTileK; + + // The number of elements per thread + uint32_t AElements = FragM * FragK / deviceWarpSize; + uint32_t BElements = FragN * FragK / deviceWarpSize; + uint32_t CElements = FragM * FragN / deviceWarpSize; + + uint32_t ASize = AElements * sizeof(AType); + uint32_t BSize = BElements * sizeof(BType); + uint32_t CSize = CElements * sizeof(CType); + + // Initialize A and B to all 1's, C to all 0's + std::vector h_a(AElements, static_cast(1)); + std::vector h_b(BElements, static_cast(1)); + std::vector h_c(CElements, static_cast(0)); + std::vector h_out(CElements, static_cast(0)); + + AType* d_a; + BType* d_b; + CType* d_c; + CType* d_out; + + HIP_CHECK_ERROR(hipMalloc(&d_a, ASize)); + HIP_CHECK_ERROR(hipMalloc(&d_b, BSize)); + HIP_CHECK_ERROR(hipMalloc(&d_c, CSize)); + HIP_CHECK_ERROR(hipMalloc(&d_out, CSize)); + + // Copy inputs to device + HIP_CHECK_ERROR(hipMemcpy(d_a, h_a.data(), ASize, hipMemcpyHostToDevice)); + HIP_CHECK_ERROR(hipMemcpy(d_b, h_b.data(), BSize, hipMemcpyHostToDevice)); + HIP_CHECK_ERROR(hipMemcpy(d_c, h_c.data(), CSize, hipMemcpyHostToDevice)); + + const auto wave_size = getDeviceWaveSize(); + test_pipeline<<<1, wave_size>>>(d_a, d_b, d_c); + HIP_CHECK_ERROR(hipDeviceSynchronize()); + + HIP_CHECK_ERROR(hipMemcpy(h_out.data(), d_out, CSize, hipMemcpyDeviceToHost)); + + // Output should be FragK for all elements, because the inputs are all 1's + for(size_t i = 0; i < CElements; ++i) + { + CType expected = static_cast(FragK); + + EXPECT_NEAR(h_out[i], expected, 1e-3); + } + + HIP_CHECK_ERROR(hipFree(d_a)); + HIP_CHECK_ERROR(hipFree(d_b)); + HIP_CHECK_ERROR(hipFree(d_c)); + HIP_CHECK_ERROR(hipFree(d_out)); +} From 892165371ba01cede971d622340d11b99e46a58c Mon Sep 17 00:00:00 2001 From: Chris Tsiaousis Date: Tue, 17 Mar 2026 08:57:21 +0000 Subject: [PATCH 05/16] Made sparse amdgcn structs depict the actual builtin signature (halved A vector size) Signed-off-by: Chris Tsiaousis --- .../ck_tile/core/arch/mma/amdgcn_mma.hpp | 20 ++++++++++--------- .../core/arch/mma/sparse/mfma/sparse_gfx9.hpp | 19 +++--------------- .../arch/mma/sparse/wmma/sparse_gfx12.hpp | 20 +++---------------- 3 files changed, 17 insertions(+), 42 deletions(-) diff --git a/projects/composablekernel/include/ck_tile/core/arch/mma/amdgcn_mma.hpp b/projects/composablekernel/include/ck_tile/core/arch/mma/amdgcn_mma.hpp index 63148faf99c9..e20bb9a81ba9 100644 --- a/projects/composablekernel/include/ck_tile/core/arch/mma/amdgcn_mma.hpp +++ b/projects/composablekernel/include/ck_tile/core/arch/mma/amdgcn_mma.hpp @@ -132,7 +132,8 @@ template + MmaOpFamily OpFamily_, + index_t kCompressionRatio_ = 1> struct amdgcn_mma_base { using OpType = OpType_; @@ -149,13 +150,14 @@ struct amdgcn_mma_base static constexpr index_t kK = FragK; // K = K2 * K1 * K0 // Layout constants, check description above. - static constexpr index_t kABKPerLane = kABKPerLane_; // K2 * K0 - static constexpr index_t kAKNumAccess = kAKNumAccess_; // K2 - static constexpr index_t kARepeat = kARepeat_; // RDNA3 repetition and MFMA block-hiding - static constexpr index_t kBKNumAccess = kBKNumAccess_; // K2 - static constexpr index_t kBRepeat = kBRepeat_; // RDNA3 repetition and MFMA block-hiding - static constexpr index_t kCMPerLane = kCMPerLane_; // M2 * M0 - static constexpr index_t kCMNumAccess = kCMNumAccess_; // M2 + static constexpr index_t kABKPerLane = kABKPerLane_; // K2 * K0 + static constexpr index_t kAKNumAccess = kAKNumAccess_; // K2 + static constexpr index_t kARepeat = kARepeat_; // RDNA3 repetition and MFMA block-hiding + static constexpr index_t kBKNumAccess = kBKNumAccess_; // K2 + static constexpr index_t kBRepeat = kBRepeat_; // RDNA3 repetition and MFMA block-hiding + static constexpr index_t kCMPerLane = kCMPerLane_; // M2 * M0 + static constexpr index_t kCMNumAccess = kCMNumAccess_; // M2 + static constexpr index_t kCompressionRatio = kCompressionRatio_; // Sparse intrisics matrix A compression // Register types (derived) static constexpr index_t WaveSize = WaveSize_; @@ -163,7 +165,7 @@ struct amdgcn_mma_base static_assert((kN * kK * kBRepeat) % WaveSize == 0); static_assert((kM * kN) % WaveSize == 0); - using AVecType = ext_vector_t; + using AVecType = ext_vector_t; using BVecType = ext_vector_t; using CVecType = ext_vector_t; }; diff --git a/projects/composablekernel/include/ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp b/projects/composablekernel/include/ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp index 0941f5cbec0f..9f949c84959d 100644 --- a/projects/composablekernel/include/ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp +++ b/projects/composablekernel/include/ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp @@ -27,29 +27,16 @@ template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | struct amdgcn_mma> -: amdgcn_mma_base +: amdgcn_mma_base // clang-format on { CK_TILE_DEVICE static auto - exec(AVecType& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType + exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) -> CVecType { - static constexpr index_t ABVecN = vector_traits::vector_size; - static constexpr index_t kCompressionRatio = 2; - static constexpr index_t CompressedSize = ABVecN / kCompressionRatio; - using AVecCompressed = ext_vector_t; - - static_assert(CompressedSize == 4); - // TODO: Compressing A on-the-fly should be OK for now, but we need to validate - // and evaluate changing this to a transform at a higher level. - // aVec not being const can cause problems when running multiple intrinsics. - const uint32_t idx = ck_tile::compress_a_impl(aVec); - - const AVecCompressed a_vec_pruned = {aVec[0], aVec[1], aVec[2], aVec[3]}; - using namespace sparse::detail; static constexpr BuiltinParams PARAMS = getBuiltinParams(); return {__builtin_amdgcn_smfmac_f32_16x16x32_f16( - a_vec_pruned, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)}; + aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)}; } }; diff --git a/projects/composablekernel/include/ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp b/projects/composablekernel/include/ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp index 7981fd91aa3d..23ac7d47cb8f 100644 --- a/projects/composablekernel/include/ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp +++ b/projects/composablekernel/include/ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp @@ -17,27 +17,13 @@ namespace ck_tile::core::arch::mma { template // clang-format off struct amdgcn_mma> -: amdgcn_mma_base +: amdgcn_mma_base // clang-format on { CK_TILE_DEVICE static auto - exec(AVecType& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType + exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) -> CVecType { - static constexpr index_t ABVecN = vector_traits::vector_size; - static constexpr index_t kCompressionRatio = 2; - static constexpr index_t CompressedSize = ABVecN / kCompressionRatio; - using AVecCompressed = ext_vector_t; - - static_assert(CompressedSize == 8); - // TODO: Compressing A on-the-fly should be OK for now, but we need to validate - // and evaluate changing this to a transform at a higher level. - // aVec not being const can cause problems when running multiple intrinsics. - const uint32_t idx = ck_tile::compress_a_impl(aVec); - - const AVecCompressed a_vec_pruned = { - aVec[0], aVec[1], aVec[2], aVec[3], aVec[4], aVec[5], aVec[6], aVec[7]}; - - return {__builtin_amdgcn_swmmac_f32_16x16x32_f16_w32(a_vec_pruned, bVec, cVec, idx)}; + return {__builtin_amdgcn_swmmac_f32_16x16x32_f16_w32(aVec, bVec, cVec, idx)}; } }; From dc0043392f26907d385923b24255f9dd5817a31b Mon Sep 17 00:00:00 2001 From: Chris Tsiaousis Date: Tue, 17 Mar 2026 08:58:27 +0000 Subject: [PATCH 06/16] Defined the sparse pipeline and used it in test Signed-off-by: Chris Tsiaousis --- .../ck_tile/core/arch/mma/mma_selector.hpp | 1 + .../ck_tile/core/arch/mma/sparse_mma.hpp | 106 ++++++++++++++++++ .../core/arch/mma/test_amdgcn_sparse_mma.cpp | 57 ++++++---- 3 files changed, 141 insertions(+), 23 deletions(-) create mode 100644 projects/composablekernel/include/ck_tile/core/arch/mma/sparse_mma.hpp diff --git a/projects/composablekernel/include/ck_tile/core/arch/mma/mma_selector.hpp b/projects/composablekernel/include/ck_tile/core/arch/mma/mma_selector.hpp index 208b90d273e3..740f0f3c337a 100644 --- a/projects/composablekernel/include/ck_tile/core/arch/mma/mma_selector.hpp +++ b/projects/composablekernel/include/ck_tile/core/arch/mma/mma_selector.hpp @@ -72,3 +72,4 @@ concept MmaSelectorI = requires(MmaSelector op) { // Include the implementations #include "wmma/wmma_selector.hpp" #include "mfma/mfma_selector.hpp" +#include "sparse/sparse_selector.hpp" diff --git a/projects/composablekernel/include/ck_tile/core/arch/mma/sparse_mma.hpp b/projects/composablekernel/include/ck_tile/core/arch/mma/sparse_mma.hpp new file mode 100644 index 000000000000..3cb23bde5323 --- /dev/null +++ b/projects/composablekernel/include/ck_tile/core/arch/mma/sparse_mma.hpp @@ -0,0 +1,106 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once + +#include "ck_tile/core/arch/mma/mma_pipeline.hpp" +#include "ck_tile/core/arch/mma/mma_selector.hpp" +#include "ck_tile/core/arch/mma/mma_traits.hpp" +#include "ck_tile/core/arch/mma/sparse/sparse_transforms.hpp" +#include "ck_tile/core/numeric/vector_type.hpp" +#include +#include + +namespace ck_tile::core::arch::mma { + +template ::SelectedOp, + typename MmaTransforms = // TODO: c++20 MmaTransformsI MmaTransforms = + typename MmaTransformsDefaultSelector::SelectedTransforms> +// clang-format off +struct SparseMma : public MmaPipelineBase(MmaPipelineOptionFlag::COMPRESS_A), + SparseMma> +{ + static_assert(MmaOpTraits::IsSupported && MmaOpTraits::IsSparse); + using Base = MmaPipelineBase(MmaPipelineOptionFlag::COMPRESS_A), + SparseMma>; + // clang-format on + + // Calculate the uncompressed A vector type + struct InternalAVecCalculator + { + using AVecTraits = vector_traits; + static constexpr index_t ASize = AVecTraits::vector_size * MmaOp::kCompressionRatio; + using AVecType = ext_vector_t; + }; + + // Expose caller-side vector types + using AVecType = InternalAVecCalculator::AVecType; + using BVecType = typename MmaOp::BVecType; + using CVecType = typename MmaOp::CVecType; + + // Transforms + using ATransform = typename MmaTransforms::ATransform; + using BTransform = typename MmaTransforms::BTransform; + using CTransform = typename MmaTransforms::CTransform; + using DTransform = typename MmaTransforms::DTransform; + + template + CK_TILE_DEVICE static decltype(auto) preApply(VecTA&& a, VecTB&& b, VecTC&& accum) + { + static_assert(Flags == MmaPipelineOptionFlags(MmaPipelineOptionFlag::COMPRESS_A)); + static_assert( + std::is_same_v>); + + using InternalAVecT = typename MmaOp::AVecType; + using InternalBVecT = typename MmaOp::BVecType; + using InternalCVecT = typename MmaOp::CVecType; + + int32_t idx{}; + auto a_frag = Base::template preApplyTransform( + std::forward(a), idx); + auto b_frag = + Base::template preApplyTransform(std::forward(b)); + auto c_frag = + Base::template preApplyTransform(std::forward(accum)); + + return std::make_tuple( + std::move(a_frag), std::move(b_frag), std::move(c_frag), std::move(idx)); + } + + template + CK_TILE_DEVICE static decltype(auto) postApply(std::tuple&& vecs) + { + static_assert(Flags == MmaPipelineOptionFlags(MmaPipelineOptionFlag::COMPRESS_A)); + + auto& [a_frag, b_frag, c_frag, idx] = vecs; + // Convert native vector results back to the output fragment format + // and then return after we apply the final output transform. + return Base::template postApplyTransform, DTransform>(c_frag); + } + + template + CK_TILE_DEVICE static void execImpl(std::tuple& vecs) + { + auto& [a_frag, b_frag, c_frag, idx] = vecs; + c_frag = MmaOp::exec(a_frag, b_frag, c_frag, idx); + } +}; + +} // namespace ck_tile::core::arch::mma diff --git a/projects/composablekernel/test/ck_tile/core/arch/mma/test_amdgcn_sparse_mma.cpp b/projects/composablekernel/test/ck_tile/core/arch/mma/test_amdgcn_sparse_mma.cpp index 03abcb577238..cfd4001e0506 100644 --- a/projects/composablekernel/test/ck_tile/core/arch/mma/test_amdgcn_sparse_mma.cpp +++ b/projects/composablekernel/test/ck_tile/core/arch/mma/test_amdgcn_sparse_mma.cpp @@ -8,7 +8,9 @@ #include "ck_tile/core/arch/mma/amdgcn_mma.hpp" #include "ck_tile/core/arch/mma/mma_op_family.hpp" #include "ck_tile/core/arch/mma/mma_selector.hpp" +#include "ck_tile/core/arch/mma/sparse_mma.hpp" #include +#include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/host/hip_check_error.hpp" #include "ck_tile/core/arch/mma/mma_traits.hpp" #include "ck_tile/core/utility/type_traits.hpp" @@ -150,31 +152,40 @@ template ; - using MmaOp = typename Selector::SelectedOp; - using CVecType = typename MmaOp::CVecType; - - static constexpr uint32_t kIters = WaveTileK / MmaOp::kK; - - // Initialize the accumulator - CVecType result = *reinterpret_cast(c); - - // Accumulate input AxB over WaveTileK/FragK iterations - for(uint32_t i = 0; i < kIters; ++i) + using MmaOp = typename MmaDefaultSelector::SelectedOp; + + using MmaTraits = MmaOpTraits; + + if constexpr(MmaTraits::IsSupported) { - result = MmaOp::exec(*reinterpret_cast(a), - *reinterpret_cast(b), - result); - } + using Pipeline = SparseMma; + + using AVecType = typename Pipeline::AVecType; + using BVecType = typename Pipeline::BVecType; + using CVecType = typename Pipeline::CVecType; + + static constexpr uint32_t kIters = WaveTileK / MmaOp::kK; + + // Initialize the accumulator + CVecType result = *reinterpret_cast(c); - *reinterpret_cast(out) = result; + // Accumulate input AxB over FragK/BlockK iterations + for(uint32_t i = 0; i < kIters; ++i) + { + result = Pipeline::exec( + *reinterpret_cast(a), *reinterpret_cast(b), result); + } + + *reinterpret_cast(out) = result; + } } // Live test on real hardware for sparse selection and execution. From 6839f3846e3913f7bcef4b8b13584549cc1147f6 Mon Sep 17 00:00:00 2001 From: Chris Tsiaousis Date: Tue, 17 Mar 2026 09:46:48 +0000 Subject: [PATCH 07/16] Made concept in line with the Pipeline interface We should, in the future remove the `::Type`. Signed-off-by: Chris Tsiaousis --- .../include/ck_tile/core/arch/mma/mma_pipeline.hpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/projects/composablekernel/include/ck_tile/core/arch/mma/mma_pipeline.hpp b/projects/composablekernel/include/ck_tile/core/arch/mma/mma_pipeline.hpp index 748b167503ca..125d5f77fea1 100644 --- a/projects/composablekernel/include/ck_tile/core/arch/mma/mma_pipeline.hpp +++ b/projects/composablekernel/include/ck_tile/core/arch/mma/mma_pipeline.hpp @@ -89,12 +89,6 @@ struct MmaPipelineBase "Flag not yet implemented"); private: - template - CK_TILE_DEVICE static bool hasFlag() - { - return Flags & Flag; - } - template CK_TILE_DEVICE static auto formatBuffer(SrcT&& inputBuffer) { @@ -111,6 +105,12 @@ struct MmaPipelineBase } protected: + template + CK_TILE_DEVICE static bool hasFlag() + { + return Flags & Flag; + } + template CK_TILE_DEVICE static auto preApplyTransform(Args&&... args) { @@ -143,7 +143,7 @@ struct MmaPipelineBase * @concept MmaPipelineI * @brief Expresses the meta-data interface required for a CRTP MmaPipeline. */ -template +template concept MmaPipelineInterface = std::derived_from>; #endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER From 28b3dd059517eaa50ee5ef9ad81de55a0b0031b1 Mon Sep 17 00:00:00 2001 From: Chris Tsiaousis Date: Tue, 17 Mar 2026 09:48:32 +0000 Subject: [PATCH 08/16] Expanded the MmaOpI concept to support an extra int arg to the exec function Also added a test for this. Signed-off-by: Chris Tsiaousis --- .../ck_tile/core/arch/mma/amdgcn_mma.hpp | 25 +++++++++++++------ .../core/arch/mma/test_amdgcn_sparse_mma.cpp | 18 +++++++++++++ 2 files changed, 35 insertions(+), 8 deletions(-) diff --git a/projects/composablekernel/include/ck_tile/core/arch/mma/amdgcn_mma.hpp b/projects/composablekernel/include/ck_tile/core/arch/mma/amdgcn_mma.hpp index e20bb9a81ba9..dbe515ca477c 100644 --- a/projects/composablekernel/include/ck_tile/core/arch/mma/amdgcn_mma.hpp +++ b/projects/composablekernel/include/ck_tile/core/arch/mma/amdgcn_mma.hpp @@ -179,6 +179,20 @@ struct Unsupported; #if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER #include +/** + * @concept HasExecSignature + * @brief Helper concept for exec signature check. + */ +template +concept HasExecSignature = requires { + { + MmaOp::exec(typename MmaOp::AVecType{}, + typename MmaOp::BVecType{}, + typename MmaOp::CVecType{}, + std::declval()...) + } -> std::convertible_to; +}; + /** * @concept MmaOpI * @brief Expresses the meta-data interface required for each MmaOp policy. @@ -187,7 +201,7 @@ template concept MmaOpI = requires(MmaOp op) { // Requires an op context typename MmaOp::OpType; - typename MmaOp::OpFamily; + { MmaOp::OpFamily } -> std::convertible_to; // Captures types for inputs / outputs to mma function typename MmaOp::ADataType; @@ -205,13 +219,8 @@ concept MmaOpI = requires(MmaOp op) { { MmaOp::kBRepeat } -> std::convertible_to; { MmaOp::kCMPerLane } -> std::convertible_to; { MmaOp::kCMNumAccess } -> std::convertible_to; - - // Static exec function - { - MmaOp::exec( - typename MmaOp::AVecType{}, typename MmaOp::BVecType{}, typename MmaOp::CVecType{}) - } -> std::convertible_to; -}; + { MmaOp::kCompressionRatio } -> std::convertible_to; +} && (HasExecSignature || HasExecSignature); #endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER diff --git a/projects/composablekernel/test/ck_tile/core/arch/mma/test_amdgcn_sparse_mma.cpp b/projects/composablekernel/test/ck_tile/core/arch/mma/test_amdgcn_sparse_mma.cpp index cfd4001e0506..4c7ea93d73b4 100644 --- a/projects/composablekernel/test/ck_tile/core/arch/mma/test_amdgcn_sparse_mma.cpp +++ b/projects/composablekernel/test/ck_tile/core/arch/mma/test_amdgcn_sparse_mma.cpp @@ -71,6 +71,24 @@ TEST(SparseMMATrait, MmaOpTraitsIntegration) std::cout << "MmaOpTraits correctly integrates sparse operations" << std::endl; } +TEST(SparseMMATrait, TestConceptRequirements) +{ +#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER + using TestSparseMmma = amdgcn_mma; + static_assert(MmaOpI); +#else + GTEST_SKIP() << "Not compiled with concepts. Skipping test."; +#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER +} + TEST(SparseMMATrait, DenseVsSparseDistinction) { // Dense MFMA from mfma/mfma_gfx9.hpp From 0e1cfb3309c2ba93efe41810b2c07e0f6eee1711 Mon Sep 17 00:00:00 2001 From: Chris Tsiaousis Date: Tue, 17 Mar 2026 14:48:23 +0000 Subject: [PATCH 09/16] Revert changes done to CK Tile's internal compress A function Signed-off-by: Chris Tsiaousis --- .../core/arch/mma/sparse/mfma/sparse_gfx9.hpp | 1 - .../arch/mma/sparse/wmma/sparse_gfx12.hpp | 1 - .../ops/gemm/warp/warp_gemm_smfmac_impl.hpp | 90 ++++++++----------- 3 files changed, 36 insertions(+), 56 deletions(-) diff --git a/projects/composablekernel/include/ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp b/projects/composablekernel/include/ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp index 9f949c84959d..41fa89b9266f 100644 --- a/projects/composablekernel/include/ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp +++ b/projects/composablekernel/include/ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp @@ -6,7 +6,6 @@ #include "ck_tile/core/config.hpp" #include "ck_tile/core/arch/arch.hpp" #include "ck_tile/core/numeric/vector_type.hpp" -#include "ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp" #include "ck_tile/core/arch/mma/sparse/sparse_traits.hpp" namespace ck_tile::core::arch::mma { diff --git a/projects/composablekernel/include/ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp b/projects/composablekernel/include/ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp index 23ac7d47cb8f..e9638fc94beb 100644 --- a/projects/composablekernel/include/ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp +++ b/projects/composablekernel/include/ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp @@ -7,7 +7,6 @@ #include "ck_tile/core/arch/arch.hpp" #include "ck_tile/core/arch/mma/amdgcn_mma.hpp" #include "ck_tile/core/numeric/vector_type.hpp" -#include "ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp" #include "ck_tile/core/arch/mma/sparse/sparse_traits.hpp" namespace ck_tile::core::arch::mma { diff --git a/projects/composablekernel/include/ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp b/projects/composablekernel/include/ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp index 0a184cfacfe0..b99fc91fa76d 100644 --- a/projects/composablekernel/include/ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp +++ b/projects/composablekernel/include/ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp @@ -5,51 +5,8 @@ #include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/tensor/static_distributed_tensor.hpp" -namespace ck_tile { -/** - * @brief Compress A vector for 2:4 structured sparsity instruction by moving all non-zero - * elements into lower part of a_vec to half its effective size. - * @param a_vec Vector to be compressed. - * @tparam ADataType The data type of a_vec - * @tparam CompressedSize The target compression size - * @tparam AVec The vector type of a_vec (deduced) - * @return Packed 32‑bit word containing **CompressedSize** 2‑bit fields. - * Each field encodes the original position (0–3) of the corresponding - * non‑zero element in the input. If fewer than CompressedSize - * non‑zeros are found, remaining fields default to 2 (see below). - */ -template -static CK_TILE_DEVICE int32_t compress_a_impl(AVec& a_vec) -{ - // idx holds one 2‑bit index per output element (total CompressedSize entries). - // It is initialized to the pattern 0b10 for every field. This matches - // what the hardware expects when there are fewer than two non‑zero values - // in a 4‑element group – the unused output is treated as coming from slot 2. - // The loop below will clear and set each field as real non‑zeros are seen. - int32_t idx = 0; - static_for<0, CompressedSize, 1>{}([&](auto k) { idx |= (2 << (2 * k)); }); - - static_for<0, CompressedSize / 2, 1>{}([&](auto i) { - ADataType nonzero_elems[2] = {a_vec[i * 4 + 2], a_vec[i * 4 + 3]}; - int32_t non_zero_pos = 0; - - static_for<0, 3, 1>{}([&](auto j) { - if(a_vec[i * 4 + j] != 0.0f) - { - nonzero_elems[non_zero_pos] = a_vec[i * 4 + j]; - // clear the two‑bit field for this output and insert j - idx &= ~(0b11 << 2 * (i * 2 + non_zero_pos)); - idx |= j << 2 * (i * 2 + non_zero_pos); - ++non_zero_pos; - } - }); - a_vec[i * 2] = nonzero_elems[0]; - a_vec[i * 2 + 1] = nonzero_elems[1]; - }); - - return idx; -} +namespace ck_tile { template struct WarpGemmSmfmacImpl @@ -86,10 +43,37 @@ struct WarpGemmSmfmacImpl return WarpGemmAttribute_::get_num_of_access(); } - template - CK_TILE_DEVICE int32_t compress_a_vec(AVec& a_vec) + //---------------------------------------------------------------------------------------------- + /// @brief Compress A vector for 2:4 structured sparsity instruction by moving all non-zero + /// elements into lower part of a_vec to half its effective size. + /// + /// @param a_vec Vector to be compressed. + /// + /// @return Four 2-bit indexes of non-zero elements locations + /// + template + CK_TILE_DEVICE int32_t compress_a(AVec& a_vec) const { - return compress_a_impl(a_vec); + int32_t idx = 0b11101110; + + static_for<0, 2, 1>{}([&](auto i) { + ADataType nonzero_elems[2] = {a_vec[i * 4 + 2], a_vec[i * 4 + 3]}; + int32_t non_zero_pos = 0; + + static_for<0, 3, 1>{}([&](auto j) { + if(a_vec[i * 4 + j] != 0.0f) + { + nonzero_elems[non_zero_pos] = a_vec[i * 4 + j]; + idx &= ~(0b11 << 2 * (i * 2 + non_zero_pos)); + idx |= j << 2 * (i * 2 + non_zero_pos); + ++non_zero_pos; + } + }); + a_vec[i * 2] = nonzero_elems[0]; + a_vec[i * 2 + 1] = nonzero_elems[1]; + }); + + return idx; } template @@ -102,11 +86,10 @@ struct WarpGemmSmfmacImpl constexpr auto CompressionRatio = WarpGemmAttribute::kCompressionRatio; using AVec = ext_vector_t; - static constexpr index_t CompressedSize = - ATensor::get_thread_buffer_size() / CompressionRatio; - using AVecCompressed = ext_vector_t; - using BVec = ext_vector_t; - using CVec = ext_vector_t; + using AVecCompressed = + ext_vector_t; + using BVec = ext_vector_t; + using CVec = ext_vector_t; constexpr auto I0 = number<0>{}; @@ -114,9 +97,8 @@ struct WarpGemmSmfmacImpl const auto b_vec = b.get_thread_buffer().template get_as()[I0]; auto c_vec = c.get_thread_buffer().template get_as()[I0]; - const int32_t idx = compress_a_vec(a_vec); + const int32_t idx = compress_a(a_vec); - static_assert(CompressedSize == 4); // @TODO can we simply set a_vec_pruned to a_vec[0:3]? const AVecCompressed a_vec_pruned = {a_vec[0], a_vec[1], a_vec[2], a_vec[3]}; From 1b144524ea063daa5b1bf369e82b40a7d6163dbc Mon Sep 17 00:00:00 2001 From: Chris Tsiaousis Date: Wed, 18 Mar 2026 11:54:29 +0000 Subject: [PATCH 10/16] Run clang-format after rebasing Signed-off-by: Chris Tsiaousis --- .../ck_tile/core/arch/mma/amdgcn_mma.hpp | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/projects/composablekernel/include/ck_tile/core/arch/mma/amdgcn_mma.hpp b/projects/composablekernel/include/ck_tile/core/arch/mma/amdgcn_mma.hpp index dbe515ca477c..797c5e23872b 100644 --- a/projects/composablekernel/include/ck_tile/core/arch/mma/amdgcn_mma.hpp +++ b/projects/composablekernel/include/ck_tile/core/arch/mma/amdgcn_mma.hpp @@ -150,14 +150,15 @@ struct amdgcn_mma_base static constexpr index_t kK = FragK; // K = K2 * K1 * K0 // Layout constants, check description above. - static constexpr index_t kABKPerLane = kABKPerLane_; // K2 * K0 - static constexpr index_t kAKNumAccess = kAKNumAccess_; // K2 - static constexpr index_t kARepeat = kARepeat_; // RDNA3 repetition and MFMA block-hiding - static constexpr index_t kBKNumAccess = kBKNumAccess_; // K2 - static constexpr index_t kBRepeat = kBRepeat_; // RDNA3 repetition and MFMA block-hiding - static constexpr index_t kCMPerLane = kCMPerLane_; // M2 * M0 - static constexpr index_t kCMNumAccess = kCMNumAccess_; // M2 - static constexpr index_t kCompressionRatio = kCompressionRatio_; // Sparse intrisics matrix A compression + static constexpr index_t kABKPerLane = kABKPerLane_; // K2 * K0 + static constexpr index_t kAKNumAccess = kAKNumAccess_; // K2 + static constexpr index_t kARepeat = kARepeat_; // RDNA3 repetition and MFMA block-hiding + static constexpr index_t kBKNumAccess = kBKNumAccess_; // K2 + static constexpr index_t kBRepeat = kBRepeat_; // RDNA3 repetition and MFMA block-hiding + static constexpr index_t kCMPerLane = kCMPerLane_; // M2 * M0 + static constexpr index_t kCMNumAccess = kCMNumAccess_; // M2 + static constexpr index_t kCompressionRatio = + kCompressionRatio_; // Sparse intrisics matrix A compression // Register types (derived) static constexpr index_t WaveSize = WaveSize_; From 80887051cfa4d7a775c91d9905fb1e0c26fd547f Mon Sep 17 00:00:00 2001 From: Chris Tsiaousis Date: Wed, 18 Mar 2026 16:01:14 +0000 Subject: [PATCH 11/16] Addressed some review comments Signed-off-by: Chris Tsiaousis --- .../composablekernel/include/ck_tile/core.hpp | 2 +- .../ck_tile/core/arch/mma/mma_pipeline.hpp | 21 ++++--------------- .../ck_tile/core/arch/mma/mma_wavewise.hpp | 6 +++++- .../sparse_mma_pipeline.hpp} | 0 .../arch/mma/sparse/sparse_transforms.hpp | 9 ++++---- .../core/arch/mma/test_amdgcn_sparse_mma.cpp | 2 +- 6 files changed, 15 insertions(+), 25 deletions(-) rename projects/composablekernel/include/ck_tile/core/arch/mma/{sparse_mma.hpp => sparse/sparse_mma_pipeline.hpp} (100%) diff --git a/projects/composablekernel/include/ck_tile/core.hpp b/projects/composablekernel/include/ck_tile/core.hpp index 1917e36d7f95..3414fc9ed380 100644 --- a/projects/composablekernel/include/ck_tile/core.hpp +++ b/projects/composablekernel/include/ck_tile/core.hpp @@ -28,12 +28,12 @@ #include "ck_tile/core/arch/mma/sparse/mfma/selector.hpp" #include "ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp" #include "ck_tile/core/arch/mma/sparse/sparse.hpp" +#include "ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp" #include "ck_tile/core/arch/mma/sparse/sparse_selector.hpp" #include "ck_tile/core/arch/mma/sparse/sparse_traits.hpp" #include "ck_tile/core/arch/mma/sparse/sparse_transforms.hpp" #include "ck_tile/core/arch/mma/sparse/wmma/selector.hpp" #include "ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp" -#include "ck_tile/core/arch/mma/sparse_mma.hpp" #include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_register_mapper.hpp" #include "ck_tile/core/arch/mma/wmma/wmma.hpp" #include "ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp" diff --git a/projects/composablekernel/include/ck_tile/core/arch/mma/mma_pipeline.hpp b/projects/composablekernel/include/ck_tile/core/arch/mma/mma_pipeline.hpp index 125d5f77fea1..4213ec49158b 100644 --- a/projects/composablekernel/include/ck_tile/core/arch/mma/mma_pipeline.hpp +++ b/projects/composablekernel/include/ck_tile/core/arch/mma/mma_pipeline.hpp @@ -8,21 +8,14 @@ #include "mma_selector.hpp" #include "mma_traits.hpp" #include "mma_transforms.hpp" -#include -#include -#include namespace ck_tile::core::arch::mma { enum struct MmaPipelineOptionFlag { - NONE = 0x0, - C_TRANSPOSE = 0x1, - SWIZZLE_A = 0x2, - SWIZZLE_B = 0x4, - DOUBLE_ATTR_NUM_ACCESS = 0x8, - QUAD_ATTR_NUM_ACCESS = 0x10, - COMPRESS_A = 0x20, + NONE = 0x0, + C_TRANSPOSE = 0x1, + COMPRESS_A = 0x2, }; struct MmaPipelineOptionFlags @@ -81,12 +74,6 @@ struct MmaPipelineBase static constexpr auto Flags = MmaPipelineOptionFlags(Flags_); // TODO: Implement those cases static_assert(!(Flags & MmaPipelineOptionFlag::C_TRANSPOSE), "Flag not yet implemented"); - static_assert(!(Flags & MmaPipelineOptionFlag::SWIZZLE_A), "Flag not yet implemented"); - static_assert(!(Flags & MmaPipelineOptionFlag::SWIZZLE_B), "Flag not yet implemented"); - static_assert(!(Flags & MmaPipelineOptionFlag::DOUBLE_ATTR_NUM_ACCESS), - "Flag not yet implemented"); - static_assert(!(Flags & MmaPipelineOptionFlag::QUAD_ATTR_NUM_ACCESS), - "Flag not yet implemented"); private: template @@ -106,7 +93,7 @@ struct MmaPipelineBase protected: template - CK_TILE_DEVICE static bool hasFlag() + constexpr CK_TILE_DEVICE static bool hasFlag() { return Flags & Flag; } diff --git a/projects/composablekernel/include/ck_tile/core/arch/mma/mma_wavewise.hpp b/projects/composablekernel/include/ck_tile/core/arch/mma/mma_wavewise.hpp index 897143a907f3..3a0b26dd674a 100644 --- a/projects/composablekernel/include/ck_tile/core/arch/mma/mma_wavewise.hpp +++ b/projects/composablekernel/include/ck_tile/core/arch/mma/mma_wavewise.hpp @@ -171,7 +171,7 @@ struct WaveWiseMma : public MmaPipelineBase(MmaPipelineOptionFl } } } - else + else if constexpr(AccumPolicy == MmaAccumPolicy::COL_MAJOR) { // "Col-major" accumulation over the M-dimension fragments first. // Pseudo code here, but we would basically iterate over the blocks in col-major order @@ -187,6 +187,10 @@ struct WaveWiseMma : public MmaPipelineBase(MmaPipelineOptionFl } } } + else + { + static_assert(false); + } } }; diff --git a/projects/composablekernel/include/ck_tile/core/arch/mma/sparse_mma.hpp b/projects/composablekernel/include/ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp similarity index 100% rename from projects/composablekernel/include/ck_tile/core/arch/mma/sparse_mma.hpp rename to projects/composablekernel/include/ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp diff --git a/projects/composablekernel/include/ck_tile/core/arch/mma/sparse/sparse_transforms.hpp b/projects/composablekernel/include/ck_tile/core/arch/mma/sparse/sparse_transforms.hpp index bf58b17c78ba..2e69af6b85f8 100644 --- a/projects/composablekernel/include/ck_tile/core/arch/mma/sparse/sparse_transforms.hpp +++ b/projects/composablekernel/include/ck_tile/core/arch/mma/sparse/sparse_transforms.hpp @@ -41,7 +41,7 @@ static CK_TILE_DEVICE int32_t compress_a_impl(AVec& a_vec) int32_t non_zero_pos = 0; static_for<0, 3, 1>{}([&](auto j) { - if(a_vec[i * 4 + j] != 0.0f) + if(static_cast(a_vec[i * 4 + j]) != 0.0f) { nonzero_elems[non_zero_pos] = a_vec[i * 4 + j]; // clear the two‑bit field for this output and insert j @@ -68,7 +68,7 @@ struct SparseCompressTransform template CK_TILE_DEVICE static decltype(auto) exec(VecType&& v, int32_t& idx) { - using VecTraits = vector_traits>; + using VecTraits = vector_traits>; using ScalarT = typename VecTraits::scalar_type; static constexpr auto VecN = VecTraits::vector_size; static constexpr index_t CompressedSize = VecN / CompressionRatio; @@ -76,9 +76,8 @@ struct SparseCompressTransform idx = detail::compress_a_impl(v); - VecCompressed result; - __builtin_memcpy(&result, &v, sizeof(VecCompressed)); - return result; + // TODO c++20: Use bit_cast + return *std::launder(reinterpret_cast(&v)); } }; diff --git a/projects/composablekernel/test/ck_tile/core/arch/mma/test_amdgcn_sparse_mma.cpp b/projects/composablekernel/test/ck_tile/core/arch/mma/test_amdgcn_sparse_mma.cpp index 4c7ea93d73b4..337a865369b8 100644 --- a/projects/composablekernel/test/ck_tile/core/arch/mma/test_amdgcn_sparse_mma.cpp +++ b/projects/composablekernel/test/ck_tile/core/arch/mma/test_amdgcn_sparse_mma.cpp @@ -8,7 +8,7 @@ #include "ck_tile/core/arch/mma/amdgcn_mma.hpp" #include "ck_tile/core/arch/mma/mma_op_family.hpp" #include "ck_tile/core/arch/mma/mma_selector.hpp" -#include "ck_tile/core/arch/mma/sparse_mma.hpp" +#include "ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp" #include #include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/host/hip_check_error.hpp" From d1b02240350018a5767a1fd8668adb7e7fe31c3f Mon Sep 17 00:00:00 2001 From: Chris Tsiaousis Date: Wed, 18 Mar 2026 17:41:27 +0000 Subject: [PATCH 12/16] Added test for the sparse transform Signed-off-by: Chris Tsiaousis --- .../arch/mma/sparse/sparse_transforms.hpp | 10 ++- .../core/arch/mma/test_amdgcn_sparse_mma.cpp | 82 +++++++++++++++++++ 2 files changed, 89 insertions(+), 3 deletions(-) diff --git a/projects/composablekernel/include/ck_tile/core/arch/mma/sparse/sparse_transforms.hpp b/projects/composablekernel/include/ck_tile/core/arch/mma/sparse/sparse_transforms.hpp index 2e69af6b85f8..df0ba3c82dc9 100644 --- a/projects/composablekernel/include/ck_tile/core/arch/mma/sparse/sparse_transforms.hpp +++ b/projects/composablekernel/include/ck_tile/core/arch/mma/sparse/sparse_transforms.hpp @@ -74,6 +74,9 @@ struct SparseCompressTransform static constexpr index_t CompressedSize = VecN / CompressionRatio; using VecCompressed = ext_vector_t; + static_assert(VecN % CompressionRatio == 0, "VecN must be divisible by CompressionRatio"); + static_assert(CompressedSize > 0, "CompressedSize must be > 0"); + idx = detail::compress_a_impl(v); // TODO c++20: Use bit_cast @@ -82,7 +85,7 @@ struct SparseCompressTransform }; /** - * @struct MmaDefaultTransformsSparse + * @class MmaDefaultTransformsSparse * @brief Implements the default transforms for Sparse * * For 2:4 structured sparsity with inline register metadata: @@ -91,9 +94,10 @@ struct SparseCompressTransform * - CTransform: Pass-through (input accumulator) * - DTransform: Pass-through (output accumulator as-is) */ +template struct MmaDefaultTransformsSparse { - using ATransform = SparseCompressTransform<2>; + using ATransform = SparseCompressTransform; using BTransform = PassThroughTransform; using CTransform = PassThroughTransform; using DTransform = PassThroughTransform; @@ -114,7 +118,7 @@ struct MmaTransformsDefaultSelector> { - using SelectedTransforms = MmaDefaultTransformsSparse; + using SelectedTransforms = MmaDefaultTransformsSparse; }; } // namespace ck_tile::core::arch::mma diff --git a/projects/composablekernel/test/ck_tile/core/arch/mma/test_amdgcn_sparse_mma.cpp b/projects/composablekernel/test/ck_tile/core/arch/mma/test_amdgcn_sparse_mma.cpp index 337a865369b8..7e172ca2fe84 100644 --- a/projects/composablekernel/test/ck_tile/core/arch/mma/test_amdgcn_sparse_mma.cpp +++ b/projects/composablekernel/test/ck_tile/core/arch/mma/test_amdgcn_sparse_mma.cpp @@ -1,8 +1,10 @@ // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT +#include #include #include +#include #include "ck_tile/core/arch/arch.hpp" #include "ck_tile/core/arch/mma/amdgcn_mma.hpp" @@ -10,6 +12,9 @@ #include "ck_tile/core/arch/mma/mma_selector.hpp" #include "ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp" #include +#include "ck_tile/core/numeric/bfloat16.hpp" +#include "ck_tile/core/numeric/float8.hpp" +#include "ck_tile/core/numeric/half.hpp" #include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/host/hip_check_error.hpp" #include "ck_tile/core/arch/mma/mma_traits.hpp" @@ -298,3 +303,80 @@ TEST(SparseMMATrait, MmaSelector_Sparse_F16_F16_F32_16x16x32_Real) HIP_CHECK_ERROR(hipFree(d_c)); HIP_CHECK_ERROR(hipFree(d_out)); } + +template +__global__ void test_sparse_transform(void* a, void* idx) +{ + using ResultT = decltype(SparseCompressTransform<2>::exec(*static_cast(a), + *reinterpret_cast(idx))); + *reinterpret_cast*>(a) = + SparseCompressTransform<2>::exec(*static_cast(a), *reinterpret_cast(idx)); +} + +// 1. Basic correctness: valid divisible sizes +template +void sparse_transform_test_case() +{ + int devCount; + hipDevice_t dev; + HIP_CHECK_ERROR(hipGetDevice(&dev)); + HIP_CHECK_ERROR(hipGetDeviceCount(&devCount)); + + hipDeviceProp_t devProp; + HIP_CHECK_ERROR(hipGetDeviceProperties(&devProp, dev)); + + auto currentArchId = hip_device_prop_gcn_arch_name_to_amdgcn_target_id(devProp.gcnArchName); + bool hasDevice = static_cast(devCount > 0); + + // TODO: c++20 add check for arch id + if(!hasDevice || (currentArchId == amdgcn_target_id::HOST)) + { + GTEST_SKIP() << "No HIP device found. Skipping test."; + } + + std::vector v(NUM); + for(int i = 0; i < NUM; ++i) + { + v[i] = i % 2 == 0 ? i + 1 : 0; + } + + float* d_v; + int32_t* d_idx; + + static constexpr auto Size = sizeof(Type) * NUM; + HIP_CHECK_ERROR(hipMalloc(&d_v, Size)); + HIP_CHECK_ERROR(hipMalloc(&d_idx, sizeof(int32_t))); + + // Copy inputs to device + HIP_CHECK_ERROR(hipMemcpy(d_v, v.data(), Size, hipMemcpyHostToDevice)); + + test_sparse_transform><<<1, 32>>>(d_v, d_idx); + HIP_CHECK_ERROR(hipDeviceSynchronize()); + + std::vector h_out(NUM / RATIO, static_cast(0)); + HIP_CHECK_ERROR(hipMemcpy(h_out.data(), d_v, Size / RATIO, hipMemcpyDeviceToHost)); + int32_t h_idx; + HIP_CHECK_ERROR(hipMemcpy(&h_idx, d_idx, sizeof(float), hipMemcpyDeviceToHost)); + + EXPECT_NE(h_idx, -1) << "idx should have been written"; + if constexpr(NUM == 8) + { + EXPECT_EQ(h_idx, 0b10001000); + } + else if constexpr(NUM == 16) + { + EXPECT_EQ(h_idx, 0b1000100010001000); + } + for(int i = 0; i < NUM / RATIO; ++i) + { + EXPECT_EQ(h_out[i], v[i * 2]); + } +} + +TEST(SparseTransformsTest, ValidCompressionRatio) +{ + // TODO: extend those when new sparse builtins are + // introduced and use different type combinations + sparse_transform_test_case<8, 2, fp16_t>(); + sparse_transform_test_case<16, 2, fp16_t>(); +} From 1240413178bb6bf97805d92a2780bd5b8a4bd28f Mon Sep 17 00:00:00 2001 From: Chris Tsiaousis Date: Thu, 19 Mar 2026 09:02:19 +0000 Subject: [PATCH 13/16] Added test for 'MmaPipelineOptionFlags' Signed-off-by: Chris Tsiaousis --- .../ck_tile/core/arch/mma/mma_pipeline.hpp | 12 +++- .../ck_tile/core/arch/mma/mma_wavewise.hpp | 4 +- .../arch/mma/sparse/sparse_mma_pipeline.hpp | 4 +- .../test/ck_tile/core/arch/mma/CMakeLists.txt | 3 + .../arch/mma/test_amdgcn_mma_pipeline.cpp | 63 +++++++++++++++++++ 5 files changed, 81 insertions(+), 5 deletions(-) create mode 100644 projects/composablekernel/test/ck_tile/core/arch/mma/test_amdgcn_mma_pipeline.cpp diff --git a/projects/composablekernel/include/ck_tile/core/arch/mma/mma_pipeline.hpp b/projects/composablekernel/include/ck_tile/core/arch/mma/mma_pipeline.hpp index 4213ec49158b..e903d70b2ff6 100644 --- a/projects/composablekernel/include/ck_tile/core/arch/mma/mma_pipeline.hpp +++ b/projects/composablekernel/include/ck_tile/core/arch/mma/mma_pipeline.hpp @@ -60,13 +60,23 @@ struct MmaPipelineOptionFlags result.mFlags = ~result.mFlags; return result; } + constexpr bool testFlag(MmaPipelineOptionFlag flag) const + { + return (flag == MmaPipelineOptionFlag::NONE) ? mFlags == toType(flag) : *this & flag; + } constexpr operator bool() const { return mFlags != toType(MmaPipelineOptionFlag::NONE); } + constexpr bool operator==(Type rhs) const { return mFlags == rhs; } private: Type mFlags; static constexpr Type toType(MmaPipelineOptionFlag f) { return static_cast(f); } }; +constexpr bool operator==(MmaPipelineOptionFlags::Type lhs, const MmaPipelineOptionFlags& rhs) +{ + return rhs == lhs; +} + // TODO: c++20: use MmaPipelineOptionFlags directly template struct MmaPipelineBase @@ -95,7 +105,7 @@ struct MmaPipelineBase template constexpr CK_TILE_DEVICE static bool hasFlag() { - return Flags & Flag; + return Flags.testFlag(Flag); } template diff --git a/projects/composablekernel/include/ck_tile/core/arch/mma/mma_wavewise.hpp b/projects/composablekernel/include/ck_tile/core/arch/mma/mma_wavewise.hpp index 3a0b26dd674a..1dfa4e132418 100644 --- a/projects/composablekernel/include/ck_tile/core/arch/mma/mma_wavewise.hpp +++ b/projects/composablekernel/include/ck_tile/core/arch/mma/mma_wavewise.hpp @@ -74,10 +74,10 @@ template ::SelectedTransforms> // clang-format off -struct WaveWiseMma : public MmaPipelineBase(MmaPipelineOptionFlag::NONE), +struct WaveWiseMma : public MmaPipelineBase(MmaPipelineOptionFlag::NONE), // TODO: c++20: use MmaPipelineOptionFlags directly WaveWiseMma> { - using Base = MmaPipelineBase(MmaPipelineOptionFlag::NONE), + using Base = MmaPipelineBase(MmaPipelineOptionFlag::NONE), // TODO: c++20: use MmaPipelineOptionFlags directly WaveWiseMma>; // clang-format on using FragWiseMmaOp = MmaOp; diff --git a/projects/composablekernel/include/ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp b/projects/composablekernel/include/ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp index 3cb23bde5323..e54b5c3385bf 100644 --- a/projects/composablekernel/include/ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp +++ b/projects/composablekernel/include/ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp @@ -34,11 +34,11 @@ template ::SelectedTransforms> // clang-format off -struct SparseMma : public MmaPipelineBase(MmaPipelineOptionFlag::COMPRESS_A), +struct SparseMma : public MmaPipelineBase(MmaPipelineOptionFlag::COMPRESS_A), // TODO: c++20: use MmaPipelineOptionFlags directly SparseMma> { static_assert(MmaOpTraits::IsSupported && MmaOpTraits::IsSparse); - using Base = MmaPipelineBase(MmaPipelineOptionFlag::COMPRESS_A), + using Base = MmaPipelineBase(MmaPipelineOptionFlag::COMPRESS_A), // TODO: c++20: use MmaPipelineOptionFlags directly SparseMma>; // clang-format on diff --git a/projects/composablekernel/test/ck_tile/core/arch/mma/CMakeLists.txt b/projects/composablekernel/test/ck_tile/core/arch/mma/CMakeLists.txt index cd77589c4d86..fc21e16fbcdb 100644 --- a/projects/composablekernel/test/ck_tile/core/arch/mma/CMakeLists.txt +++ b/projects/composablekernel/test/ck_tile/core/arch/mma/CMakeLists.txt @@ -27,3 +27,6 @@ else() message(DEBUG "Skipping gfx9|gfx11|gfx12 mma layout validation tests for current target") endif() +add_gtest_executable(test_amdgcn_mma_pipeline test_amdgcn_mma_pipeline.cpp) +target_compile_options(test_amdgcn_mma_layout PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + diff --git a/projects/composablekernel/test/ck_tile/core/arch/mma/test_amdgcn_mma_pipeline.cpp b/projects/composablekernel/test/ck_tile/core/arch/mma/test_amdgcn_mma_pipeline.cpp new file mode 100644 index 000000000000..83c1d896efe5 --- /dev/null +++ b/projects/composablekernel/test/ck_tile/core/arch/mma/test_amdgcn_mma_pipeline.cpp @@ -0,0 +1,63 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include + +#include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/core/arch/mma/mma_pipeline.hpp" + +namespace { +using namespace ck_tile::core::arch::mma; +} + +TEST(MmaPipelineOptionFlagsTests, ConversionTests) +{ + MmaPipelineOptionFlags flags_0{}; + MmaPipelineOptionFlags flags_1{MmaPipelineOptionFlag::C_TRANSPOSE}; + MmaPipelineOptionFlags flags_2{MmaPipelineOptionFlag::COMPRESS_A}; + MmaPipelineOptionFlags flags_3{0b11}; + + EXPECT_EQ(flags_0, 0); + EXPECT_TRUE(flags_0.testFlag(MmaPipelineOptionFlag::NONE)); + + EXPECT_EQ(flags_1, 1); + EXPECT_TRUE(flags_1.testFlag(MmaPipelineOptionFlag::C_TRANSPOSE)); + + EXPECT_EQ(flags_2, 2); + EXPECT_TRUE(flags_2.testFlag(MmaPipelineOptionFlag::COMPRESS_A)); + + EXPECT_EQ(flags_3, 3); + EXPECT_TRUE(flags_3.testFlag(MmaPipelineOptionFlag::COMPRESS_A)); + EXPECT_TRUE(flags_3.testFlag(MmaPipelineOptionFlag::C_TRANSPOSE)); +} + +TEST(MmaPipelineOptionFlagsTests, OperatorsTests) +{ + MmaPipelineOptionFlags flags{}; + + EXPECT_TRUE(flags.testFlag(MmaPipelineOptionFlag::NONE)); + + flags |= MmaPipelineOptionFlag::C_TRANSPOSE; + + EXPECT_FALSE(flags.testFlag(MmaPipelineOptionFlag::NONE)); + EXPECT_TRUE(flags.testFlag(MmaPipelineOptionFlag::C_TRANSPOSE)); + + flags |= MmaPipelineOptionFlag::COMPRESS_A; + + EXPECT_FALSE(flags.testFlag(MmaPipelineOptionFlag::NONE)); + EXPECT_TRUE(flags.testFlag(MmaPipelineOptionFlag::C_TRANSPOSE)); + EXPECT_TRUE(flags.testFlag(MmaPipelineOptionFlag::COMPRESS_A)); + + flags &= MmaPipelineOptionFlag::COMPRESS_A; + + EXPECT_FALSE(flags.testFlag(MmaPipelineOptionFlag::NONE)); + EXPECT_FALSE(flags.testFlag(MmaPipelineOptionFlag::C_TRANSPOSE)); + EXPECT_TRUE(flags.testFlag(MmaPipelineOptionFlag::COMPRESS_A)); + + EXPECT_FALSE((~flags).testFlag(MmaPipelineOptionFlag::NONE)); + EXPECT_TRUE((~flags).testFlag(MmaPipelineOptionFlag::C_TRANSPOSE)); + EXPECT_FALSE((~flags).testFlag(MmaPipelineOptionFlag::COMPRESS_A)); +} From d8bc46693d0b53e7810d0293aa9005f8eb113706 Mon Sep 17 00:00:00 2001 From: Chris Tsiaousis Date: Thu, 19 Mar 2026 10:43:58 +0000 Subject: [PATCH 14/16] Deduce kCompressionRatio automatically from MmaOpFamily Signed-off-by: Chris Tsiaousis --- .../include/ck_tile/core/arch/mma/amdgcn_mma.hpp | 7 +++---- .../ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp | 2 +- .../ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp | 2 +- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/projects/composablekernel/include/ck_tile/core/arch/mma/amdgcn_mma.hpp b/projects/composablekernel/include/ck_tile/core/arch/mma/amdgcn_mma.hpp index 797c5e23872b..10117a79bbfd 100644 --- a/projects/composablekernel/include/ck_tile/core/arch/mma/amdgcn_mma.hpp +++ b/projects/composablekernel/include/ck_tile/core/arch/mma/amdgcn_mma.hpp @@ -132,8 +132,7 @@ template + MmaOpFamily OpFamily_> struct amdgcn_mma_base { using OpType = OpType_; @@ -157,8 +156,8 @@ struct amdgcn_mma_base static constexpr index_t kBRepeat = kBRepeat_; // RDNA3 repetition and MFMA block-hiding static constexpr index_t kCMPerLane = kCMPerLane_; // M2 * M0 static constexpr index_t kCMNumAccess = kCMNumAccess_; // M2 - static constexpr index_t kCompressionRatio = - kCompressionRatio_; // Sparse intrisics matrix A compression + // K-dimension compression ratio for A matrix, always 2 for sparse intrinsics + static constexpr index_t kCompressionRatio = (OpFamily == MmaOpFamily::SPARSE) ? 2 : 1; // Register types (derived) static constexpr index_t WaveSize = WaveSize_; diff --git a/projects/composablekernel/include/ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp b/projects/composablekernel/include/ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp index 41fa89b9266f..781d496e5a81 100644 --- a/projects/composablekernel/include/ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp +++ b/projects/composablekernel/include/ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp @@ -26,7 +26,7 @@ template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | struct amdgcn_mma> -: amdgcn_mma_base +: amdgcn_mma_base // clang-format on { CK_TILE_DEVICE static auto diff --git a/projects/composablekernel/include/ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp b/projects/composablekernel/include/ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp index e9638fc94beb..0648a45b2914 100644 --- a/projects/composablekernel/include/ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp +++ b/projects/composablekernel/include/ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp @@ -16,7 +16,7 @@ namespace ck_tile::core::arch::mma { template // clang-format off struct amdgcn_mma> -: amdgcn_mma_base +: amdgcn_mma_base // clang-format on { CK_TILE_DEVICE static auto From 0283390e2d60d21cc8ae22b55ad1f36178781f53 Mon Sep 17 00:00:00 2001 From: Chris Tsiaousis Date: Thu, 19 Mar 2026 11:19:51 +0000 Subject: [PATCH 15/16] De-couple intertwined compile passes for host/device code Allow pipeline instantiations on the host side without having to constexpr conditionaly use them within the kernel. Also utilize the warning added to the host/unsupported amdgcn struct. Signed-off-by: Chris Tsiaousis --- .../ck_tile/core/arch/mma/mma_pipeline.hpp | 21 ++++++-- .../ck_tile/core/arch/mma/mma_transforms.hpp | 2 + .../arch/mma/sparse/sparse_mma_pipeline.hpp | 3 +- .../core/arch/mma/test_amdgcn_sparse_mma.cpp | 34 ++++++------- .../arch/mma/test_amdgcn_wavewise_mma.cpp | 50 +++++++------------ 5 files changed, 53 insertions(+), 57 deletions(-) diff --git a/projects/composablekernel/include/ck_tile/core/arch/mma/mma_pipeline.hpp b/projects/composablekernel/include/ck_tile/core/arch/mma/mma_pipeline.hpp index e903d70b2ff6..b45652b24be7 100644 --- a/projects/composablekernel/include/ck_tile/core/arch/mma/mma_pipeline.hpp +++ b/projects/composablekernel/include/ck_tile/core/arch/mma/mma_pipeline.hpp @@ -124,11 +124,22 @@ struct MmaPipelineBase template CK_TILE_DEVICE static decltype(auto) exec(VecTA&& a, VecTB&& b, VecTC&& accum) { - // TODO: c++20: Call template functions with MmaPipelineOptionFlags directly - auto pre = Derived::template preApply( - std::forward(a), std::forward(b), std::forward(accum)); - Derived::execImpl(pre); - return Derived::template postApply(std::move(pre)); + if constexpr(MmaOpTraits::IsSupported) + { + // TODO: c++20: Call template functions with MmaPipelineOptionFlags directly + auto pre = Derived::template preApply( + std::forward(a), std::forward(b), std::forward(accum)); + Derived::execImpl(pre); + return Derived::template postApply(std::move(pre)); + } + else + { + // Return the unsupported exec. This should print a runtime warning. (amdgcn_mma.hpp) + // Code should not reach here, but HOST/DEVICE compile passes are + // weirdly intertwined and instead of having constexpr in the calling + // site (tests) we do this. See also changes by this commit. + return Derived::FragWiseMmaOp::exec({}, {}, {}); + } } }; diff --git a/projects/composablekernel/include/ck_tile/core/arch/mma/mma_transforms.hpp b/projects/composablekernel/include/ck_tile/core/arch/mma/mma_transforms.hpp index c41aa0ae1190..854254023320 100644 --- a/projects/composablekernel/include/ck_tile/core/arch/mma/mma_transforms.hpp +++ b/projects/composablekernel/include/ck_tile/core/arch/mma/mma_transforms.hpp @@ -42,6 +42,8 @@ template struct MmaTransformsDefaultSelector { using SelectedTransforms = MmaDefaultPassThroughTransforms; + static_assert(CompilerTarget::TARGET_ID == amdgcn_target_id::HOST, + "Device code should use another specialization."); }; #if CK_TILE_CONCEPTS diff --git a/projects/composablekernel/include/ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp b/projects/composablekernel/include/ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp index e54b5c3385bf..ff9bbefc4b01 100644 --- a/projects/composablekernel/include/ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp +++ b/projects/composablekernel/include/ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp @@ -37,11 +37,12 @@ template (MmaPipelineOptionFlag::COMPRESS_A), // TODO: c++20: use MmaPipelineOptionFlags directly SparseMma> { - static_assert(MmaOpTraits::IsSupported && MmaOpTraits::IsSparse); using Base = MmaPipelineBase(MmaPipelineOptionFlag::COMPRESS_A), // TODO: c++20: use MmaPipelineOptionFlags directly SparseMma>; // clang-format on + using FragWiseMmaOp = MmaOp; // Expose the selected MmaOp + // Calculate the uncompressed A vector type struct InternalAVecCalculator { diff --git a/projects/composablekernel/test/ck_tile/core/arch/mma/test_amdgcn_sparse_mma.cpp b/projects/composablekernel/test/ck_tile/core/arch/mma/test_amdgcn_sparse_mma.cpp index 7e172ca2fe84..9789b0468386 100644 --- a/projects/composablekernel/test/ck_tile/core/arch/mma/test_amdgcn_sparse_mma.cpp +++ b/projects/composablekernel/test/ck_tile/core/arch/mma/test_amdgcn_sparse_mma.cpp @@ -185,30 +185,26 @@ __global__ void test_sparse_accum_over_k(void* a, void* b, void* c, void* out) CompilerTarget, MmaOpFamily::SPARSE>::SelectedOp; - using MmaTraits = MmaOpTraits; + using Pipeline = + SparseMma; - if constexpr(MmaTraits::IsSupported) - { - using Pipeline = SparseMma; - - using AVecType = typename Pipeline::AVecType; - using BVecType = typename Pipeline::BVecType; - using CVecType = typename Pipeline::CVecType; + using AVecType = typename Pipeline::AVecType; + using BVecType = typename Pipeline::BVecType; + using CVecType = typename Pipeline::CVecType; - static constexpr uint32_t kIters = WaveTileK / MmaOp::kK; + static constexpr uint32_t kIters = WaveTileK / MmaOp::kK; - // Initialize the accumulator - CVecType result = *reinterpret_cast(c); + // Initialize the accumulator + CVecType result = *reinterpret_cast(c); - // Accumulate input AxB over FragK/BlockK iterations - for(uint32_t i = 0; i < kIters; ++i) - { - result = Pipeline::exec( - *reinterpret_cast(a), *reinterpret_cast(b), result); - } - - *reinterpret_cast(out) = result; + // Accumulate input AxB over FragK/BlockK iterations + for(uint32_t i = 0; i < kIters; ++i) + { + result = Pipeline::exec( + *reinterpret_cast(a), *reinterpret_cast(b), result); } + + *reinterpret_cast(out) = result; } // Live test on real hardware for sparse selection and execution. diff --git a/projects/composablekernel/test/ck_tile/core/arch/mma/test_amdgcn_wavewise_mma.cpp b/projects/composablekernel/test/ck_tile/core/arch/mma/test_amdgcn_wavewise_mma.cpp index 3f05552b415f..5a4b9e2ab82d 100644 --- a/projects/composablekernel/test/ck_tile/core/arch/mma/test_amdgcn_wavewise_mma.cpp +++ b/projects/composablekernel/test/ck_tile/core/arch/mma/test_amdgcn_wavewise_mma.cpp @@ -23,38 +23,24 @@ template ::SelectedOp; - - using MmaTraits = MmaOpTraits; - - if constexpr(MmaTraits::IsSupported) - { - using Pipeline = WaveWiseMma; - - using AVecType = typename Pipeline::AVecType; - using BVecType = typename Pipeline::BVecType; - using CVecType = typename Pipeline::CVecType; - - Pipeline::exec(*reinterpret_cast(a), - *reinterpret_cast(b), - *reinterpret_cast(c)); - } + + using Pipeline = WaveWiseMma; + + using AVecType = typename Pipeline::AVecType; + using BVecType = typename Pipeline::BVecType; + using CVecType = typename Pipeline::CVecType; + + Pipeline::exec(*reinterpret_cast(a), + *reinterpret_cast(b), + *reinterpret_cast(c)); } TEST(WaveWiseMmaPipeline, testKIter) From 4c8c2625e9014a75edbacb744030fae9d6ce8ec2 Mon Sep 17 00:00:00 2001 From: Chris Tsiaousis Date: Thu, 19 Mar 2026 15:00:55 +0000 Subject: [PATCH 16/16] Reorganise pipeline test code, re-use boilerplate and add TransposeC support Signed-off-by: Chris Tsiaousis --- .../composablekernel/include/ck_tile/core.hpp | 1 + .../ck_tile/core/arch/mma/mma_pipeline.hpp | 8 +- .../arch/mma/sparse/sparse_mma_pipeline.hpp | 15 +- .../test/ck_tile/core/arch/mma/CMakeLists.txt | 6 +- .../mma/pipeline/pipeline_tests_helper.hpp | 107 ++++++++++++++ .../test_amdgcn_mma_pipeline.cpp | 0 .../{ => pipeline}/test_amdgcn_sparse_mma.cpp | 129 ++++------------- .../mma/pipeline/test_amdgcn_wavewise_mma.cpp | 88 ++++++++++++ .../arch/mma/test_amdgcn_wavewise_mma.cpp | 133 ------------------ 9 files changed, 239 insertions(+), 248 deletions(-) create mode 100644 projects/composablekernel/test/ck_tile/core/arch/mma/pipeline/pipeline_tests_helper.hpp rename projects/composablekernel/test/ck_tile/core/arch/mma/{ => pipeline}/test_amdgcn_mma_pipeline.cpp (100%) rename projects/composablekernel/test/ck_tile/core/arch/mma/{ => pipeline}/test_amdgcn_sparse_mma.cpp (70%) create mode 100644 projects/composablekernel/test/ck_tile/core/arch/mma/pipeline/test_amdgcn_wavewise_mma.cpp delete mode 100644 projects/composablekernel/test/ck_tile/core/arch/mma/test_amdgcn_wavewise_mma.cpp diff --git a/projects/composablekernel/include/ck_tile/core.hpp b/projects/composablekernel/include/ck_tile/core.hpp index 3414fc9ed380..7367da30727c 100644 --- a/projects/composablekernel/include/ck_tile/core.hpp +++ b/projects/composablekernel/include/ck_tile/core.hpp @@ -57,6 +57,7 @@ #include "ck_tile/core/container/tuple.hpp" #include "ck_tile/core/numeric/bfloat16.hpp" #include "ck_tile/core/numeric/e8m0.hpp" +#include "ck_tile/core/numeric/ext_vector_base.hpp" #include "ck_tile/core/numeric/float8.hpp" #include "ck_tile/core/numeric/half.hpp" #include "ck_tile/core/numeric/int8.hpp" diff --git a/projects/composablekernel/include/ck_tile/core/arch/mma/mma_pipeline.hpp b/projects/composablekernel/include/ck_tile/core/arch/mma/mma_pipeline.hpp index b45652b24be7..cb8ddf49566f 100644 --- a/projects/composablekernel/include/ck_tile/core/arch/mma/mma_pipeline.hpp +++ b/projects/composablekernel/include/ck_tile/core/arch/mma/mma_pipeline.hpp @@ -82,8 +82,6 @@ template struct MmaPipelineBase { static constexpr auto Flags = MmaPipelineOptionFlags(Flags_); - // TODO: Implement those cases - static_assert(!(Flags & MmaPipelineOptionFlag::C_TRANSPOSE), "Flag not yet implemented"); private: template @@ -128,7 +126,11 @@ struct MmaPipelineBase { // TODO: c++20: Call template functions with MmaPipelineOptionFlags directly auto pre = Derived::template preApply( - std::forward(a), std::forward(b), std::forward(accum)); + hasFlag() ? std::forward(b) + : std::forward(a), + hasFlag() ? std::forward(a) + : std::forward(b), + std::forward(accum)); Derived::execImpl(pre); return Derived::template postApply(std::move(pre)); } diff --git a/projects/composablekernel/include/ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp b/projects/composablekernel/include/ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp index ff9bbefc4b01..0d1e8490358f 100644 --- a/projects/composablekernel/include/ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp +++ b/projects/composablekernel/include/ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp @@ -12,6 +12,11 @@ namespace ck_tile::core::arch::mma { +namespace sparse::detail { +// TODO: c++20: return MmaPipelineOptionFlags directly] +constexpr inline int getFlags() { return static_cast(MmaPipelineOptionFlag::COMPRESS_A); } +} // namespace sparse::detail + template ::SelectedTransforms> // clang-format off -struct SparseMma : public MmaPipelineBase(MmaPipelineOptionFlag::COMPRESS_A), // TODO: c++20: use MmaPipelineOptionFlags directly - SparseMma> +struct SparseMma : public MmaPipelineBase> { - using Base = MmaPipelineBase(MmaPipelineOptionFlag::COMPRESS_A), // TODO: c++20: use MmaPipelineOptionFlags directly - SparseMma>; + using Base = MmaPipelineBase>; // clang-format on using FragWiseMmaOp = MmaOp; // Expose the selected MmaOp @@ -65,7 +68,7 @@ struct SparseMma : public MmaPipelineBase(MmaPipelineOptionFlag template CK_TILE_DEVICE static decltype(auto) preApply(VecTA&& a, VecTB&& b, VecTC&& accum) { - static_assert(Flags == MmaPipelineOptionFlags(MmaPipelineOptionFlag::COMPRESS_A)); + static_assert(MmaPipelineOptionFlags(Flags).testFlag(MmaPipelineOptionFlag::COMPRESS_A)); static_assert( std::is_same_v>); @@ -88,7 +91,7 @@ struct SparseMma : public MmaPipelineBase(MmaPipelineOptionFlag template CK_TILE_DEVICE static decltype(auto) postApply(std::tuple&& vecs) { - static_assert(Flags == MmaPipelineOptionFlags(MmaPipelineOptionFlag::COMPRESS_A)); + static_assert(MmaPipelineOptionFlags(Flags).testFlag(MmaPipelineOptionFlag::COMPRESS_A)); auto& [a_frag, b_frag, c_frag, idx] = vecs; // Convert native vector results back to the output fragment format diff --git a/projects/composablekernel/test/ck_tile/core/arch/mma/CMakeLists.txt b/projects/composablekernel/test/ck_tile/core/arch/mma/CMakeLists.txt index fc21e16fbcdb..d31c8e972b7f 100644 --- a/projects/composablekernel/test/ck_tile/core/arch/mma/CMakeLists.txt +++ b/projects/composablekernel/test/ck_tile/core/arch/mma/CMakeLists.txt @@ -8,13 +8,13 @@ if(CK_USE_OCP_FP8) endif() if(GPU_TARGETS MATCHES "gfx9|gfx12") - add_gtest_executable(test_amdgcn_sparse_mma test_amdgcn_sparse_mma.cpp) + add_gtest_executable(test_amdgcn_sparse_mma pipeline/test_amdgcn_sparse_mma.cpp) target_compile_options(test_amdgcn_sparse_mma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) endif() if(GPU_TARGETS MATCHES "gfx9") add_gtest_executable(test_amdgcn_mma test_amdgcn_mma.cpp) target_compile_options(test_amdgcn_mma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) - add_gtest_executable(test_amdgcn_wavewise_mma test_amdgcn_wavewise_mma.cpp) + add_gtest_executable(test_amdgcn_wavewise_mma pipeline/test_amdgcn_wavewise_mma.cpp) target_compile_options(test_amdgcn_wavewise_mma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) else() message(DEBUG "Skipping ck_tile_gemm tests for current target") @@ -27,6 +27,6 @@ else() message(DEBUG "Skipping gfx9|gfx11|gfx12 mma layout validation tests for current target") endif() -add_gtest_executable(test_amdgcn_mma_pipeline test_amdgcn_mma_pipeline.cpp) +add_gtest_executable(test_amdgcn_mma_pipeline pipeline/test_amdgcn_mma_pipeline.cpp) target_compile_options(test_amdgcn_mma_layout PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) diff --git a/projects/composablekernel/test/ck_tile/core/arch/mma/pipeline/pipeline_tests_helper.hpp b/projects/composablekernel/test/ck_tile/core/arch/mma/pipeline/pipeline_tests_helper.hpp new file mode 100644 index 000000000000..74a5a9d58734 --- /dev/null +++ b/projects/composablekernel/test/ck_tile/core/arch/mma/pipeline/pipeline_tests_helper.hpp @@ -0,0 +1,107 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include + +#include "ck_tile/core/arch/arch.hpp" +#include +#include "ck_tile/host/hip_check_error.hpp" + +#include "../get_wave_size_helper.hpp" + +template +struct MmaPipelineTest +{ + using AType = AType_; + using BType = BType_; + using CType = CType_; + static constexpr auto WaveTileM = WaveTileM_; + static constexpr auto WaveTileN = WaveTileN_; + static constexpr auto WaveTileK = WaveTileK_; + + void test_pipeline(std::function shouldSkip, + std::function kernel, + std::function getExpected) + { + using namespace ck_tile; + using namespace ck_tile::core::arch; + + int devCount; + hipDevice_t dev; + HIP_CHECK_ERROR(hipGetDevice(&dev)); + HIP_CHECK_ERROR(hipGetDeviceCount(&devCount)); + + hipDeviceProp_t devProp; + HIP_CHECK_ERROR(hipGetDeviceProperties(&devProp, dev)); + + auto currentArchId = hip_device_prop_gcn_arch_name_to_amdgcn_target_id(devProp.gcnArchName); + bool hasDevice = static_cast(devCount > 0); + int deviceWarpSize = devProp.warpSize; + + if(!hasDevice || shouldSkip(currentArchId)) + { + GTEST_SKIP() << "No HIP device found. Skipping test."; + } + + // WaveTile size, also the expected fragment size (MmaTile) from the selector. + // Note: Actual FragK might be slightly different due to hardware implementation, but the + // test_accum_over_k kernel will loop over the K dimension to ensure that the total K is + // correct. + static constexpr uint32_t FragM = WaveTileM; + static constexpr uint32_t FragN = WaveTileN; + static constexpr uint32_t FragK = WaveTileK; + + // The number of elements per thread + uint32_t AElements = FragM * FragK / deviceWarpSize; + uint32_t BElements = FragN * FragK / deviceWarpSize; + uint32_t CElements = FragM * FragN / deviceWarpSize; + + uint32_t ASize = AElements * sizeof(AType); + uint32_t BSize = BElements * sizeof(BType); + uint32_t CSize = CElements * sizeof(CType); + + // Initialize A and B to all 1's, C to all 0's + std::vector h_a(AElements, static_cast(1)); + std::vector h_b(BElements, static_cast(1)); + std::vector h_c(CElements, static_cast(0)); + std::vector h_out(CElements, static_cast(0)); + + AType* d_a; + BType* d_b; + CType* d_c; + CType* d_out; + + HIP_CHECK_ERROR(hipMalloc(&d_a, ASize)); + HIP_CHECK_ERROR(hipMalloc(&d_b, BSize)); + HIP_CHECK_ERROR(hipMalloc(&d_c, CSize)); + HIP_CHECK_ERROR(hipMalloc(&d_out, CSize)); + + // Copy inputs to device + HIP_CHECK_ERROR(hipMemcpy(d_a, h_a.data(), ASize, hipMemcpyHostToDevice)); + HIP_CHECK_ERROR(hipMemcpy(d_b, h_b.data(), BSize, hipMemcpyHostToDevice)); + HIP_CHECK_ERROR(hipMemcpy(d_c, h_c.data(), CSize, hipMemcpyHostToDevice)); + + const auto wave_size = getDeviceWaveSize(); + kernel(wave_size, d_a, d_b, d_c, d_out); + HIP_CHECK_ERROR(hipDeviceSynchronize()); + + HIP_CHECK_ERROR(hipMemcpy(h_out.data(), d_out, CSize, hipMemcpyDeviceToHost)); + + // Output should be FragK for all elements, because the inputs are all 1's + for(size_t i = 0; i < CElements; ++i) + { + EXPECT_NEAR(h_out[i], getExpected(FragK), 1e-3); + } + + HIP_CHECK_ERROR(hipFree(d_a)); + HIP_CHECK_ERROR(hipFree(d_b)); + HIP_CHECK_ERROR(hipFree(d_c)); + HIP_CHECK_ERROR(hipFree(d_out)); + } +}; diff --git a/projects/composablekernel/test/ck_tile/core/arch/mma/test_amdgcn_mma_pipeline.cpp b/projects/composablekernel/test/ck_tile/core/arch/mma/pipeline/test_amdgcn_mma_pipeline.cpp similarity index 100% rename from projects/composablekernel/test/ck_tile/core/arch/mma/test_amdgcn_mma_pipeline.cpp rename to projects/composablekernel/test/ck_tile/core/arch/mma/pipeline/test_amdgcn_mma_pipeline.cpp diff --git a/projects/composablekernel/test/ck_tile/core/arch/mma/test_amdgcn_sparse_mma.cpp b/projects/composablekernel/test/ck_tile/core/arch/mma/pipeline/test_amdgcn_sparse_mma.cpp similarity index 70% rename from projects/composablekernel/test/ck_tile/core/arch/mma/test_amdgcn_sparse_mma.cpp rename to projects/composablekernel/test/ck_tile/core/arch/mma/pipeline/test_amdgcn_sparse_mma.cpp index 9789b0468386..f1c998299ee8 100644 --- a/projects/composablekernel/test/ck_tile/core/arch/mma/test_amdgcn_sparse_mma.cpp +++ b/projects/composablekernel/test/ck_tile/core/arch/mma/pipeline/test_amdgcn_sparse_mma.cpp @@ -20,7 +20,7 @@ #include "ck_tile/core/arch/mma/mma_traits.hpp" #include "ck_tile/core/utility/type_traits.hpp" -#include "get_wave_size_helper.hpp" +#include "pipeline_tests_helper.hpp" using namespace ck_tile; using namespace ck_tile::core::arch; @@ -174,25 +174,13 @@ template __global__ void test_sparse_accum_over_k(void* a, void* b, void* c, void* out) { - using CompilerTarget = decltype(get_compiler_target()); - using MmaOp = typename MmaDefaultSelector::SelectedOp; - - using Pipeline = - SparseMma; + using Pipeline = SparseMma; using AVecType = typename Pipeline::AVecType; using BVecType = typename Pipeline::BVecType; using CVecType = typename Pipeline::CVecType; - static constexpr uint32_t kIters = WaveTileK / MmaOp::kK; + static constexpr uint32_t kIters = WaveTileK / Pipeline::FragWiseMmaOp::kK; // Initialize the accumulator CVecType result = *reinterpret_cast(c); @@ -210,94 +198,26 @@ __global__ void test_sparse_accum_over_k(void* a, void* b, void* c, void* out) // Live test on real hardware for sparse selection and execution. TEST(SparseMMATrait, MmaSelector_Sparse_F16_F16_F32_16x16x32_Real) { - int devCount; - hipDevice_t dev; - HIP_CHECK_ERROR(hipGetDevice(&dev)); - HIP_CHECK_ERROR(hipGetDeviceCount(&devCount)); - - hipDeviceProp_t devProp; - HIP_CHECK_ERROR(hipGetDeviceProperties(&devProp, dev)); - - auto currentArchId = hip_device_prop_gcn_arch_name_to_amdgcn_target_id(devProp.gcnArchName); - bool hasDevice = static_cast(devCount > 0); - int deviceWarpSize = devProp.warpSize; - - bool isSupportedWmma = (currentArchId >= amdgcn_target_id::GFX1200) && - (currentArchId <= amdgcn_target_id::GFX12_GENERIC); - bool isSupportedMfma = - (currentArchId >= amdgcn_target_id::GFX942) && (currentArchId <= amdgcn_target_id::GFX950); - // TODO: c++20 add check for arch id - if(!hasDevice || (currentArchId == amdgcn_target_id::HOST) || - !(isSupportedWmma || isSupportedMfma)) - { - GTEST_SKIP() << "No HIP device found. Skipping test."; - } - - using AType = fp16_t; - using BType = fp16_t; - using CType = fp32_t; - - // WaveTile size, also the expected fragment size (MmaTile) from the selector. - // Note: Actual FragK might be slightly different due to hardware implementation, but the - // test_accum_over_k kernel will loop over the K dimension to ensure that the total K is - // correct. - static constexpr uint32_t WaveTileM = 16; - static constexpr uint32_t WaveTileN = 16; - static constexpr uint32_t WaveTileK = 32; - static constexpr uint32_t FragM = WaveTileM; - static constexpr uint32_t FragN = WaveTileN; - static constexpr uint32_t FragK = WaveTileK; - - // The number of elements per thread - uint32_t AElements = FragM * FragK / deviceWarpSize; - uint32_t BElements = FragN * FragK / deviceWarpSize; - uint32_t CElements = FragM * FragN / deviceWarpSize; - - uint32_t ASize = AElements * sizeof(AType); - uint32_t BSize = BElements * sizeof(BType); - uint32_t CSize = CElements * sizeof(CType); - - // Initialize A and B to all 1's, C to all 0's - std::vector h_a(AElements, static_cast(1)); - std::vector h_b(BElements, static_cast(1)); - std::vector h_c(CElements, static_cast(0)); - std::vector h_out(CElements, static_cast(0)); - - AType* d_a; - BType* d_b; - CType* d_c; - CType* d_out; - - HIP_CHECK_ERROR(hipMalloc(&d_a, ASize)); - HIP_CHECK_ERROR(hipMalloc(&d_b, BSize)); - HIP_CHECK_ERROR(hipMalloc(&d_c, CSize)); - HIP_CHECK_ERROR(hipMalloc(&d_out, CSize)); - - // Copy inputs to device - HIP_CHECK_ERROR(hipMemcpy(d_a, h_a.data(), ASize, hipMemcpyHostToDevice)); - HIP_CHECK_ERROR(hipMemcpy(d_b, h_b.data(), BSize, hipMemcpyHostToDevice)); - HIP_CHECK_ERROR(hipMemcpy(d_c, h_c.data(), CSize, hipMemcpyHostToDevice)); - - const auto wave_size = getDeviceWaveSize(); - test_sparse_accum_over_k - <<<1, wave_size>>>(d_a, d_b, d_c, d_out); - HIP_CHECK_ERROR(hipDeviceSynchronize()); - - HIP_CHECK_ERROR(hipMemcpy(h_out.data(), d_out, CSize, hipMemcpyDeviceToHost)); - - // Output should be FragK for all elements, because the inputs are all 1's - for(size_t i = 0; i < CElements; ++i) - { - // In sparse only half of the A values are non-zero, thus the /2. - CType expected = static_cast(FragK) / 2; - - EXPECT_NEAR(h_out[i], expected, 1e-3); - } - - HIP_CHECK_ERROR(hipFree(d_a)); - HIP_CHECK_ERROR(hipFree(d_b)); - HIP_CHECK_ERROR(hipFree(d_c)); - HIP_CHECK_ERROR(hipFree(d_out)); + MmaPipelineTest<> test; + const auto should_skip = [](amdgcn_target_id currentArchId) { + bool isSupportedWmma = (currentArchId >= amdgcn_target_id::GFX1200) && + (currentArchId <= amdgcn_target_id::GFX12_GENERIC); + bool isSupportedMfma = (currentArchId >= amdgcn_target_id::GFX942) && + (currentArchId <= amdgcn_target_id::GFX950); + return ((currentArchId == amdgcn_target_id::HOST) || !(isSupportedWmma || isSupportedMfma)); + }; + const std::function validator = [](uint32_t fragK) { + return static_cast(fragK) / 2; + }; + const auto kernel = [](uint32_t waveSize, void* a, void* b, void* c, void* out) { + test_sparse_accum_over_k::AType, + MmaPipelineTest<>::BType, + MmaPipelineTest<>::CType, + MmaPipelineTest<>::WaveTileM, + MmaPipelineTest<>::WaveTileN, + MmaPipelineTest<>::WaveTileK><<<1, waveSize>>>(a, b, c, out); + }; + test.test_pipeline(should_skip, kernel, validator); } template @@ -367,6 +287,9 @@ void sparse_transform_test_case() { EXPECT_EQ(h_out[i], v[i * 2]); } + + HIP_CHECK_ERROR(hipFree(d_v)); + HIP_CHECK_ERROR(hipFree(d_idx)); } TEST(SparseTransformsTest, ValidCompressionRatio) diff --git a/projects/composablekernel/test/ck_tile/core/arch/mma/pipeline/test_amdgcn_wavewise_mma.cpp b/projects/composablekernel/test/ck_tile/core/arch/mma/pipeline/test_amdgcn_wavewise_mma.cpp new file mode 100644 index 000000000000..0b7badb1baaa --- /dev/null +++ b/projects/composablekernel/test/ck_tile/core/arch/mma/pipeline/test_amdgcn_wavewise_mma.cpp @@ -0,0 +1,88 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/core/arch/mma/mma_op_family.hpp" +#include "ck_tile/core/arch/mma/mma_wavewise.hpp" + +#include "pipeline_tests_helper.hpp" + +using namespace ck_tile; +using namespace ck_tile::core::arch; +using namespace ck_tile::core::arch::mma; + +template +__global__ void test_wavewise_pipeline(void* a, void* b, void* c, void* out) +{ + using CompilerTarget = decltype(get_compiler_target()); + + using Pipeline = WaveWiseMma; + + using AVecType = typename Pipeline::AVecType; + using BVecType = typename Pipeline::BVecType; + using CVecType = typename Pipeline::CVecType; + + auto result = + Pipeline::exec(*reinterpret_cast(a), + *reinterpret_cast(b), + *reinterpret_cast(c)); + + // *reinterpret_cast(out) = result; + memcpy(out, &result, sizeof(result)); +} + +namespace { +const auto should_skip = [](amdgcn_target_id currentArchId) { + bool isSupportedWmma = false; + bool isSupportedMfma = + (currentArchId >= amdgcn_target_id::GFX942) && (currentArchId <= amdgcn_target_id::GFX950); + return ((currentArchId == amdgcn_target_id::HOST) || !(isSupportedWmma || isSupportedMfma)); +}; +const std::function validator = [](uint32_t fragK) { + return static_cast(fragK); +}; +} // namespace + +TEST(WaveWiseMmaPipeline, testKIter) +{ + MmaPipelineTest<> test; + const auto kernel = [](uint32_t waveSize, void* a, void* b, void* c, void* out) { + test_wavewise_pipeline::AType, + MmaPipelineTest<>::BType, + MmaPipelineTest<>::CType, + MmaPipelineTest<>::WaveTileM, + MmaPipelineTest<>::WaveTileN, + MmaPipelineTest<>::WaveTileK, + false><<<1, waveSize>>>(a, b, c, out); + }; + test.test_pipeline(should_skip, kernel, validator); +} + +TEST(WaveWiseMmaPipeline, testKIterTransposeC) +{ + MmaPipelineTest<> test; + const auto kernel = [](uint32_t waveSize, void* a, void* b, void* c, void* out) { + test_wavewise_pipeline::AType, + MmaPipelineTest<>::BType, + MmaPipelineTest<>::CType, + MmaPipelineTest<>::WaveTileM, + MmaPipelineTest<>::WaveTileN, + MmaPipelineTest<>::WaveTileK, + true><<<1, waveSize>>>(a, b, c, out); + }; + test.test_pipeline(should_skip, kernel, validator); +} diff --git a/projects/composablekernel/test/ck_tile/core/arch/mma/test_amdgcn_wavewise_mma.cpp b/projects/composablekernel/test/ck_tile/core/arch/mma/test_amdgcn_wavewise_mma.cpp deleted file mode 100644 index 5a4b9e2ab82d..000000000000 --- a/projects/composablekernel/test/ck_tile/core/arch/mma/test_amdgcn_wavewise_mma.cpp +++ /dev/null @@ -1,133 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "ck_tile/core/arch/arch.hpp" -#include "ck_tile/core/arch/mma/mma_op_family.hpp" -#include "ck_tile/core/arch/mma/mma_wavewise.hpp" - -#include "get_wave_size_helper.hpp" - -#include -#include - -using namespace ck_tile; -using namespace ck_tile::core::arch; -using namespace ck_tile::core::arch::mma; - -template -__global__ void test_pipeline(void* a, void* b, void* c) -{ - using CompilerTarget = decltype(get_compiler_target()); - - using Pipeline = WaveWiseMma; - - using AVecType = typename Pipeline::AVecType; - using BVecType = typename Pipeline::BVecType; - using CVecType = typename Pipeline::CVecType; - - Pipeline::exec(*reinterpret_cast(a), - *reinterpret_cast(b), - *reinterpret_cast(c)); -} - -TEST(WaveWiseMmaPipeline, testKIter) -{ - int devCount; - hipDevice_t dev; - HIP_CHECK_ERROR(hipGetDevice(&dev)); - HIP_CHECK_ERROR(hipGetDeviceCount(&devCount)); - - hipDeviceProp_t devProp; - HIP_CHECK_ERROR(hipGetDeviceProperties(&devProp, dev)); - - auto currentArchId = hip_device_prop_gcn_arch_name_to_amdgcn_target_id(devProp.gcnArchName); - bool hasDevice = static_cast(devCount > 0); - int deviceWarpSize = devProp.warpSize; - - bool isSupportedWmma = false; - bool isSupportedMfma = - (currentArchId >= amdgcn_target_id::GFX942) && (currentArchId <= amdgcn_target_id::GFX950); - // TODO: c++20 add check for arch id - if(!hasDevice || (currentArchId == amdgcn_target_id::HOST) || - !(isSupportedWmma || isSupportedMfma)) - { - GTEST_SKIP() << "No HIP device found. Skipping test."; - } - - using AType = fp16_t; - using BType = fp16_t; - using CType = fp32_t; - - // WaveTile size, also the expected fragment size (MmaTile) from the selector. - // Note: Actual FragK might be slightly different due to hardware implementation, but the - // test_accum_over_k kernel will loop over the K dimension to ensure that the total K is - // correct. - static constexpr uint32_t WaveTileM = 16; - static constexpr uint32_t WaveTileN = 16; - static constexpr uint32_t WaveTileK = 32; - static constexpr uint32_t FragM = WaveTileM; - static constexpr uint32_t FragN = WaveTileN; - static constexpr uint32_t FragK = WaveTileK; - - // The number of elements per thread - uint32_t AElements = FragM * FragK / deviceWarpSize; - uint32_t BElements = FragN * FragK / deviceWarpSize; - uint32_t CElements = FragM * FragN / deviceWarpSize; - - uint32_t ASize = AElements * sizeof(AType); - uint32_t BSize = BElements * sizeof(BType); - uint32_t CSize = CElements * sizeof(CType); - - // Initialize A and B to all 1's, C to all 0's - std::vector h_a(AElements, static_cast(1)); - std::vector h_b(BElements, static_cast(1)); - std::vector h_c(CElements, static_cast(0)); - std::vector h_out(CElements, static_cast(0)); - - AType* d_a; - BType* d_b; - CType* d_c; - CType* d_out; - - HIP_CHECK_ERROR(hipMalloc(&d_a, ASize)); - HIP_CHECK_ERROR(hipMalloc(&d_b, BSize)); - HIP_CHECK_ERROR(hipMalloc(&d_c, CSize)); - HIP_CHECK_ERROR(hipMalloc(&d_out, CSize)); - - // Copy inputs to device - HIP_CHECK_ERROR(hipMemcpy(d_a, h_a.data(), ASize, hipMemcpyHostToDevice)); - HIP_CHECK_ERROR(hipMemcpy(d_b, h_b.data(), BSize, hipMemcpyHostToDevice)); - HIP_CHECK_ERROR(hipMemcpy(d_c, h_c.data(), CSize, hipMemcpyHostToDevice)); - - const auto wave_size = getDeviceWaveSize(); - test_pipeline<<<1, wave_size>>>(d_a, d_b, d_c); - HIP_CHECK_ERROR(hipDeviceSynchronize()); - - HIP_CHECK_ERROR(hipMemcpy(h_out.data(), d_out, CSize, hipMemcpyDeviceToHost)); - - // Output should be FragK for all elements, because the inputs are all 1's - for(size_t i = 0; i < CElements; ++i) - { - CType expected = static_cast(FragK); - - EXPECT_NEAR(h_out[i], expected, 1e-3); - } - - HIP_CHECK_ERROR(hipFree(d_a)); - HIP_CHECK_ERROR(hipFree(d_b)); - HIP_CHECK_ERROR(hipFree(d_c)); - HIP_CHECK_ERROR(hipFree(d_out)); -}