-
Notifications
You must be signed in to change notification settings - Fork 241
[CK Tile] Unification work - mma transformations pipeline #5508
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
f974282
69e8516
29893f8
49f35e9
8921653
dc00433
6839f38
28b3dd0
0e1cfb3
1b14452
8088705
d1b0224
1240413
d8bc466
0283390
4c8c262
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,161 @@ | ||
| // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. | ||
| // SPDX-License-Identifier: MIT | ||
| #pragma once | ||
| #include "ck_tile/core/arch/arch.hpp" | ||
| #include "ck_tile/core/numeric/vector_type.hpp" | ||
|
|
||
| #include "amdgcn_mma.hpp" | ||
| #include "mma_selector.hpp" | ||
| #include "mma_traits.hpp" | ||
| #include "mma_transforms.hpp" | ||
|
|
||
| namespace ck_tile::core::arch::mma { | ||
|
|
||
| enum struct MmaPipelineOptionFlag | ||
| { | ||
| NONE = 0x0, | ||
| C_TRANSPOSE = 0x1, | ||
| COMPRESS_A = 0x2, | ||
| }; | ||
|
|
||
| struct MmaPipelineOptionFlags | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Very verbose but I guess necessary if we really don't want to allow raw enums... |
||
| { | ||
| using Type = std::underlying_type<MmaPipelineOptionFlag>::type; | ||
|
|
||
| explicit constexpr MmaPipelineOptionFlags() : mFlags(0) {} | ||
| explicit constexpr MmaPipelineOptionFlags(Type value) : mFlags(value) {} | ||
| constexpr MmaPipelineOptionFlags(MmaPipelineOptionFlag singleFlag) : mFlags(toType(singleFlag)) | ||
| { | ||
| } | ||
| constexpr MmaPipelineOptionFlags(const MmaPipelineOptionFlags& original) | ||
| : mFlags(original.mFlags) | ||
| { | ||
| } | ||
|
|
||
| constexpr MmaPipelineOptionFlags& operator|=(MmaPipelineOptionFlag addValue) | ||
| { | ||
| mFlags |= toType(addValue); | ||
| return *this; | ||
| } | ||
| constexpr MmaPipelineOptionFlags operator|(MmaPipelineOptionFlag addValue) const | ||
| { | ||
| MmaPipelineOptionFlags result(*this); | ||
| result |= addValue; | ||
| return result; | ||
| } | ||
| constexpr MmaPipelineOptionFlags& operator&=(MmaPipelineOptionFlag maskValue) | ||
| { | ||
| mFlags &= toType(maskValue); | ||
| return *this; | ||
| } | ||
| constexpr MmaPipelineOptionFlags operator&(MmaPipelineOptionFlag maskValue) const | ||
| { | ||
| MmaPipelineOptionFlags result(*this); | ||
| result &= maskValue; | ||
| return result; | ||
| } | ||
| constexpr MmaPipelineOptionFlags operator~() const | ||
| { | ||
| MmaPipelineOptionFlags result(*this); | ||
| result.mFlags = ~result.mFlags; | ||
| return result; | ||
| } | ||
| constexpr bool testFlag(MmaPipelineOptionFlag flag) const | ||
| { | ||
| return (flag == MmaPipelineOptionFlag::NONE) ? mFlags == toType(flag) : *this & flag; | ||
| } | ||
| constexpr operator bool() const { return mFlags != toType(MmaPipelineOptionFlag::NONE); } | ||
| constexpr bool operator==(Type rhs) const { return mFlags == rhs; } | ||
|
|
||
| private: | ||
| Type mFlags; | ||
| static constexpr Type toType(MmaPipelineOptionFlag f) { return static_cast<Type>(f); } | ||
| }; | ||
|
|
||
| constexpr bool operator==(MmaPipelineOptionFlags::Type lhs, const MmaPipelineOptionFlags& rhs) | ||
| { | ||
| return rhs == lhs; | ||
| } | ||
|
|
||
| // TODO: c++20: use MmaPipelineOptionFlags directly | ||
| template <MmaPipelineOptionFlags::Type Flags_, typename Derived> | ||
| struct MmaPipelineBase | ||
| { | ||
| static constexpr auto Flags = MmaPipelineOptionFlags(Flags_); | ||
|
|
||
| private: | ||
| template <typename DstT, typename SrcT> | ||
| CK_TILE_DEVICE static auto formatBuffer(SrcT&& inputBuffer) | ||
| { | ||
| // TODO: Implement formatting logic as needed. | ||
| // This is intended to convert input fragments to the native vector types | ||
| // required by the BlockWiseMma operation for iteration | ||
| static_assert(sizeof(DstT) == sizeof(std::remove_reference_t<SrcT>), | ||
| "Size mismatch in formatBuffer"); | ||
|
|
||
| using QualifiedDstT = | ||
| std::conditional_t<std::is_const_v<std::remove_reference_t<SrcT>>, DstT const, DstT>; | ||
|
|
||
| return reinterpret_cast<QualifiedDstT&>(inputBuffer); | ||
| } | ||
|
|
||
| protected: | ||
| template <MmaPipelineOptionFlag Flag> | ||
| constexpr CK_TILE_DEVICE static bool hasFlag() | ||
| { | ||
| return Flags.testFlag(Flag); | ||
| } | ||
|
|
||
| template <typename DstT, typename Transform, typename... Args> | ||
| CK_TILE_DEVICE static auto preApplyTransform(Args&&... args) | ||
| { | ||
| return formatBuffer<DstT>(Transform::exec(std::forward<Args>(args)...)); | ||
| } | ||
|
|
||
| template <typename DstT, typename Transform, typename... Args> | ||
| CK_TILE_DEVICE static auto postApplyTransform(Args&&... args) | ||
| { | ||
| return Transform::exec(formatBuffer<DstT>(std::forward<Args>(args)...)); | ||
| } | ||
|
|
||
| public: | ||
| template <typename VecTA, typename VecTB, typename VecTC> | ||
| CK_TILE_DEVICE static decltype(auto) exec(VecTA&& a, VecTB&& b, VecTC&& accum) | ||
| { | ||
| if constexpr(MmaOpTraits<typename Derived::FragWiseMmaOp>::IsSupported) | ||
| { | ||
| // TODO: c++20: Call template functions with MmaPipelineOptionFlags directly | ||
| auto pre = Derived::template preApply<Flags_>( | ||
| hasFlag<MmaPipelineOptionFlag::C_TRANSPOSE>() ? std::forward<VecTB>(b) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this works only if A and B have the same type and size
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, CTranspose will not be available for all intrinsics. Also I don't think CTranspose is possible for sparse intrinsics.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll disable them for sparse then! |
||
| : std::forward<VecTA>(a), | ||
| hasFlag<MmaPipelineOptionFlag::C_TRANSPOSE>() ? std::forward<VecTA>(a) | ||
| : std::forward<VecTB>(b), | ||
| std::forward<VecTC>(accum)); | ||
| Derived::execImpl(pre); | ||
| return Derived::template postApply<Flags_>(std::move(pre)); | ||
| } | ||
| else | ||
| { | ||
| // Return the unsupported exec. This should print a runtime warning. (amdgcn_mma.hpp) | ||
| // Code should not reach here, but HOST/DEVICE compile passes are | ||
| // weirdly intertwined and instead of having constexpr in the calling | ||
| // site (tests) we do this. See also changes by this commit. | ||
| return Derived::FragWiseMmaOp::exec({}, {}, {}); | ||
| } | ||
| } | ||
|
Comment on lines
+122
to
+145
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Aha we are making a second-order wrapper for the intrinsic just like in CK Tile, making more sense to me now. |
||
| }; | ||
|
|
||
| #if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER | ||
|
|
||
| #include <concepts> | ||
|
|
||
| /** | ||
| * @concept MmaPipelineI | ||
| * @brief Expresses the meta-data interface required for a CRTP MmaPipeline. | ||
| */ | ||
| template <typename Derived, MmaPipelineOptionFlags::Type Flags> | ||
| concept MmaPipelineInterface = std::derived_from<Derived, MmaPipelineBase<Flags, Derived>>; | ||
|
|
||
| #endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER | ||
|
|
||
| } // namespace ck_tile::core::arch::mma | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,6 +18,18 @@ struct PassThroughTransform | |
| } | ||
| }; | ||
|
|
||
| /** | ||
| * @struct MmaDefaultPassThroughTransforms | ||
| * @brief Implements the default MMA transforms | ||
| */ | ||
| struct MmaDefaultPassThroughTransforms | ||
| { | ||
| using ATransform = PassThroughTransform; | ||
| using BTransform = PassThroughTransform; | ||
| using CTransform = PassThroughTransform; | ||
| using DTransform = PassThroughTransform; | ||
| }; | ||
|
|
||
| /** | ||
| * @class MmaTransformsDefaultSelector | ||
| * @brief Default selector for MmaTransforms based on MmaOp and CompilerTarget | ||
|
|
@@ -27,7 +39,12 @@ struct PassThroughTransform | |
| */ | ||
| template <typename MmaOp, typename CompilerTarget, typename Enable = void> | ||
| // TODO: c++20 template <MmaOpI MmaOp, amdgcn_target_arch_id CompilerTarget, typename Enable = void> | ||
| struct MmaTransformsDefaultSelector; | ||
| struct MmaTransformsDefaultSelector | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change is not good. It was done because DEVICE and HOST code are weirdly intertwined and I'll think of a way to revert this. |
||
| { | ||
| using SelectedTransforms = MmaDefaultPassThroughTransforms; | ||
| static_assert(CompilerTarget::TARGET_ID == amdgcn_target_id::HOST, | ||
| "Device code should use another specialization."); | ||
| }; | ||
|
|
||
| #if CK_TILE_CONCEPTS | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.