Skip to content

[CK_Tile] Refactor amdgcn_mma policy structs#5272

Open
krithalith wants to merge 9 commits intodevelopfrom
users/krithalith/ck/unification_policy_struct_refactor
Open

[CK_Tile] Refactor amdgcn_mma policy structs#5272
krithalith wants to merge 9 commits intodevelopfrom
users/krithalith/ck/unification_policy_struct_refactor

Conversation

@krithalith
Copy link
Contributor

@krithalith krithalith commented Mar 10, 2026

Motivation

The point of this MR is to update the intrinsic layout parameters to simplify them and make them more clear and flexible. Also, a number of simple refactors were performed to reduce boilerplate and code duplication.

Technical Details

In CK Tile and old CK, the full set of information available in the intrinsic wrappers, for WMMA and MFMA combined, would be something like:

// Basic info
using ADataType = void;
using BDataType = void;
using CDataType = void;

using AVecType = ext_vector_t<ADataType, 0>;
using BVecType = ext_vector_t<BDataType, 0>;
using CVecType = ext_vector_t<CDataType, 0>;

// Fragment sizes
static constexpr index_t kM;
static constexpr index_t kN;
static constexpr index_t kK;

// Layout parameters
static constexpr index_t kAMBlock;
static constexpr index_t kBNBlock;

static constexpr index_t kRepeat;
static constexpr index_t kAMLane;
static constexpr index_t kBNLane;
static constexpr index_t kABK0PerLane;
static constexpr index_t kABKLane;
static constexpr index_t kABK1PerLane;

static constexpr index_t kCMLane;
static constexpr index_t kCNLane;
static constexpr index_t kCM0PerLane;
static constexpr index_t kCM1PerLane;

using kABPs2RHssMajor = sequence<2, 1>;
using kABPs2RHssMinor = sequence<1, 0>;
using kABYs2RHsMajor  = sequence<2, 2>;
using kABYs2RHsMinor  = sequence<0, 2>;

using kCPs2RHssMajor = sequence<1, 2>;
using kCPs2RHssMinor = sequence<1, 0>;
using kCYs2RHsMajor  = sequence<1, 1>;
using kCYs2RHsMinor  = sequence<0, 2>;

using kCTPs2RHssMajor = sequence<2, 1>;
using kCTPs2RHssMinor = sequence<1, 0>;
using kCTYs2RHsMajor  = sequence<2, 2>;
using kCTYs2RHsMinor  = sequence<0, 2>;   

Note that on top of the intrinsic sizes, we have 12 layout parameters. I have reduced this in the new design to:

// Basic info
using ADataType = void;
using BDataType = void;
using CDataType = void;

// Fragment sizes
static constexpr index_t kM;
static constexpr index_t kN;
static constexpr index_t kK;

// Layout parameters
static constexpr index_t kABKPerLane;  // K2 * K0, Always the same, even for diff A / B layouts
static constexpr index_t kAKNumAccess; // K2
static constexpr index_t kARepeat;     // Used for RDNA3 repeated inputs and CDNA block hiding.
static constexpr index_t kBKNumAccess; // K2
static constexpr index_t kBRepeat;     // Used for RDNA3 repeated inputs and CDNA block hiding.
static constexpr index_t kCMPerLane;   // M2 * M0
static constexpr index_t kCMNumAccess; // M2

// Derived properties
using AVecType = ext_vector_t<ADataType, 0>;
using BVecType = ext_vector_t<BDataType, 0>;
using CVecType = ext_vector_t<CDataType, 0>;

Note that there are now only 7 layout parameters and no more dimensionality orderings. Believe it or not these 7 parameters are more general than the original 12, and can handle intrinsic and mid-level features that are currently awkward in CK Tile, like dealing with AttrNumAccess, different A / B layouts, more general block-hiding (currently very limited in CK tile), and future arch features.

Furthermore, the A, B and C vec types are now derived directly from the layout parameters to ensure internal consistency.

I added a detailed explanation of the new params in terms of register mappings at the top of amgcn_mma.hpp

Other refactorings I did in this MR:

  • Make an amdgcn_mma_base struct to drastically reduce code duplication and potential bugs. Should also make auto-generating the amd_gcn specializations much easier.

  • Simplify the MmaOpTraits significantly by only including those parameters that are not directly gettable from the MmaOp itself. This removes duplicated variables and simplifies higher level code.

  • Remove overloaded "Block" term for intrinsic dimensions, and replace by "Frag" instead. Some spots were already using the term "Frag" for combined intrinsics, in which case I changed that term to "Chunk" instead.

  • Remove some tests that had become somewhat pointless (setting variables and then checking their values immediately).

  • Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.

@krithalith krithalith requested a review from a team as a code owner March 10, 2026 10:40
@krithalith krithalith added the organization: streamhpc contributors from streamhpc label Mar 10, 2026
@krithalith krithalith force-pushed the users/krithalith/ck/unification_policy_struct_refactor branch from b37eacf to 6fb3ad3 Compare March 10, 2026 10:55
@krithalith krithalith marked this pull request as draft March 10, 2026 10:57
@wj-laskowski wj-laskowski self-requested a review March 10, 2026 11:00
Copy link
Contributor

