Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
f974282
Added default implementation for `MmaTransformsDefaultSelector`
chris-tsiaousis-hpc Mar 12, 2026
69e8516
Added the MmaPipelineBase struct and renamed `mma.hpp`
chris-tsiaousis-hpc Mar 16, 2026
29893f8
Added compress A functionality to sparse transforms
chris-tsiaousis-hpc Mar 16, 2026
49f35e9
Added test case for wavewise mma pipeline
chris-tsiaousis-hpc Mar 16, 2026
8921653
Made sparse amdgcn structs depict the actual builtin signature (halve…
chris-tsiaousis-hpc Mar 17, 2026
dc00433
Defined the sparse pipeline and used it in test
chris-tsiaousis-hpc Mar 17, 2026
6839f38
Made concept in line with the Pipeline interface
chris-tsiaousis-hpc Mar 17, 2026
28b3dd0
Expanded the MmaOpI concept to support an extra int arg to the exec f…
chris-tsiaousis-hpc Mar 17, 2026
0e1cfb3
Revert changes done to CK Tile's internal compress A function
chris-tsiaousis-hpc Mar 17, 2026
1b14452
Run clang-format after rebasing
chris-tsiaousis-hpc Mar 18, 2026
8088705
Addressed some review comments
chris-tsiaousis-hpc Mar 18, 2026
d1b0224
Added test for the sparse transform
chris-tsiaousis-hpc Mar 18, 2026
1240413
Added test for 'MmaPipelineOptionFlags'
chris-tsiaousis-hpc Mar 19, 2026
d8bc466
Deduce kCompressionRatio automatically from MmaOpFamily
chris-tsiaousis-hpc Mar 19, 2026
0283390
De-couple intertwined compile passes for host/device code
chris-tsiaousis-hpc Mar 19, 2026
4c8c262
Reorganise pipeline test code, re-use boilerplate and add TransposeC …
chris-tsiaousis-hpc Mar 19, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion projects/composablekernel/include/ck_tile/core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,16 @@
#include "ck_tile/core/arch/mma/mfma/mfma_selector.hpp"
#include "ck_tile/core/arch/mma/mfma/mfma_traits.hpp"
#include "ck_tile/core/arch/mma/mfma/mfma_transforms.hpp"
#include "ck_tile/core/arch/mma/mma.hpp"
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
#include "ck_tile/core/arch/mma/mma_pipeline.hpp"
#include "ck_tile/core/arch/mma/mma_selector.hpp"
#include "ck_tile/core/arch/mma/mma_traits.hpp"
#include "ck_tile/core/arch/mma/mma_transforms.hpp"
#include "ck_tile/core/arch/mma/mma_wavewise.hpp"
#include "ck_tile/core/arch/mma/sparse/mfma/selector.hpp"
#include "ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp"
#include "ck_tile/core/arch/mma/sparse/sparse.hpp"
#include "ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp"
#include "ck_tile/core/arch/mma/sparse/sparse_selector.hpp"
#include "ck_tile/core/arch/mma/sparse/sparse_traits.hpp"
#include "ck_tile/core/arch/mma/sparse/sparse_transforms.hpp"
Expand Down Expand Up @@ -55,6 +57,7 @@
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/e8m0.hpp"
#include "ck_tile/core/numeric/ext_vector_base.hpp"
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/int8.hpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,14 +156,16 @@ struct amdgcn_mma_base
static constexpr index_t kBRepeat = kBRepeat_; // RDNA3 repetition and MFMA block-hiding
static constexpr index_t kCMPerLane = kCMPerLane_; // M2 * M0
static constexpr index_t kCMNumAccess = kCMNumAccess_; // M2
// K-dimension compression ratio for A matrix, always 2 for sparse intrinsics
static constexpr index_t kCompressionRatio = (OpFamily == MmaOpFamily::SPARSE) ? 2 : 1;

// Register types (derived)
static constexpr index_t WaveSize = WaveSize_;
static_assert((kM * kK * kARepeat) % WaveSize == 0);
static_assert((kN * kK * kBRepeat) % WaveSize == 0);
static_assert((kM * kN) % WaveSize == 0);

using AVecType = ext_vector_t<ADataType, kM * kK * kARepeat / WaveSize>;
using AVecType = ext_vector_t<ADataType, kM * kK * kARepeat / WaveSize / kCompressionRatio>;
using BVecType = ext_vector_t<BDataType, kN * kK * kBRepeat / WaveSize>;
using CVecType = ext_vector_t<CDataType, kM * kN / WaveSize>;
};
Expand All @@ -177,6 +179,20 @@ struct Unsupported;
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER

#include <concepts>
/**
* @concept HasExecSignature
* @brief Helper concept for exec signature check.
*/
template <typename MmaOp, typename... ExecArgs>
concept HasExecSignature = requires {
{
MmaOp::exec(typename MmaOp::AVecType{},
typename MmaOp::BVecType{},
typename MmaOp::CVecType{},
std::declval<ExecArgs>()...)
} -> std::convertible_to<typename MmaOp::CVecType>;
};

/**
* @concept MmaOpI
* @brief Expresses the meta-data interface required for each MmaOp policy.
Expand All @@ -185,7 +201,7 @@ template <typename MmaOp>
concept MmaOpI = requires(MmaOp op) {
// Requires an op context
typename MmaOp::OpType;
typename MmaOp::OpFamily;
{ MmaOp::OpFamily } -> std::convertible_to<MmaOpFamily>;

// Captures types for inputs / outputs to mma function
typename MmaOp::ADataType;
Expand All @@ -203,13 +219,8 @@ concept MmaOpI = requires(MmaOp op) {
{ MmaOp::kBRepeat } -> std::convertible_to<unsigned int>;
{ MmaOp::kCMPerLane } -> std::convertible_to<unsigned int>;
{ MmaOp::kCMNumAccess } -> std::convertible_to<unsigned int>;

// Static exec function
{
MmaOp::exec(
typename MmaOp::AVecType{}, typename MmaOp::BVecType{}, typename MmaOp::CVecType{})
} -> std::convertible_to<typename MmaOp::CVecType>;
};
{ MmaOp::kCompressionRatio } -> std::convertible_to<unsigned int>;
} && (HasExecSignature<MmaOp> || HasExecSignature<MmaOp, int>);

#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@ struct MmaDefaultTransformsGfx9
// TODO: c++20 template <MmaOpI MmaOp, amdgcn_target_arch_id CompilerTarget>
// TODO: c++20 requires
template <typename MmaOp, typename CompilerTarget>
struct MmaTransformsDefaultSelector<MmaOp,
CompilerTarget,
enable_if_target_family_gfx9_t<CompilerTarget>>
struct MmaTransformsDefaultSelector<
MmaOp,
CompilerTarget,
enable_if_all<enable_if_target_family_gfx9_t<CompilerTarget>,
std::enable_if_t<MmaOp::OpFamily == MmaOpFamily::DENSE>>>
{
using SelectedTransforms = MmaDefaultTransformsGfx9;
};
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"

#include "amdgcn_mma.hpp"
#include "mma_selector.hpp"
#include "mma_traits.hpp"
#include "mma_transforms.hpp"

namespace ck_tile::core::arch::mma {

enum struct MmaPipelineOptionFlag
{
NONE = 0x0,
C_TRANSPOSE = 0x1,
COMPRESS_A = 0x2,
};

struct MmaPipelineOptionFlags
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very verbose but I guess necessary if we really don't want to allow raw enums...

{
using Type = std::underlying_type<MmaPipelineOptionFlag>::type;

explicit constexpr MmaPipelineOptionFlags() : mFlags(0) {}
explicit constexpr MmaPipelineOptionFlags(Type value) : mFlags(value) {}
constexpr MmaPipelineOptionFlags(MmaPipelineOptionFlag singleFlag) : mFlags(toType(singleFlag))
{
}
constexpr MmaPipelineOptionFlags(const MmaPipelineOptionFlags& original)
: mFlags(original.mFlags)
{
}

constexpr MmaPipelineOptionFlags& operator|=(MmaPipelineOptionFlag addValue)
{
mFlags |= toType(addValue);
return *this;
}
constexpr MmaPipelineOptionFlags operator|(MmaPipelineOptionFlag addValue) const
{
MmaPipelineOptionFlags result(*this);
result |= addValue;
return result;
}
constexpr MmaPipelineOptionFlags& operator&=(MmaPipelineOptionFlag maskValue)
{
mFlags &= toType(maskValue);
return *this;
}
constexpr MmaPipelineOptionFlags operator&(MmaPipelineOptionFlag maskValue) const
{
MmaPipelineOptionFlags result(*this);
result &= maskValue;
return result;
}
constexpr MmaPipelineOptionFlags operator~() const
{
MmaPipelineOptionFlags result(*this);
result.mFlags = ~result.mFlags;
return result;
}
constexpr bool testFlag(MmaPipelineOptionFlag flag) const
{
return (flag == MmaPipelineOptionFlag::NONE) ? mFlags == toType(flag) : *this & flag;
}
constexpr operator bool() const { return mFlags != toType(MmaPipelineOptionFlag::NONE); }
constexpr bool operator==(Type rhs) const { return mFlags == rhs; }

private:
Type mFlags;
static constexpr Type toType(MmaPipelineOptionFlag f) { return static_cast<Type>(f); }
};

constexpr bool operator==(MmaPipelineOptionFlags::Type lhs, const MmaPipelineOptionFlags& rhs)
{
return rhs == lhs;
}

// TODO: c++20: use MmaPipelineOptionFlags directly
template <MmaPipelineOptionFlags::Type Flags_, typename Derived>
struct MmaPipelineBase
{
static constexpr auto Flags = MmaPipelineOptionFlags(Flags_);

private:
template <typename DstT, typename SrcT>
CK_TILE_DEVICE static auto formatBuffer(SrcT&& inputBuffer)
{
// TODO: Implement formatting logic as needed.
// This is intended to convert input fragments to the native vector types
// required by the BlockWiseMma operation for iteration
static_assert(sizeof(DstT) == sizeof(std::remove_reference_t<SrcT>),
"Size mismatch in formatBuffer");

using QualifiedDstT =
std::conditional_t<std::is_const_v<std::remove_reference_t<SrcT>>, DstT const, DstT>;

return reinterpret_cast<QualifiedDstT&>(inputBuffer);
}

protected:
template <MmaPipelineOptionFlag Flag>
constexpr CK_TILE_DEVICE static bool hasFlag()
{
return Flags.testFlag(Flag);
}

template <typename DstT, typename Transform, typename... Args>
CK_TILE_DEVICE static auto preApplyTransform(Args&&... args)
{
return formatBuffer<DstT>(Transform::exec(std::forward<Args>(args)...));
}

template <typename DstT, typename Transform, typename... Args>
CK_TILE_DEVICE static auto postApplyTransform(Args&&... args)
{
return Transform::exec(formatBuffer<DstT>(std::forward<Args>(args)...));
}

public:
template <typename VecTA, typename VecTB, typename VecTC>
CK_TILE_DEVICE static decltype(auto) exec(VecTA&& a, VecTB&& b, VecTC&& accum)
{
if constexpr(MmaOpTraits<typename Derived::FragWiseMmaOp>::IsSupported)
{
// TODO: c++20: Call template functions with MmaPipelineOptionFlags directly
auto pre = Derived::template preApply<Flags_>(
hasFlag<MmaPipelineOptionFlag::C_TRANSPOSE>() ? std::forward<VecTB>(b)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this works only if A and B have the same type and size

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, CTranspose will not be available for all intrinsics. Also I don't think CTranspose is possible for sparse intrinsics.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll disable them for sparse then!

: std::forward<VecTA>(a),
hasFlag<MmaPipelineOptionFlag::C_TRANSPOSE>() ? std::forward<VecTA>(a)
: std::forward<VecTB>(b),
std::forward<VecTC>(accum));
Derived::execImpl(pre);
return Derived::template postApply<Flags_>(std::move(pre));
}
else
{
// Return the unsupported exec. This should print a runtime warning. (amdgcn_mma.hpp)
// Code should not reach here, but HOST/DEVICE compile passes are
// weirdly intertwined and instead of having constexpr in the calling
// site (tests) we do this. See also changes by this commit.
return Derived::FragWiseMmaOp::exec({}, {}, {});
}
}
Comment on lines +122 to +145
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aha we are making a second-order wrapper for the intrinsic just like in CK Tile, making more sense to me now.

};

#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER

#include <concepts>

/**
* @concept MmaPipelineI
* @brief Expresses the meta-data interface required for a CRTP MmaPipeline.
*/
template <typename Derived, MmaPipelineOptionFlags::Type Flags>
concept MmaPipelineInterface = std::derived_from<Derived, MmaPipelineBase<Flags, Derived>>;

#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER

} // namespace ck_tile::core::arch::mma
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,4 @@ concept MmaSelectorI = requires(MmaSelector op) {
// Include the implementations
#include "wmma/wmma_selector.hpp"
#include "mfma/mfma_selector.hpp"
#include "sparse/sparse_selector.hpp"
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,18 @@ struct PassThroughTransform
}
};

/**
* @struct MmaDefaultPassThroughTransforms
* @brief Implements the default MMA transforms
*/
struct MmaDefaultPassThroughTransforms
{
using ATransform = PassThroughTransform;
using BTransform = PassThroughTransform;
using CTransform = PassThroughTransform;
using DTransform = PassThroughTransform;
};

/**
* @class MmaTransformsDefaultSelector
* @brief Default selector for MmaTransforms based on MmaOp and CompilerTarget
Expand All @@ -27,7 +39,12 @@ struct PassThroughTransform
*/
template <typename MmaOp, typename CompilerTarget, typename Enable = void>
// TODO: c++20 template <MmaOpI MmaOp, amdgcn_target_arch_id CompilerTarget, typename Enable = void>
struct MmaTransformsDefaultSelector;
struct MmaTransformsDefaultSelector
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is not good. It was done because DEVICE and HOST code are weirdly intertwined and I'll think of a way to revert this.

{
using SelectedTransforms = MmaDefaultPassThroughTransforms;
static_assert(CompilerTarget::TARGET_ID == amdgcn_target_id::HOST,
"Device code should use another specialization.");
};

#if CK_TILE_CONCEPTS

Expand Down
Loading
Loading