Skip to content

Commit b27b9d5

Browse files
authored
[AMD] GetThreadsPerWarpForOperand interface (#5675)
This PR implements GetThreadsPerWarpForOperand function for WMMA and MFMA layouts.
1 parent 0b2f486 commit b27b9d5

File tree

2 files changed

+80
-8
lines changed

2 files changed

+80
-8
lines changed

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1964,9 +1964,30 @@ AMDMfmaEncodingAttr::getRepOrderForOperand(int opIdx) const {
19641964

19651965
SmallVector<unsigned>
19661966
AMDMfmaEncodingAttr::getThreadsPerWarpForOperand(int opIdx) const {
1967-
llvm::report_fatal_error(
1968-
"getThreadsPerWarpForOperand not implemented for AMDMfmaEncodingAttr");
1969-
return {};
1967+
auto rank = ::getOrder(*this).size();
1968+
SmallVector<unsigned> threads(rank, 1);
1969+
unsigned kThreads;
1970+
unsigned nonKThreads;
1971+
switch (getMDim()) {
1972+
case 32:
1973+
assert(getNDim() == 32);
1974+
kThreads = 2;
1975+
nonKThreads = 32;
1976+
break;
1977+
case 16:
1978+
assert(getNDim() == 16);
1979+
kThreads = 4;
1980+
nonKThreads = 16;
1981+
break;
1982+
default:
1983+
llvm::report_fatal_error(
1984+
"unexpected mfma shape encountered in getThreadsPerWarpForOperand");
1985+
}
1986+
int kDimIdx = opIdx == 0 ? rank - 1 : rank - 2;
1987+
int nonKDimIdx = opIdx == 0 ? rank - 2 : rank - 1;
1988+
threads[kDimIdx] = kThreads;
1989+
threads[nonKDimIdx] = nonKThreads;
1990+
return threads;
19701991
}
19711992

19721993
SmallVector<int64_t>
@@ -2032,9 +2053,30 @@ AMDWmmaEncodingAttr::getRepOrderForOperand(int opIdx) const {
20322053

20332054
SmallVector<unsigned>
20342055
AMDWmmaEncodingAttr::getThreadsPerWarpForOperand(int opIdx) const {
2035-
llvm::report_fatal_error("getThreadsPerWarpForOperand not implemented for "
2036-
"AMDWmmaEncodingAttr");
2037-
return {};
2056+
auto rank = ::getOrder(*this).size();
2057+
SmallVector<unsigned> threads(rank, 1);
2058+
unsigned kThreads;
2059+
unsigned nonKThreads;
2060+
switch (getVersion()) {
2061+
case 1:
2062+
// kThreads * onKThreads != 32,
2063+
// because values in lanes (n, n + 16) duplicates
2064+
kThreads = 1;
2065+
nonKThreads = 16;
2066+
break;
2067+
case 2:
2068+
kThreads = 2;
2069+
nonKThreads = 16;
2070+
break;
2071+
default:
2072+
llvm::report_fatal_error(
2073+
"unsupported WMMA version in getThreadsPerWarpForOperand");
2074+
}
2075+
int kDimIdx = opIdx == 0 ? rank - 1 : rank - 2;
2076+
int nonKDimIdx = opIdx == 0 ? rank - 2 : rank - 1;
2077+
threads[kDimIdx] = kThreads;
2078+
threads[nonKDimIdx] = nonKThreads;
2079+
return threads;
20382080
}
20392081

20402082
SmallVector<unsigned> AMDWmmaEncodingAttr::getCTAsPerCGA() const {

unittest/Dialect/TritonGPU/DialectTest.cpp

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,8 @@ TEST_F(AMDMfmaLayoutTest, mfma_dot_op) {
368368
auto dot2dOp1 = createDotOperand(1, mfma2d, 4);
369369
ASSERT_THAT(dot2dOp0.getWarpOrder(), mfma2d.getWarpOrder());
370370
ASSERT_THAT(dot2dOp1.getWarpOrder(), mfma2d.getWarpOrder());
371+
ASSERT_THAT(dot2dOp0.getThreadsPerWarp(), testing::ElementsAre(32u, 2u));
372+
ASSERT_THAT(dot2dOp1.getThreadsPerWarp(), testing::ElementsAre(2u, 32u));
371373

372374
auto tmfma2d = createTransposedMFMA(32, 32, {2, 4});
373375
auto tdot2dOp0 = createDotOperand(0, tmfma2d, 4);
@@ -380,12 +382,28 @@ TEST_F(AMDMfmaLayoutTest, mfma_dot_op) {
380382
auto dot3dOp1 = createDotOperand(1, mfma3d, 4);
381383
ASSERT_THAT(dot3dOp0.getWarpOrder(), mfma3d.getWarpOrder());
382384
ASSERT_THAT(dot3dOp1.getWarpOrder(), mfma3d.getWarpOrder());
385+
ASSERT_THAT(dot3dOp0.getThreadsPerWarp(), testing::ElementsAre(1u, 32u, 2u));
386+
ASSERT_THAT(dot3dOp1.getThreadsPerWarp(), testing::ElementsAre(1u, 2u, 32u));
383387

384388
auto tmfma3d = createTransposedMFMA(32, 32, {2, 4, 1});
385389
auto tdot3dOp0 = createDotOperand(0, tmfma3d, 4);
386390
auto tdot3dOp1 = createDotOperand(1, tmfma3d, 4);
387391
ASSERT_THAT(tdot3dOp0.getWarpOrder(), tmfma3d.getWarpOrder());
388392
ASSERT_THAT(tdot3dOp1.getWarpOrder(), tmfma3d.getWarpOrder());
393+
394+
auto mfma16_2d = createMFMA(16, 16, {2, 4});
395+
auto dot16_2dOp0 = createDotOperand(0, mfma16_2d, 4);
396+
auto dot16_2dOp1 = createDotOperand(1, mfma16_2d, 4);
397+
ASSERT_THAT(dot16_2dOp0.getThreadsPerWarp(), testing::ElementsAre(16u, 4u));
398+
ASSERT_THAT(dot16_2dOp1.getThreadsPerWarp(), testing::ElementsAre(4u, 16u));
399+
400+
auto mfma16_3d = createMFMA(16, 16, {2, 4, 1});
401+
auto dot16_3dOp0 = createDotOperand(0, mfma16_3d, 4);
402+
auto dot16_3dOp1 = createDotOperand(1, mfma16_3d, 4);
403+
ASSERT_THAT(dot16_3dOp0.getThreadsPerWarp(),
404+
testing::ElementsAre(1u, 16u, 4u));
405+
ASSERT_THAT(dot16_3dOp1.getThreadsPerWarp(),
406+
testing::ElementsAre(1u, 4u, 16u));
389407
}
390408

391409
TEST_F(AMDWmmaLayoutTest, wmmaV1) {
@@ -434,24 +452,36 @@ TEST_F(AMDWmmaLayoutTest, wmma_dot_op) {
434452
auto dot2dVer1Op1 = createDotOperand(1, wmma2dVer1, 16);
435453
ASSERT_THAT(dot2dVer1Op0.getWarpOrder(), wmma2dVer1.getWarpOrder());
436454
ASSERT_THAT(dot2dVer1Op1.getWarpOrder(), wmma2dVer1.getWarpOrder());
455+
ASSERT_THAT(dot2dVer1Op0.getThreadsPerWarp(), testing::ElementsAre(16u, 1u));
456+
ASSERT_THAT(dot2dVer1Op1.getThreadsPerWarp(), testing::ElementsAre(1u, 16u));
437457

438-
auto wmma3dVer1 = createWMMAv1({2, 4});
458+
auto wmma3dVer1 = createWMMAv1({2, 4, 1});
439459
auto dot3dVer1Op0 = createDotOperand(0, wmma3dVer1, 16);
440460
auto dot3dVer1Op1 = createDotOperand(1, wmma3dVer1, 16);
441461
ASSERT_THAT(dot3dVer1Op0.getWarpOrder(), wmma3dVer1.getWarpOrder());
442462
ASSERT_THAT(dot3dVer1Op1.getWarpOrder(), wmma3dVer1.getWarpOrder());
463+
ASSERT_THAT(dot3dVer1Op0.getThreadsPerWarp(),
464+
testing::ElementsAre(1, 16u, 1u));
465+
ASSERT_THAT(dot3dVer1Op1.getThreadsPerWarp(),
466+
testing::ElementsAre(1, 1u, 16u));
443467

444468
auto wmma2dVer2 = createWMMAv2(false, {2, 4});
445469
auto dot2dVer2Op0 = createDotOperand(0, wmma2dVer2, 16);
446470
auto dot2dVer2Op1 = createDotOperand(1, wmma2dVer2, 16);
447471
ASSERT_THAT(dot2dVer2Op0.getWarpOrder(), wmma2dVer2.getWarpOrder());
448472
ASSERT_THAT(dot2dVer2Op1.getWarpOrder(), wmma2dVer2.getWarpOrder());
473+
ASSERT_THAT(dot2dVer2Op0.getThreadsPerWarp(), testing::ElementsAre(16u, 2u));
474+
ASSERT_THAT(dot2dVer2Op1.getThreadsPerWarp(), testing::ElementsAre(2u, 16u));
449475

450-
auto wmma3dVer2 = createWMMAv2(false, {2, 4});
476+
auto wmma3dVer2 = createWMMAv2(false, {2, 4, 1});
451477
auto dot3dVer2Op0 = createDotOperand(0, wmma3dVer2, 16);
452478
auto dot3dVer2Op1 = createDotOperand(1, wmma3dVer2, 16);
453479
ASSERT_THAT(dot3dVer2Op0.getWarpOrder(), wmma3dVer2.getWarpOrder());
454480
ASSERT_THAT(dot3dVer2Op1.getWarpOrder(), wmma3dVer2.getWarpOrder());
481+
ASSERT_THAT(dot3dVer2Op0.getThreadsPerWarp(),
482+
testing::ElementsAre(1, 16u, 2u));
483+
ASSERT_THAT(dot3dVer2Op1.getThreadsPerWarp(),
484+
testing::ElementsAre(1, 2u, 16u));
455485
}
456486

457487
class LinearEncodingTest : public ::testing::Test {

0 commit comments

Comments
 (0)