@chris-tsiaousis-hpc chris-tsiaousis-hpc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for those changes @krithalith, things look much simpler now! I've added some comments and proposals for improvement.

// TODO: Describe layout params.
/**
* @class amdgcn_mma_base
* @brief Helper base class for amdgcn_mma structs to avoid a lot of code duplication. Also puts
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Helper" here is not correct IMO. This is just a base class.

static constexpr index_t kN = FragN;
static constexpr index_t kK = FragK;

// Layout constants
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the description of the PR you mention:

// Layout parameters
static constexpr index_t kABKPerLane;  // K2 * K0, Always the same, even for diff A / B layouts
static constexpr index_t kAKNumAccess; // K2
static constexpr index_t kARepeat;     // Used for RDNA3 repeated inputs and CDNA block hiding.
static constexpr index_t kBKNumAccess; // K2
static constexpr index_t kBRepeat;     // Used for RDNA3 repeated inputs and CDNA block hiding.
static constexpr index_t kCMPerLane;   // M2 * M0
static constexpr index_t kCMNumAccess; // M2

I'd like to have those comments here as well. It would also be useful for future devs having a peek on this to mention that M = M0 * M1 * M2 and so on...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I'll add some inline comments and plan to add a more detailed description of the parameters, maybe with some ASCII art at the top of the file somewhere.

* @tparam CompilerTarget The current compiler target
* @tparam Enabler SFINAE enabler
*/
// clang-format off
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why turn off clang-format here and for such a long section? If you only need this for the instantiation of the base class, maybe just do it one line before and enable it right after?

exec(AVecType& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
{
static constexpr index_t CompressedSize = ABVecN / kCompressionRatio;
static constexpr index_t CompressedSize = vector_traits<AVecType>::vector_size / 2;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer this 2 to be a constexpr variable like it used to, since we are still waiting for an answer on whether other compression ratios will be supported. It is also cleaner than a hardcoded number within the exec function...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I knew you would bring this up and you are right, I was just in line-reduction mode :)

// and evaluate changing this to a transform at a higher level.
// aVec not being const can cause problems when running multiple intrinsics.
const int32_t idx = ck_tile::compress_a_impl<fp16_t, CompressedSize>(aVec);
const index_t idx = ck_tile::compress_a_impl<fp16_t, CompressedSize>(aVec);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While it compiles, this is not correct since the function returns an int32_t and the builtin expects an int as a fourth parameter.

a_vec_pruned, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
}
};
// clang-format on
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, avoid prolonging the disabled clang-format section

return {__builtin_amdgcn_swmmac_f32_16x16x32_f16_w32(a_vec_pruned, bVec, cVec, idx)};
}
};
// clang-format on
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid prolonging the disabled clang-format section

return {__builtin_amdgcn_wmma_f32_16x16x16_f16_w32(aVec, bVec, cVec)};
}
};
// clang-format on
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid prolonging the disabled clang-format section

return {__builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(aVec, bVec, cVec)};
}
};
// clang-format on
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid prolonging the disabled clang-format section

Copy link
Contributor

@wj-laskowski wj-laskowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! My comments are mostly nits and some thinking out loud.

@wj-laskowski wj-laskowski self-requested a review March 11, 2026 11:00
* Note that K0 = 2 (first unmerge size, fastest changing), K1 = 3 (second unmerge size,
* second-fastest changing), and K2 = 12 / 2 / 3 = 2 (outermost dimension, whatever is left).
*
* If we were to use this unmerge op to decribe an A matrix layout in registers, we might have for
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

describe

@krithalith krithalith marked this pull request as ready for review March 11, 2026 15:57
@krithalith krithalith force-pushed the users/krithalith/ck/unification_policy_struct_refactor branch 2 times, most recently from 0e26f69 to 78119c5 Compare March 12, 2026 09:03
@krithalith krithalith changed the title [WMMA / MFMA unification] Refactor amdgcn_mma policy structs [CK_Tile] Refactor amdgcn_mma policy structs Mar 16, 2026
@krithalith krithalith force-pushed the users/krithalith/ck/unification_policy_struct_refactor branch from e5b031e to 2dbdd57 Compare March 18, 2026 10:43
@krithalith krithalith requested a review from cgmillette March 18, 2026 10:53
// Test MmaDefaultSelector for supported DummyAmdgcnMma on fragment sizes other than 16x16x16
// This tests that the selector can still pick the correct MMA op even if the fragment sizes differ
TEST(TestAmdgcnMma, MmaDefaultSelectorSupportedFragment)
// Test MmaDefaultSelector for supported DummyAmdgcnMma on chunk sizes other than 16x16x16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some tests still have references to "chunk"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, my search range was too small, nice catch!

Copy link
Contributor

@cgmillette cgmillette left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great PR! Excellent work.
I really enjoyed seeing the reduction of boilerplate code with the extraction of the base class.

@krithalith krithalith force-pushed the users/krithalith/ck/unification_policy_struct_refactor branch from 9f4f0a7 to aa5bcbc Compare March 18, 2026 15:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants