From c9a40b26d4efeb1f35958169a465839d60f902d2 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Wed, 23 Oct 2024 08:15:42 -0700 Subject: [PATCH 1/7] [Build] Remove unnecessary `NVGPUIR` from `TritonGPUToLLVM` (#4977) Note, there are no uses of `nvgpu::` in this lib. Unblocks building `*-opt` tools with "custom" LLVM that was built with `-DLLVM_TARGETS_TO_BUILD="host;AMDGPU"` (i.e., no `NVPTX`). --- lib/Conversion/TritonGPUToLLVM/CMakeLists.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt index cca2830b04..4d57131d02 100644 --- a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt +++ b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt @@ -34,5 +34,4 @@ add_triton_library(TritonGPUToLLVM TritonGPUIR TritonGPUTransforms TritonNvidiaGPUTransforms - NVGPUIR ) From a20ce645ffcb4992fb20b083853e5051f6b0f9e3 Mon Sep 17 00:00:00 2001 From: Alexander Efimov Date: Wed, 23 Oct 2024 17:43:37 +0200 Subject: [PATCH 2/7] [AMD] Add MFMA dot operand to LinearLayout conversion (#4961) This PR implements general conversion of MFMA dot operand to Linear Layout. --- lib/Dialect/TritonGPU/IR/Dialect.cpp | 5 + .../TritonGPU/IR/LinearLayoutConversions.cpp | 85 ++-- unittest/Dialect/TritonGPU/DialectTest.cpp | 32 ++ .../TritonGPU/LinearLayoutConversionsTest.cpp | 382 +++++++++++++++++- 4 files changed, 464 insertions(+), 40 deletions(-) diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 70eaf5d3b6..71506ecbb9 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -256,6 +256,11 @@ SmallVector getOrderForDotOperand(unsigned opIdx, unsigned rank, } SmallVector getWarpOrder(Attribute layout) { + if (auto dotLayout = dyn_cast(layout)) { + if (isa(dotLayout.getParent())) { + return getWarpOrder(dotLayout.getParent()); + } + } auto order = getOrder(layout); // FIXME: This mmaLayout if should just return // getOrderForDotOperand(0, order.size(), kMajor=false) diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index 039de22cb4..b5fe1463af 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -472,10 +472,6 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef shape) const { int nIndex = 1 + hasBatchDim; (void)mIndex, (void)nIndex; - assert(((shape[mIndex] == 1 || shape[mIndex] >= getMDim()) && - (shape[nIndex] == 1 || shape[nIndex] >= getNDim())) && - "Unsupported tensor shape for given mfma layout"); - assert(((getMDim() == 32 && getNDim() == 32) || (getMDim() == 16 && getNDim() == 16)) && "Unsupported mfma type"); @@ -579,55 +575,76 @@ dotOperandMfmaToLinearLayout(DotOperandEncodingAttr dotMfmaLayout, // 4) warpsPerCTA[mDim] == 1: This guarantees that every B tensor element is // held by exactly one thread, maintaining the same number of global loads // as in a blocked layout. + // + // Other use of Linear layout is a support of rare corner cases, + // for example one instruction tile is larger than tensor auto mfmaLayout = llvm::cast(dotMfmaLayout.getParent()); - if (dotMfmaLayout.getOpIdx() == 0) { - return std::nullopt; - } auto rank = shape.size(); bool hasBatchDim = rank == 3; int mIndex = 0 + hasBatchDim; - auto kWidth = dotMfmaLayout.getKWidth(); + int32_t kWidth = dotMfmaLayout.getKWidth(); + auto kDim = dotMfmaLayout.getOpIdx() == 0 ? rank - 1 : rank - 2; + int32_t kSize = shape[kDim]; auto warpsPerCTA = mfmaLayout.getWarpsPerCTA(); - if (kWidth != 8 || warpsPerCTA[mIndex] != 1) { - return std::nullopt; - } - MLIRContext *ctx = dotMfmaLayout.getContext(); SmallVector outDimNames = standardOutDimNames(ctx, rank); StringAttr kRegister = S("register"); StringAttr kLane = S("lane"); + StringAttr kWarp = S("warp"); + // register order + // operand A: [1, 0] / [2, 1, 0] + // operand B: [0, 1] / [1, 2, 0] + // for both cases it is [k, nonk]/[k, nonk, batch] SmallVector order = triton::gpu::getOrder(dotMfmaLayout); - auto tileLayout = LinearLayout::empty(); + // warp order + // common for both operand A and B: [0, 1] / [0, 1, 2] + // in both cases it is [M dim, N dim]/[batch, M dim, N dim] + SmallVector warpOrder = triton::gpu::getWarpOrder(dotMfmaLayout); + + // Lane holds kWidth consecutive elements along k dimension, so + // base register vectors for one tile are initialized in following way: + // {1, 0}, {2, 0} ... {kWidth/2, 0} + std::vector> registerBase; + for (int32_t elem = 1; elem < kWidth; elem *= 2) + registerBase.emplace_back(std::vector{elem, 0}); + + std::vector> laneBase; + int32_t kTileSize = -1; if (mfmaLayout.getMDim() == 32) { - // Based on canonical MFMA linear layout, which handles 4 consecutive - // elements along the register dimension, kWidth=8 means we have 8 - // consecutive elements, so we have an additional {4, 0} base vector here. - // For lane dim, since the MFMA thread arrangement is {K, N} = {2, 32}, this - // means that mapping of first 5 base (up to thread 16) vectors will be an - // identity along N dim. Thread 32 will be mapped to element 8 in K - // dimension, because kWidth == 8. - tileLayout = LinearLayout( - {{kRegister, {{1, 0}, {2, 0}, {4, 0}}}, - {kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {8, 0}}}}, - {outDimNames[order[0]], outDimNames[order[1]]}); + // Canonical MFMA linear layout handles 4 consecutive elements along + // the register dimension. Dot operand handles varaible kWidth consecutive + // elements. For lane dim, since the MFMA thread arrangement is {K, N} = {2, + // 32}, this means that mapping of first 5 base (up to thread 16) vectors + // will be an identity along N dim. Thread 32 will be mapped to element + // kWidth in K dimension. + laneBase = {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {kWidth, 0}}; + kTileSize = kWidth * 2; } else { assert(mfmaLayout.getMDim() == 16); // For lane dim, since the MFMA thread arrangement is {K, N} = {4, 16}, this // means that mapping of first 4 base (up to thread 16) vectors will be an - // identity along N dim. Thread 16 will be mapped to element 8 in K - // dimension, because kWidth == 8. Thread 32 is mapped to element 16 as that - // is 2*kWidth in K dim. - tileLayout = LinearLayout( - {{kRegister, {{1, 0}, {2, 0}, {4, 0}}}, - {kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {8, 0}, {16, 0}}}}, - {outDimNames[order[0]], outDimNames[order[1]]}); + // identity along N dim. Thread 16 will be mapped to element kWisth in K + // dimension. Thread 32 is mapped to element 2*kWidth in K dim. + laneBase = {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {kWidth, 0}, {kWidth * 2, 0}}; + kTileSize = kWidth * 4; } + assert(kTileSize != -1); + // Add repeats of registers along K dimension to register base vectors + for (int32_t elem = kTileSize; elem < kSize; elem *= 2) + registerBase.emplace_back(std::vector{elem, 0}); + + // Base vectors above are defined in a fixed order [non-k-dim, k-dim]. + // To assign them to actual matrix dimensions `order` array is used. + // For operand A: non-k-dim -> dim0, k-dim -> dim1 + // For operand B: non-k-dim -> dim1, k-dim -> dim0 + LinearLayout tileLayout({{kRegister, registerBase}, {kLane, laneBase}}, + {outDimNames[order[0]], outDimNames[order[1]]}); if (hasBatchDim) { assert(order[2] == 0); @@ -638,8 +655,10 @@ dotOperandMfmaToLinearLayout(DotOperandEncodingAttr dotMfmaLayout, } LinearLayout warpLayout = - identityND(S("warp"), warpsPerCTA, order, outDimNames); - LinearLayout ctaLayout = tileLayout * warpLayout; + identityND(kWarp, warpsPerCTA, warpOrder, outDimNames); + + LinearLayout ctaLayout = tileLayout.transposeOuts(outDimNames) * + warpLayout.transposeOuts(outDimNames); return combineCtaCgaWithShape(ctaLayout, mfmaLayout.getCTALayout(), shape); } diff --git a/unittest/Dialect/TritonGPU/DialectTest.cpp b/unittest/Dialect/TritonGPU/DialectTest.cpp index e3f521f1b3..c27c63335e 100644 --- a/unittest/Dialect/TritonGPU/DialectTest.cpp +++ b/unittest/Dialect/TritonGPU/DialectTest.cpp @@ -543,6 +543,12 @@ class AMDMfmaLayoutTest : public ::testing::Test { /*isTransposed=*/true, ctaLayout); } + triton::gpu::DotOperandEncodingAttr + createDotOperand(int idx, triton::gpu::AMDMfmaEncodingAttr parent, + int kWidth) { + return triton::gpu::DotOperandEncodingAttr::get(&ctx, idx, parent, kWidth); + } + protected: MLIRContext ctx; const SmallVector ctaPerCGA{1, 1, 1}; @@ -588,6 +594,32 @@ TEST_F(AMDMfmaLayoutTest, mfma16) { ASSERT_THAT(tmfma3d.getWarpOrder(), testing::ElementsAre(2u, 1u, 0u)); } +TEST_F(AMDMfmaLayoutTest, mfma_dot_op) { + auto mfma2d = createMFMA(32, 32, {2, 4}); + auto dot2dOp0 = createDotOperand(0, mfma2d, 4); + auto dot2dOp1 = createDotOperand(1, mfma2d, 4); + ASSERT_THAT(dot2dOp0.getWarpOrder(), mfma2d.getWarpOrder()); + ASSERT_THAT(dot2dOp1.getWarpOrder(), mfma2d.getWarpOrder()); + + auto tmfma2d = createTransposedMFMA(32, 32, {2, 4}); + auto tdot2dOp0 = createDotOperand(0, tmfma2d, 4); + auto tdot2dOp1 = createDotOperand(1, tmfma2d, 4); + ASSERT_THAT(tdot2dOp0.getWarpOrder(), tmfma2d.getWarpOrder()); + ASSERT_THAT(tdot2dOp1.getWarpOrder(), tmfma2d.getWarpOrder()); + + auto mfma3d = createMFMA(32, 32, {2, 4, 1}); + auto dot3dOp0 = createDotOperand(0, mfma3d, 4); + auto dot3dOp1 = createDotOperand(1, mfma3d, 4); + ASSERT_THAT(dot3dOp0.getWarpOrder(), mfma3d.getWarpOrder()); + ASSERT_THAT(dot3dOp1.getWarpOrder(), mfma3d.getWarpOrder()); + + auto tmfma3d = createTransposedMFMA(32, 32, {2, 4, 1}); + auto tdot3dOp0 = createDotOperand(0, tmfma3d, 4); + auto tdot3dOp1 = createDotOperand(1, tmfma3d, 4); + ASSERT_THAT(tdot3dOp0.getWarpOrder(), tmfma3d.getWarpOrder()); + ASSERT_THAT(tdot3dOp1.getWarpOrder(), tmfma3d.getWarpOrder()); +} + } // anonymous namespace } // namespace mlir::triton::gpu diff --git a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp index 76c9c44225..fd65233e5c 100644 --- a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp +++ b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp @@ -58,8 +58,8 @@ class LinearLayoutConversionsTest : public ::testing::Test { isTransposed, CTALayoutAttr::get(&ctx, cpg, cSplit, cOrd)); } - DotOperandEncodingAttr mfmaDot(AMDMfmaEncodingAttr mfma, unsigned opIdx, - unsigned kWidth) { + DotOperandEncodingAttr mfmaDotOp(AMDMfmaEncodingAttr mfma, unsigned opIdx, + unsigned kWidth) { return DotOperandEncodingAttr::get(&ctx, opIdx, mfma, kWidth); } @@ -738,10 +738,82 @@ TEST_F(LinearLayoutConversionsTest, MFMA32_2x4x1Warps) { {S("dim0"), S("dim1"), S("dim2")})); } +TEST_F(LinearLayoutConversionsTest, warp1onK_mfma32_lhs_kwidth8) { + auto parentMfma_1_8 = mfma(/*warps=*/{1, 8}, /*mDim=*/32, /*nDim=*/32, + /*isTransposed=*/false); + auto mfmaDot_1_8 = mfmaDotOp(parentMfma_1_8, /*opIdx=*/0, /*kWidth=*/8); + EXPECT_EQ(toLinearLayout({128, 128}, mfmaDot_1_8), + LinearLayout( + {{S("register"), + {{0, 1}, + {0, 2}, + {0, 4}, + {0, 16}, + {0, 32}, + {0, 64}, + {32, 0}, + {64, 0}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, {0, 8}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + EXPECT_EQ(toLinearLayout({128, 256}, mfmaDot_1_8), + LinearLayout( + {{S("register"), + {{0, 1}, + {0, 2}, + {0, 4}, + {0, 16}, + {0, 32}, + {0, 64}, + {0, 128}, + {32, 0}, + {64, 0}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, {0, 8}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + EXPECT_EQ(toLinearLayout({32, 64}, mfmaDot_1_8), + LinearLayout( + {{S("register"), {{0, 1}, {0, 2}, {0, 4}, {0, 16}, {0, 32}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, {0, 8}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + EXPECT_EQ(toLinearLayout({256, 256}, mfmaDot_1_8), + LinearLayout( + {{S("register"), + {{0, 1}, + {0, 2}, + {0, 4}, + {0, 16}, + {0, 32}, + {0, 64}, + {0, 128}, + {32, 0}, + {64, 0}, + {128, 0}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, {0, 8}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + EXPECT_EQ(toLinearLayout({16, 16}, mfmaDot_1_8), + LinearLayout( + {{S("register"), {{0, 1}, {0, 2}, {0, 4}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 0}, {0, 8}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + TEST_F(LinearLayoutConversionsTest, warp1onK_mfma32_rhs_kwidth8) { auto parentMfma_1_8 = mfma(/*warps=*/{1, 8}, /*mDim=*/32, /*nDim=*/32, /*isTransposed=*/false); - auto mfmaDot_1_8 = mfmaDot(parentMfma_1_8, /*opIdx=*/1, /*kWidth=*/8); + auto mfmaDot_1_8 = mfmaDotOp(parentMfma_1_8, /*opIdx=*/1, /*kWidth=*/8); EXPECT_EQ( toLinearLayout({128, 128}, mfmaDot_1_8), LinearLayout( @@ -778,9 +850,17 @@ TEST_F(LinearLayoutConversionsTest, warp1onK_mfma32_rhs_kwidth8) { {S("block"), {}}}, {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({16, 16}, mfmaDot_1_8), + LinearLayout( + {{S("register"), {{1, 0}, {2, 0}, {4, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 0}, {8, 0}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + auto parentMfma_1_4 = mfma(/*warps=*/{1, 4}, /*mDim=*/32, /*nDim=*/32, /*isTransposed=*/false); - auto mfmaDot_1_4 = mfmaDot(parentMfma_1_4, /*opIdx=*/1, /*kWidth=*/8); + auto mfmaDot_1_4 = mfmaDotOp(parentMfma_1_4, /*opIdx=*/1, /*kWidth=*/8); EXPECT_EQ(toLinearLayout({256, 256}, mfmaDot_1_4), LinearLayout( {{S("register"), @@ -798,10 +878,129 @@ TEST_F(LinearLayoutConversionsTest, warp1onK_mfma32_rhs_kwidth8) { {S("dim0"), S("dim1")})); } +TEST_F(LinearLayoutConversionsTest, warp1onK_mfma16_lhs_kwidth8) { + auto parentMfma_1_4 = mfma(/*warps=*/{1, 4}, /*mDim=*/16, /*nDim=*/16, + /*isTransposed=*/false); + auto mfmaDot_1_4 = mfmaDotOp(parentMfma_1_4, /*opIdx=*/0, /*kWidth=*/8); + EXPECT_EQ(toLinearLayout({128, 128}, mfmaDot_1_4), + LinearLayout( + {{S("register"), + {{0, 1}, + {0, 2}, + {0, 4}, + {0, 32}, + {0, 64}, + {16, 0}, + {32, 0}, + {64, 0}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 8}, {0, 16}}}, + {S("warp"), {{0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + EXPECT_EQ(toLinearLayout({1, 128}, mfmaDot_1_4), + LinearLayout( + {{S("register"), + { + {0, 1}, + {0, 2}, + {0, 4}, + {0, 32}, + {0, 64}, + }}, + {S("lane"), {{0, 0}, {0, 0}, {0, 0}, {0, 0}, {0, 8}, {0, 16}}}, + {S("warp"), {{0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + EXPECT_EQ( + toLinearLayout({128, 1}, mfmaDot_1_4), + LinearLayout( + {{S("register"), {{0, 0}, {0, 0}, {0, 0}, {16, 0}, {32, 0}, {64, 0}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 0}, {0, 0}}}, + {S("warp"), {{0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + EXPECT_EQ(toLinearLayout({256, 256}, mfmaDot_1_4), + LinearLayout( + {{S("register"), + {{0, 1}, + {0, 2}, + {0, 4}, + {0, 32}, + {0, 64}, + {0, 128}, + {16, 0}, + {32, 0}, + {64, 0}, + {128, 0}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 8}, {0, 16}}}, + {S("warp"), {{0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + EXPECT_EQ(toLinearLayout({16, 16}, mfmaDot_1_4), + LinearLayout( + {{S("register"), {{0, 1}, {0, 2}, {0, 4}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 8}, {0, 0}}}, + {S("warp"), {{0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + auto parentMfma_1_8 = mfma(/*warps=*/{1, 8}, /*mDim=*/16, /*nDim=*/16, + /*isTransposed=*/false); + auto mfmaDot_1_8 = mfmaDotOp(parentMfma_1_8, /*opIdx=*/0, /*kWidth=*/8); + EXPECT_EQ(toLinearLayout({256, 256}, mfmaDot_1_8), + LinearLayout( + {{S("register"), + {{0, 1}, + {0, 2}, + {0, 4}, + {0, 32}, + {0, 64}, + {0, 128}, + {16, 0}, + {32, 0}, + {64, 0}, + {128, 0}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 8}, {0, 16}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + auto parentMfma_1_8_1 = mfma(/*warps=*/{1, 1, 8}, /*mDim=*/16, /*nDim=*/16, + /*isTransposed=*/false); + auto mfmaDot_1_8_1 = mfmaDotOp(parentMfma_1_8_1, /*opIdx=*/0, /*kWidth=*/8); + + EXPECT_EQ(toLinearLayout({1, 256, 256}, mfmaDot_1_8_1), + LinearLayout({{S("register"), + {{0, 0, 1}, + {0, 0, 2}, + {0, 0, 4}, + {0, 0, 32}, + {0, 0, 64}, + {0, 0, 128}, + {0, 16, 0}, + {0, 32, 0}, + {0, 64, 0}, + {0, 128, 0}}}, + {S("lane"), + {{0, 1, 0}, + {0, 2, 0}, + {0, 4, 0}, + {0, 8, 0}, + {0, 0, 8}, + {0, 0, 16}}}, + {S("warp"), {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1"), S("dim2")})); +} + TEST_F(LinearLayoutConversionsTest, warp1onK_mfma16_rhs_kwidth8) { auto parentMfma_1_4 = mfma(/*warps=*/{1, 4}, /*mDim=*/16, /*nDim=*/16, /*isTransposed=*/false); - auto mfmaDot_1_4 = mfmaDot(parentMfma_1_4, /*opIdx=*/1, /*kWidth=*/8); + auto mfmaDot_1_4 = mfmaDotOp(parentMfma_1_4, /*opIdx=*/1, /*kWidth=*/8); EXPECT_EQ( toLinearLayout({128, 128}, mfmaDot_1_4), LinearLayout( @@ -843,9 +1042,17 @@ TEST_F(LinearLayoutConversionsTest, warp1onK_mfma16_rhs_kwidth8) { {S("block"), {}}}, {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({16, 16}, mfmaDot_1_4), + LinearLayout( + {{S("register"), {{1, 0}, {2, 0}, {4, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {8, 0}, {0, 0}}}, + {S("warp"), {{0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + auto parentMfma_1_8 = mfma(/*warps=*/{1, 8}, /*mDim=*/16, /*nDim=*/16, /*isTransposed=*/false); - auto mfmaDot_1_8 = mfmaDot(parentMfma_1_8, /*opIdx=*/1, /*kWidth=*/8); + auto mfmaDot_1_8 = mfmaDotOp(parentMfma_1_8, /*opIdx=*/1, /*kWidth=*/8); EXPECT_EQ( toLinearLayout({256, 256}, mfmaDot_1_8), LinearLayout( @@ -858,7 +1065,7 @@ TEST_F(LinearLayoutConversionsTest, warp1onK_mfma16_rhs_kwidth8) { auto parentMfma_1_8_1 = mfma(/*warps=*/{1, 1, 8}, /*mDim=*/16, /*nDim=*/16, /*isTransposed=*/false); - auto mfmaDot_1_8_1 = mfmaDot(parentMfma_1_8_1, /*opIdx=*/1, /*kWidth=*/8); + auto mfmaDot_1_8_1 = mfmaDotOp(parentMfma_1_8_1, /*opIdx=*/1, /*kWidth=*/8); EXPECT_EQ(toLinearLayout({1, 256, 256}, mfmaDot_1_8_1), LinearLayout({{S("register"), @@ -881,6 +1088,167 @@ TEST_F(LinearLayoutConversionsTest, warp1onK_mfma16_rhs_kwidth8) { {S("dim0"), S("dim1"), S("dim2")})); } +TEST_F(LinearLayoutConversionsTest, mfma32_dot_op_lhs_kwidth4) { + auto parentMfma32 = mfma(/*warps=*/{2, 4}, /*mDim=*/32, /*nDim=*/32, + /*isTransposed=*/false); + auto mfmaDotOp0_32 = mfmaDotOp(parentMfma32, /*opIdx=*/0, /*kWidth=*/4); + EXPECT_EQ(toLinearLayout({128, 128}, mfmaDotOp0_32), + LinearLayout( + {{S("register"), + {{0, 1}, {0, 2}, {0, 8}, {0, 16}, {0, 32}, {0, 64}, {64, 0}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, {0, 4}}}, + {S("warp"), {{0, 0}, {0, 0}, {32, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({64, 32}, mfmaDotOp0_32), + LinearLayout( + {{S("register"), {{0, 1}, {0, 2}, {0, 8}, {0, 16}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, {0, 4}}}, + {S("warp"), {{0, 0}, {0, 0}, {32, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({16, 16}, mfmaDotOp0_32), + LinearLayout( + {{S("register"), {{0, 1}, {0, 2}, {0, 8}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 0}, {0, 4}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + // Dot operand based on transposed mfma layout has same layout as ordinary + auto parentTMfma32 = mfma(/*warps=*/{2, 4}, /*mDim=*/32, /*nDim=*/32, + /*isTransposed=*/true); + auto tmfmaDotOp0_32 = mfmaDotOp(parentTMfma32, /*opIdx=*/0, /*kWidth=*/4); + + EXPECT_EQ(toLinearLayout({128, 128}, tmfmaDotOp0_32), + toLinearLayout({128, 128}, mfmaDotOp0_32)); + EXPECT_EQ(toLinearLayout({64, 32}, tmfmaDotOp0_32), + toLinearLayout({64, 32}, mfmaDotOp0_32)); + EXPECT_EQ(toLinearLayout({16, 16}, tmfmaDotOp0_32), + toLinearLayout({16, 16}, mfmaDotOp0_32)); +} + +TEST_F(LinearLayoutConversionsTest, mfma16_dot_op_lhs_kwidth4) { + auto parentMfma16 = mfma(/*warps=*/{2, 4}, /*mDim=*/16, /*nDim=*/16, + /*isTransposed=*/false); + auto mfmaDotOp0_16 = mfmaDotOp(parentMfma16, /*opIdx=*/0, /*kWidth=*/4); + EXPECT_EQ( + toLinearLayout({128, 128}, mfmaDotOp0_16), + LinearLayout( + {{S("register"), + {{0, 1}, {0, 2}, {0, 16}, {0, 32}, {0, 64}, {32, 0}, {64, 0}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 4}, {0, 8}}}, + {S("warp"), {{0, 0}, {0, 0}, {16, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({64, 32}, mfmaDotOp0_16), + LinearLayout( + {{S("register"), {{0, 1}, {0, 2}, {0, 16}, {32, 0}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 4}, {0, 8}}}, + {S("warp"), {{0, 0}, {0, 0}, {16, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({16, 16}, mfmaDotOp0_16), + LinearLayout( + {{S("register"), {{0, 1}, {0, 2}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 4}, {0, 8}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + // Dot operand based on transposed mfma layout has same layout as ordinary + auto parentTMfma16 = mfma(/*warps=*/{2, 4}, /*mDim=*/16, /*nDim=*/16, + /*isTransposed=*/true); + auto tmfmaDotOp0_16 = mfmaDotOp(parentTMfma16, /*opIdx=*/0, /*kWidth=*/4); + + EXPECT_EQ(toLinearLayout({128, 128}, tmfmaDotOp0_16), + toLinearLayout({128, 128}, mfmaDotOp0_16)); + EXPECT_EQ(toLinearLayout({64, 32}, tmfmaDotOp0_16), + toLinearLayout({64, 32}, mfmaDotOp0_16)); + EXPECT_EQ(toLinearLayout({16, 16}, tmfmaDotOp0_16), + toLinearLayout({16, 16}, mfmaDotOp0_16)); +} + +TEST_F(LinearLayoutConversionsTest, mfma32_dot_op_rhs_kwidth4) { + auto parentMfma32 = mfma(/*warps=*/{2, 4}, /*mDim=*/32, /*nDim=*/32, + /*isTransposed=*/false); + auto mfmaDotOp1_32 = mfmaDotOp(parentMfma32, /*opIdx=*/1, /*kWidth=*/4); + EXPECT_EQ( + toLinearLayout({128, 128}, mfmaDotOp1_32), + LinearLayout( + {{S("register"), {{1, 0}, {2, 0}, {8, 0}, {16, 0}, {32, 0}, {64, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {4, 0}}}, + {S("warp"), {{0, 32}, {0, 64}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({32, 64}, mfmaDotOp1_32), + LinearLayout( + {{S("register"), {{1, 0}, {2, 0}, {8, 0}, {16, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {4, 0}}}, + {S("warp"), {{0, 32}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({16, 16}, mfmaDotOp1_32), + LinearLayout( + {{S("register"), {{1, 0}, {2, 0}, {8, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 0}, {4, 0}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + // Dot operand based on transposed mfma layout has same layout as ordinary + auto parentTMfma32 = mfma(/*warps=*/{2, 4}, /*mDim=*/32, /*nDim=*/32, + /*isTransposed=*/true); + auto tmfmaDotOp1_32 = mfmaDotOp(parentTMfma32, /*opIdx=*/1, /*kWidth=*/4); + + EXPECT_EQ(toLinearLayout({128, 128}, tmfmaDotOp1_32), + toLinearLayout({128, 128}, mfmaDotOp1_32)); + EXPECT_EQ(toLinearLayout({64, 32}, tmfmaDotOp1_32), + toLinearLayout({64, 32}, mfmaDotOp1_32)); + EXPECT_EQ(toLinearLayout({16, 16}, tmfmaDotOp1_32), + toLinearLayout({16, 16}, mfmaDotOp1_32)); +} + +TEST_F(LinearLayoutConversionsTest, mfma16_dot_op_rhs_kwidth4) { + auto parentMfma16 = mfma(/*warps=*/{2, 4}, /*mDim=*/16, /*nDim=*/16, + /*isTransposed=*/false); + auto mfmaDotOp1_16 = mfmaDotOp(parentMfma16, /*opIdx=*/1, /*kWidth=*/4); + EXPECT_EQ(toLinearLayout({128, 128}, mfmaDotOp1_16), + LinearLayout( + {{S("register"), + {{1, 0}, {2, 0}, {16, 0}, {32, 0}, {64, 0}, {0, 64}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {4, 0}, {8, 0}}}, + {S("warp"), {{0, 16}, {0, 32}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({32, 64}, mfmaDotOp1_16), + LinearLayout( + {{S("register"), {{1, 0}, {2, 0}, {16, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {4, 0}, {8, 0}}}, + {S("warp"), {{0, 16}, {0, 32}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({16, 16}, mfmaDotOp1_16), + LinearLayout( + {{S("register"), {{1, 0}, {2, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {4, 0}, {8, 0}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + // Dot operand based on transposed mfma layout has same layout as ordinary + auto parentTMfma16 = mfma(/*warps=*/{2, 4}, /*mDim=*/16, /*nDim=*/16, + /*isTransposed=*/true); + auto tmfmaDotOp1_16 = mfmaDotOp(parentTMfma16, /*opIdx=*/1, /*kWidth=*/4); + + EXPECT_EQ(toLinearLayout({128, 128}, tmfmaDotOp1_16), + toLinearLayout({128, 128}, mfmaDotOp1_16)); + EXPECT_EQ(toLinearLayout({64, 32}, tmfmaDotOp1_16), + toLinearLayout({64, 32}, mfmaDotOp1_16)); + EXPECT_EQ(toLinearLayout({16, 16}, tmfmaDotOp1_16), + toLinearLayout({16, 16}, mfmaDotOp1_16)); +} + TEST_F(LinearLayoutConversionsTest, WMMA_2x4Warps) { auto legacy = wmma(/*warps=*/{2, 4}); From a1aa58b441748e3beff21471a1e68f12a36bde68 Mon Sep 17 00:00:00 2001 From: David Berard Date: Wed, 23 Oct 2024 10:00:00 -0700 Subject: [PATCH 3/7] [BACKEND] Use vectorized atomics on Hopper (#4971) Hopper supports vectorized atomics for add, max, and min. This PR adds support for generating these instructions. Note: atomic add/min/max also have packed instructions for f16x2 and bf16x2. Packed instructions were used prior to this PR, but vectorized instructions weren't. When vectorized instructions are available, this PR switches to using vectorized instructions (like .v2.f16 instead of .f16x2, or .v8.f16 instead of .v4.f16x2). When vectorized instructions aren't available, packed instructions will be used instead. This PR also adds a check for mask alignment, which wasn't previously checked. --- test/Conversion/tritongpu_to_llvm.mlir | 34 ++++- test/Conversion/tritongpu_to_llvm_hopper.mlir | 38 +++++ .../LoadStoreOpToLLVM.cpp | 144 ++++++++++++++---- 3 files changed, 184 insertions(+), 32 deletions(-) diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index e2f43f4ba6..e1a2ec68bd 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -1034,7 +1034,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.target" = "cuda:80"} { // CHECK-LABEL: atomic_add_f32 tt.func @atomic_add_f32(%arg0 : tensor<256x!tt.ptr, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) { // CHECK: llvm.inline_asm @@ -1048,7 +1048,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.target" = "cuda:80"} { // CHECK-LABEL: atomic_add_f32_scalar tt.func @atomic_add_f32_scalar(%arg0 : !tt.ptr, %arg1 : i1, %arg2 : f32) { // CHECK: llvm.icmp "eq" @@ -1062,7 +1062,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.target" = "cuda:80"} { // CHECK-LABEL: atomic_add_f32 tt.func @atomic_add_f32_sys_scope(%arg0 : tensor<256x!tt.ptr, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) { // CHECK: llvm.inline_asm @@ -1076,6 +1076,34 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- +#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @atomic_add_f16_nomask(%dest_ptrs: tensor<256x!tt.ptr, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf16, #blocked>) attributes {noinline = false} { + // CHECK-LABEL: atomic_add_f16_nomask + // CHECK: atom.global.gpu.acq_rel.add.noftz.f16x2 + // CHECK: atom.global.gpu.acq_rel.add.noftz.f16x2 + %0 = tt.atomic_rmw fadd, acq_rel, gpu, %dest_ptrs, %data : (tensor<256x!tt.ptr, #blocked>, tensor<256xf16, #blocked>) -> tensor<256xf16, #blocked> + tt.return + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @atomic_add_f16_withmask(%dest_ptrs: tensor<256x!tt.ptr, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf16, #blocked>, %mask: tensor<256xi1, #blocked>) attributes {noinline = false} { + // CHECK-LABEL: atomic_add_f16_withmask + // CHECK: atom.global.gpu.acq_rel.add.noftz.f16 + // CHECK: atom.global.gpu.acq_rel.add.noftz.f16 + // CHECK: atom.global.gpu.acq_rel.add.noftz.f16 + // CHECK: atom.global.gpu.acq_rel.add.noftz.f16 + %0 = tt.atomic_rmw fadd, acq_rel, gpu, %dest_ptrs, %data, %mask : (tensor<256x!tt.ptr, #blocked>, tensor<256xf16, #blocked>, tensor<256xi1, #blocked>) -> tensor<256xf16, #blocked> + tt.return + } +} + +// ----- + #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: store_f32 diff --git a/test/Conversion/tritongpu_to_llvm_hopper.mlir b/test/Conversion/tritongpu_to_llvm_hopper.mlir index d445299662..83653d57b6 100644 --- a/test/Conversion/tritongpu_to_llvm_hopper.mlir +++ b/test/Conversion/tritongpu_to_llvm_hopper.mlir @@ -241,3 +241,41 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : tt.return } } + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @atomic_add_f32_nomask(%dest_ptrs: tensor<256x!tt.ptr, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf32, #blocked>) attributes {noinline = false} { + // CHECK-LABEL: atomic_add_f32_nomask + // CHECK: atom.global.gpu.acq_rel.add.v4.f32 + %0 = tt.atomic_rmw fadd, acq_rel, gpu, %dest_ptrs, %data : (tensor<256x!tt.ptr, #blocked>, tensor<256xf32, #blocked>) -> tensor<256xf32, #blocked> + tt.return + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @atomic_add_f32_withmask(%dest_ptrs: tensor<256x!tt.ptr, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf32, #blocked>, %mask: tensor<256xi1, #blocked> {tt.constancy = 2 : i32}) attributes {noinline = false} { + // CHECK-LABEL: atomic_add_f32_withmask + // CHECK: atom.global.gpu.acq_rel.add.v2.f32 + // CHECK: atom.global.gpu.acq_rel.add.v2.f32 + %0 = tt.atomic_rmw fadd, acq_rel, gpu, %dest_ptrs, %data, %mask : (tensor<256x!tt.ptr, #blocked>, tensor<256xf32, #blocked>, tensor<256xi1, #blocked>) -> tensor<256xf32, #blocked> + tt.return + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @atomic_add_f16_withmask(%dest_ptrs: tensor<256x!tt.ptr, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf16, #blocked>, %mask: tensor<256xi1, #blocked> {tt.constancy = 4 : i32}) attributes {noinline = false} { + // CHECK-LABEL: atomic_add_f16_withmask + // CHECK: atom.global.gpu.acq_rel.add.noftz.v4.f16 + // CHECK: atom.global.gpu.acq_rel.add.noftz.v4.f16 + %0 = tt.atomic_rmw fadd, acq_rel, gpu, %dest_ptrs, %data, %mask : (tensor<256x!tt.ptr, #blocked>, tensor<256xf16, #blocked>, tensor<256xi1, #blocked>) -> tensor<256xf16, #blocked> + tt.return + } +} diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp index b19f3ac88e..760ba75d98 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -98,6 +98,23 @@ Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter, return mask; } +std::string getRegisterSizeCode(int size, bool is_float) { + switch (size) { + case 1: + return "b"; + case 16: + return "h"; + case 32: + return is_float ? "f" : "r"; + case 64: + return is_float ? "d" : "l"; + case 128: + return "q"; + default: + llvm_unreachable("Unsupported register size"); + } +} + // Contains some helper functions for both Load and Store conversions. struct LoadStoreConversionBase { explicit LoadStoreConversionBase(const NVIDIA::TargetInfo &targetInfo, @@ -632,6 +649,20 @@ struct AtomicRMWOpConversion : ConvertOpToLLVMPattern(converter, benefit), LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} + bool supportsVectorized(Operation *moduleOp, RMWOp opType, + Type elementType) const { + // vectorized atomics are only supported on hopper, + // and only for specific atomic ops (add, min, max). + // Note that "packed types" like f16x2 are supported sm60+. + auto computeCapability = getNVIDIAComputeCapability(moduleOp); + if (computeCapability < 90) { + return false; + } + + return opType == RMWOp::FADD && + (elementType.isF16() || elementType.isBF16() || elementType.isF32()); + } + LogicalResult matchAndRewrite(triton::AtomicRMWOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -664,45 +695,82 @@ struct AtomicRMWOpConversion : valueTy; const size_t valueElemNBits = valueElemTy.getIntOrFloatBitWidth(); auto elemsPerThread = getTotalElemsPerThread(val.getType()); - // vec = 1, numElements = 1 for scalar - auto vec = getVectorSize(ptr); - auto vecOrig = vec; - int numElems = 1; - // tensor + // packed: e.g. packed=2 for f16x2 + // vec: e.g. .v2, .v4, .v8 version of atom instruction. + unsigned vec, vecOrig; + int numElems, packed; if (tensorTy) { + vec = getVectorSize(ptr); + if (llMask) { + vec = std::min(vec, getMaskAlignment(op.getMask())); + } + vecOrig = vec; + packed = 1; auto valTy = cast(val.getType()); - vec = std::min(vec, valTy.getElementType().isF16() ? 2 : 1); - // mask + if (!supportsVectorized(moduleOp, atomicRmwAttr, + valTy.getElementType())) { + packed = + std::min(vecOrig, valTy.getElementType().isF16() ? 2 : 1); + vec = 1; + } numElems = tensorTy.getNumElements(); + } else { + // scalar + vec = 1; + vecOrig = 1; + numElems = 1; + packed = 1; } + assert((packed == 1 || vec == 1) && "packed or vec must be 1"); - if (vec == 1 && numElems > 1) + if (vec * packed == 1 && numElems > 1) op->emitRemark() << "Warning: vectorization fails vec = " << vec - << " origin vec = " << vecOrig + << " packed = " << packed << " origin vec = " << vecOrig << " numElems = " << numElems; Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo); - auto vecTy = vec_ty(valueElemTy, vec); + auto packedTy = vec_ty(valueElemTy, packed); SmallVector resultVals(elemsPerThread); - for (size_t i = 0; i < elemsPerThread; i += vec) { - Value rmwVal = undef(vecTy); - for (int ii = 0; ii < vec; ++ii) { - Value iiVal = createIndexAttrConstant( - rewriter, loc, getTypeConverter()->getIndexType(), ii); - rmwVal = insert_element(vecTy, rmwVal, valElements[i + ii], iiVal); - } - + for (size_t i = 0; i < elemsPerThread; i += vec * packed) { Value rmwPtr = ptrElements[i]; Value rmwMask = llMask ? and_(mask, maskElements[i]) : mask; std::string sTy; PTXBuilder ptxBuilderAtomicRMW; - std::string tyId = valueElemNBits * vec == 64 - ? "l" - : (valueElemNBits * vec == 32 ? "r" : "h"); - auto *dstOpr = ptxBuilderAtomicRMW.newOperand("=" + tyId, /*init=*/true); + // 16-bit -> "h", 32-bit -> "r", 64-bit -> "l" + std::string tyId = + getRegisterSizeCode(valueElemNBits * packed, /*is_float=*/false); + + PTXBuilder::Operand *dstOpr; + if (vec > 1) { + dstOpr = ptxBuilderAtomicRMW.newListOperand(); + for (unsigned ii = 0; ii < vec; ++ii) { + dstOpr->listAppend( + ptxBuilderAtomicRMW.newOperand("=" + tyId, /*init=*/true)); + } + } else { + dstOpr = ptxBuilderAtomicRMW.newOperand("=" + tyId, /*init=*/true); + } + auto *ptrOpr = ptxBuilderAtomicRMW.newAddrOperand(rmwPtr, "l"); - auto *valOpr = ptxBuilderAtomicRMW.newOperand(rmwVal, tyId); + + PTXBuilder::Operand *valOpr; + if (vec > 1) { + valOpr = ptxBuilderAtomicRMW.newListOperand(); + for (unsigned ii = 0; ii < vec; ++ii) { + valOpr->listAppend( + ptxBuilderAtomicRMW.newOperand(valElements[i + ii], tyId)); + } + } else if (packed > 1) { + Value rmwVal = undef(packedTy); + for (int ii = 0; ii < packed; ++ii) { + rmwVal = insert_element(packedTy, rmwVal, valElements[i + ii], + i32_val(ii)); + } + valOpr = ptxBuilderAtomicRMW.newOperand(rmwVal, tyId); + } else { + valOpr = ptxBuilderAtomicRMW.newOperand(valElements[i], tyId); + } auto scope = stringifyMemSyncScope(op.getScope()).str(); auto &atom = ptxBuilderAtomicRMW.create<>("atom")->global().o(scope); @@ -725,7 +793,7 @@ struct AtomicRMWOpConversion rmwOp = "add"; rmwOp += (valueElemNBits == 16 ? ".noftz" : ""); sTy = "f" + sBits; - sTy += (vec == 2 && valueElemNBits == 16) ? "x2" : ""; + sTy += (packed == 2 && valueElemNBits == 16) ? "x2" : ""; break; case RMWOp::MAX: sTy = "s" + sBits; @@ -750,15 +818,33 @@ struct AtomicRMWOpConversion std::string semStr; llvm::raw_string_ostream os(semStr); os << op.getSem(); - atom.o(semStr).o(rmwOp).o(sTy); + atom.o(semStr).o(rmwOp).v(vec).o(sTy); if (tensorTy) { atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask); - auto retType = vec == 1 ? valueElemTy : vecTy; + Type retType; + if (vec > 1) { + SmallVector retTys(vec, valueElemTy); + retType = struct_ty(retTys); + } else if (packed > 1) { + retType = packedTy; + } else { + retType = valueElemTy; + } + auto ret = ptxBuilderAtomicRMW.launch(rewriter, loc, retType); - for (int ii = 0; ii < vec; ++ii) { - resultVals[i + ii] = - vec == 1 ? ret : extract_element(valueElemTy, ret, i32_val(ii)); + + if (vec > 1) { + for (unsigned ii = 0; ii < vec; ++ii) { + resultVals[i + ii] = extract_val(valueElemTy, ret, ii); + } + } else if (packed > 1) { + for (unsigned ii = 0; ii < packed; ++ii) { + resultVals[i + ii] = extract_element(valueElemTy, ret, i32_val(ii)); + } + } else { + resultVals[i] = ret; } + } else { auto ASMReturnTy = void_ty(ctx); atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask); From 6ad95ee4fd9b1e172717323460fd54c250dd7d65 Mon Sep 17 00:00:00 2001 From: Minjang Kim Date: Wed, 23 Oct 2024 10:03:07 -0700 Subject: [PATCH 4/7] [AUTOTUNER] A quick follow-up for more device-independent do_bench (#4974) This is a quick follow-up for the recent autotuner/testing changes as in https://github.com/triton-lang/triton/pull/4496. This PR moves the empty cache creation into the driver code to make the code more device independent. --- python/triton/testing.py | 8 ++------ third_party/amd/backend/driver.py | 7 +++++++ third_party/nvidia/backend/driver.py | 9 +++++++++ 3 files changed, 18 insertions(+), 6 deletions(-) diff --git a/python/triton/testing.py b/python/triton/testing.py index c6bee23863..71cb8ab1ea 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -92,7 +92,7 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mod return _summarize_statistics(torch.tensor(ret), quantiles, return_mode) -def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean", device_type="cuda"): +def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean"): """ Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with the 20-th and 80-th performance percentile. @@ -117,11 +117,7 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_m fn() di.synchronize() - # We maintain a buffer of 256 MB that we clear - # before each kernel call to make sure that the L2 cache - # doesn't contain any input data before the run - cache_size = 256 * 1024 * 1024 - cache = torch.empty(int(cache_size // 4), dtype=torch.int, device='cuda') + cache = runtime.driver.active.get_empty_cache_for_benchmark() # Estimate the runtime of the function start_event = di.Event(enable_timing=True) diff --git a/third_party/amd/backend/driver.py b/third_party/amd/backend/driver.py index 86c9dd4339..6e1a368bf8 100644 --- a/third_party/amd/backend/driver.py +++ b/third_party/amd/backend/driver.py @@ -503,3 +503,10 @@ def get_current_target(self): def get_benchmarker(self): from triton.testing import do_bench return do_bench + + def get_empty_cache_for_benchmark(self): + import torch + + # It's the same as the Nvidia backend. + cache_size = 256 * 1024 * 1024 + return torch.empty(int(cache_size // 4), dtype=torch.int, device='cuda') diff --git a/third_party/nvidia/backend/driver.py b/third_party/nvidia/backend/driver.py index 286f8cb52a..38ce62b0c2 100644 --- a/third_party/nvidia/backend/driver.py +++ b/third_party/nvidia/backend/driver.py @@ -452,3 +452,12 @@ def is_active(): def get_benchmarker(self): from triton.testing import do_bench return do_bench + + def get_empty_cache_for_benchmark(self): + import torch + + # We maintain a buffer of 256 MB that we clear + # before each kernel call to make sure that the L2 cache + # doesn't contain any input data before the run + cache_size = 256 * 1024 * 1024 + return torch.empty(int(cache_size // 4), dtype=torch.int, device='cuda') From 4a5431159a11bcf288c4bf682093d4a6e3320228 Mon Sep 17 00:00:00 2001 From: Mario Lezcano Casado <3291265+lezcano@users.noreply.github.com> Date: Wed, 23 Oct 2024 19:55:51 +0100 Subject: [PATCH 5/7] [BACKEND] Fix when trying to convert an mma> into blocked (#4980) The bitwidth is unimplemented in LLVM for pointer types so it throws an exception when evaluating the condition `tensorTy.getElementType().getIntOrFloatBitWidth()` --- lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index b5fe1463af..9bc3be036c 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -1016,6 +1016,8 @@ bool canUseStMatrix(RankedTensorType tensorTy, ArrayRef repShape, mlir::dyn_cast(tensorTy.getEncoding()); if (!mmaLayout || !mmaLayout.isHopper()) return false; + if (isa(tensorTy.getElementType())) + return false; if (tensorTy.getElementType().getIntOrFloatBitWidth() != 16) return false; if (order[0] != 1) From 3c13f09ef992594f0cee020b23bc41d45cd87fde Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Wed, 23 Oct 2024 20:23:31 -0700 Subject: [PATCH 6/7] [AMD] NFC: Refactor AccelerateAMDMatmul patterns (#4985) This commit refactors the AccelerateAMDMatmul patterns in prep for mxfp support. --- .../AccelerateAMDMatmul.cpp | 211 +++++++++--------- 1 file changed, 109 insertions(+), 102 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp index a26a18ed96..dbf5be6932 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp @@ -5,6 +5,8 @@ #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "triton/Analysis/Utility.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include @@ -36,16 +38,15 @@ int getWmmaVersion(StringRef archGen) { return 0; } -SmallVector warpsPerTile(tt::DotOp dotOp, - const ArrayRef shape, - int numWarps, - SmallVector shapePerWarp) { +SmallVector +warpsPerTile(Operation *dotOp, ArrayRef shape, int numWarps, + std::pair shapePerWarp) { auto rank = shape.size(); // Early exit for batched matmul if (rank == 3) return {(unsigned)numWarps, 1, 1}; - auto filter = [&dotOp](Operation *op) { + auto filter = [dotOp](Operation *op) { return op->getParentRegion() == dotOp->getParentRegion(); }; ForwardSliceOptions fwdOpt; @@ -55,7 +56,7 @@ SmallVector warpsPerTile(tt::DotOp dotOp, bwdOpt.filter = filter; auto slices = getSlice(dotOp, bwdOpt, fwdOpt); for (Operation *op : slices) - if (isa(op) && (op != dotOp)) + if (op->hasTrait() && (op != dotOp)) return {(unsigned)numWarps, 1}; SmallVector tensorShape = {shape[0], shape[1]}; @@ -63,9 +64,9 @@ SmallVector warpsPerTile(tt::DotOp dotOp, do { if (ret[0] * ret[1] >= numWarps) break; - if (tensorShape[0] / (shapePerWarp[0] * 2) / ret[0] >= - tensorShape[1] / shapePerWarp[1] / ret[1]) { - if (ret[0] < tensorShape[0] / shapePerWarp[0]) { + if (tensorShape[0] / (shapePerWarp.first * 2) / ret[0] >= + tensorShape[1] / shapePerWarp.second / ret[1]) { + if (ret[0] < tensorShape[0] / shapePerWarp.first) { ret[0] *= 2; } else ret[1] *= 2; @@ -74,24 +75,89 @@ SmallVector warpsPerTile(tt::DotOp dotOp, } } while (true); - if (ret[1] * shapePerWarp[1] > tensorShape[1]) { + if (ret[1] * shapePerWarp.second > tensorShape[1]) { return {ret[1], ret[0]}; } return ret; } -SmallVector -warpsPerTileMFMA(tt::DotOp dotOp, const ArrayRef shape, int numWarps, - SmallVector shapePerWarp) { +SmallVector +warpsPerTileMFMA(Operation *dotOp, ArrayRef shape, int numWarps, + std::pair shapePerWarp) { return warpsPerTile(dotOp, shape, numWarps, shapePerWarp); } -SmallVector -warpsPerTileWMMA(tt::DotOp dotOp, const ArrayRef shape, int numWarps) { - return warpsPerTile(dotOp, shape, numWarps, - {ttg::AMDWmmaEncodingAttr::getMNKDimPerInstr()[0], - ttg::AMDWmmaEncodingAttr::getMNKDimPerInstr()[1]}); +SmallVector +warpsPerTileWMMA(Operation *dotOp, ArrayRef shape, int numWarps) { + auto mnk = ttg::AMDWmmaEncodingAttr::getMNKDimPerInstr(); + return warpsPerTile(dotOp, shape, numWarps, {mnk[0], mnk[1]}); +} + +// Chooses a proper MFMA instruction that can used to compute the given dot op. +// If enforcedNonKDim is not zero, it will be used to overwrite the default +// logic to chose a MFMA with matching M/N dim. +FailureOr chooseMfmaInstruction(RankedTensorType cType, + Type aElemType, Type bElemType, + int inputKSize, int mfmaVersion, + int enforcedNonKDim) { + // number of matrix elements along k dim per one MFMA intruction + unsigned kDim = 0; + + auto resShape = cType.getShape(); + auto rank = resShape.size(); + auto M = resShape[rank - 2]; + auto N = resShape[rank - 1]; + + unsigned mDim = 0; + unsigned nDim = 0; + if (enforcedNonKDim != 0) { + mDim = nDim = enforcedNonKDim; + } else { + int minSize = std::min(M, N); + if (minSize >= 32) { + mDim = 32; + nDim = 32; + } + if (minSize >= 16 && minSize < 32) { + mDim = 16; + nDim = 16; + } + if (minSize < 16) { + if (M < 16 && N >= 64) { + mDim = 4; + nDim = 64; + } else if (M >= 64 && N < 16) { + mDim = 64; + nDim = 4; + } else { + assert(inputKSize >= 64 && + "k should be at least 64 to use this layout"); + mDim = 4; + nDim = 4; + } + } + } + assert(mDim != 0 && nDim != 0); + + auto maybeMfmaInsn = + MfmaInsn::selectMfma(mDim, nDim, aElemType, bElemType, mfmaVersion); + if (failed(maybeMfmaInsn)) + llvm::report_fatal_error("No match found in MFMA database\n"); + + kDim = maybeMfmaInsn->getKDim(); + assert(kDim != 0); + assert(M % mDim == 0 && N % nDim == 0); + assert(inputKSize % kDim == 0); + return maybeMfmaInsn; +} + +FailureOr chooseMfmaInstruction(tt::DotOp dot, int mfmaVersion, + int nonKDim) { + RankedTensorType aType = dot.getA().getType(); + return chooseMfmaInstruction(dot.getC().getType(), aType.getElementType(), + dot.getB().getType().getElementType(), + aType.getShape().back(), mfmaVersion, nonKDim); } using OperandTypesVector = SmallVector; @@ -259,15 +325,16 @@ Value convertAndCastTensor(PatternRewriter &rewriter, Value value, return castedTensor; } -class BlockedToMFMA : public RewritePattern { +class BlockedToMFMA : public OpRewritePattern { int mfmaVersion; - int enforcedNonKDim; + int nonKDim; int kPack; public: - BlockedToMFMA(MLIRContext *context, int mfmaVersion, int nonKDim, int kPack) - : RewritePattern(tt::DotOp::getOperationName(), 2, context), - mfmaVersion(mfmaVersion), enforcedNonKDim(nonKDim), kPack(kPack) {} + BlockedToMFMA(MLIRContext *context, int mfmaVersion, int nonKDim, int kPack, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), mfmaVersion(mfmaVersion), + nonKDim(nonKDim), kPack(kPack) {} bool isSecondDot(tt::DotOp &dotOp) const { auto filter = [&dotOp](Operation *op) { @@ -285,75 +352,15 @@ class BlockedToMFMA : public RewritePattern { return false; } - /// @brief Choose MFMA instruction parameters - /// @param dot target dot operation - /// @return MfmaInsn or failure - FailureOr chooseMfmaInstruction(tt::DotOp dot) const { - // number of matrix elements along k dim per one MFMA intruction - unsigned kDim = 0; - auto opType = cast(dot.getA().getType()); - auto dataTypeA = opType.getElementType(); - auto dataTypeB = - cast(dot.getB().getType()).getElementType(); - - auto resType = cast(dot.getD().getType()); - auto resShape = resType.getShape(); - auto rank = resShape.size(); - auto M = resShape[rank - 2]; - auto N = resShape[rank - 1]; - - unsigned mDim = 0; - unsigned nDim = 0; - if (enforcedNonKDim != 0) { - mDim = enforcedNonKDim; - nDim = enforcedNonKDim; - } else { - int minSize = std::min(M, N); - if (minSize >= 32) { - mDim = 32; - nDim = 32; - } - if (minSize >= 16 && minSize < 32) { - mDim = 16; - nDim = 16; - } - if (minSize < 16) { - if (M < 16 && N >= 64) { - mDim = 4; - nDim = 64; - } else if (M >= 64 && N < 16) { - mDim = 64; - nDim = 4; - } else { - assert(opType.getShape()[rank - 1] >= 64 && - "k should be at least 64 to use this layout"); - mDim = 4; - nDim = 4; - } - } - } - assert(mDim != 0 && nDim != 0); - - auto maybeMfmaInsn = - MfmaInsn::selectMfma(mDim, nDim, dataTypeA, dataTypeB, mfmaVersion); - if (failed(maybeMfmaInsn)) - llvm::report_fatal_error("No match found in MFMA database\n"); - - kDim = maybeMfmaInsn->getKDim(); - assert(kDim != 0); - assert(M % mDim == 0 && N % nDim == 0); - assert(opType.getShape()[rank - 1] % kDim == 0); - return maybeMfmaInsn; - } - - LogicalResult matchAndRewrite(Operation *op, + LogicalResult matchAndRewrite(tt::DotOp dotOp, PatternRewriter &rewriter) const override { - auto dotOp = cast(op); - RankedTensorType oldRetType = dotOp.getType(); if (!oldRetType.getEncoding() || !isa(oldRetType.getEncoding())) return failure(); + if (!isa_and_nonnull(dotOp.getType().getEncoding())) + return rewriter.notifyMatchFailure( + dotOp, "expected blocked encoding result tensor"); if (!supportMFMA(dotOp)) return failure(); @@ -362,7 +369,7 @@ class BlockedToMFMA : public RewritePattern { // get MFMA encoding for the given number of warps auto retShape = oldRetType.getShape(); - auto mod = op->getParentOfType(); + auto mod = dotOp->getParentOfType(); int numWarps = ttg::TritonGPUDialect::getNumWarps(mod); // operands @@ -374,7 +381,7 @@ class BlockedToMFMA : public RewritePattern { ttg::AMDMfmaEncodingAttr mfmaEnc; - auto mfmaInstr = chooseMfmaInstruction(dotOp); + auto mfmaInstr = chooseMfmaInstruction(dotOp, mfmaVersion, nonKDim); auto mDim = mfmaInstr.value().getMDim(); auto nDim = mfmaInstr.value().getNDim(); auto kDim = mfmaInstr.value().getKDim(); @@ -397,7 +404,7 @@ class BlockedToMFMA : public RewritePattern { mfmaAccType = rewriter.getF32Type(); // convert accumulator - auto oldAcc = dotOp.getOperand(2); + auto oldAcc = dotOp.getC(); auto newAcc = convertAndCastTensor(rewriter, oldAcc, mfmaEnc, mfmaAccType); // Here is a brief explanation of kWidth, kBase, and kDim @@ -456,11 +463,12 @@ class BlockedToMFMA : public RewritePattern { convertAndCastTensor(rewriter, newDot, oldRetType.getEncoding(), oldRetType.getElementType()); - rewriter.replaceOp(op, dotOutput); + rewriter.replaceOp(dotOp, dotOutput); return success(); } }; + static Value promoteOperand(OpBuilder &builder, Location loc, Value operand, Type promotedType) { Type tensorPromotedType = cast(operand.getType()) @@ -566,18 +574,17 @@ static void decomposeMixedModeDotOp(ModuleOp mod) { }); } -class BlockedToWMMA : public RewritePattern { +class BlockedToWMMA : public OpRewritePattern { int wmmaVersion; public: - BlockedToWMMA(MLIRContext *context, int wmmaVersion) - : RewritePattern(tt::DotOp::getOperationName(), 2, context), - wmmaVersion(wmmaVersion) {} + BlockedToWMMA(MLIRContext *context, int wmmaVersion, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), wmmaVersion(wmmaVersion) {} - LogicalResult matchAndRewrite(Operation *op, + LogicalResult matchAndRewrite(tt::DotOp dotOp, PatternRewriter &rewriter) const override { - auto ctx = op->getContext(); - auto dotOp = cast(op); + auto ctx = dotOp->getContext(); Value a = dotOp.getA(); Value b = dotOp.getB(); @@ -603,7 +610,7 @@ class BlockedToWMMA : public RewritePattern { if (wmmaVersion == 2 && llvm::isa(oldAType) && oldAType.getIntOrFloatBitWidth() == 8) { - return rewriter.notifyMatchFailure(op, "not supported yet"); + return rewriter.notifyMatchFailure(dotOp, "not supported yet"); } // get operand types @@ -612,7 +619,7 @@ class BlockedToWMMA : public RewritePattern { return failure(); // get WMMA encoding for the given number of warps - auto mod = op->getParentOfType(); + auto mod = dotOp->getParentOfType(); int numWarps = ttg::TritonGPUDialect::getNumWarps(mod); ttg::AMDWmmaEncodingAttr wmmaEnc; @@ -626,7 +633,7 @@ class BlockedToWMMA : public RewritePattern { auto newRetType = RankedTensorType::get(retShape, operandTypes[3], wmmaEnc); // convert accumulator - auto oldAcc = dotOp.getOperand(2); + auto oldAcc = dotOp.getC(); auto newAcc = convertAndCastTensor(rewriter, oldAcc, wmmaEnc, operandTypes[2]); @@ -653,7 +660,7 @@ class BlockedToWMMA : public RewritePattern { Value dotOutput = convertAndCastTensor(rewriter, newDot, oldRetEncoding, oldRetType.getElementType()); - rewriter.replaceOp(op, dotOutput); + rewriter.replaceOp(dotOp, dotOutput); return success(); } }; From 3613bf40d90a38766ec65a250aeadb391f9f7fc9 Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Thu, 24 Oct 2024 00:49:19 -0700 Subject: [PATCH 7/7] [BACKEND] Fix the register accessing order of dot operands of mmav2 (#4979) --- .../TritonGPUToLLVM/ElementwiseOpToLLVM.cpp | 44 ++++++++++++++----- .../ConvertLayoutOpToLLVM.cpp | 26 ++--------- .../SharedToDotOperandMMAv2.cpp | 2 +- .../DotOpToLLVM/MMAv2.cpp | 36 +++++++++++++-- .../UpcastMXFPToLLVM.cpp | 13 ------ 5 files changed, 72 insertions(+), 49 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp index 8762942c31..8ee1668669 100644 --- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -41,36 +41,60 @@ SmallVector reorderValues(const SmallVector &values, Type inType, if (inBitWidth == ouBitWidth) return values; if (inBitWidth == 16 && ouBitWidth == 32) { + // Register layout conversion: + // + // [0, 1], [4, 5] ⟶ [0], [1], [4], [5] + // [2, 3], [6, 7] [2], [3], [6], [7] + // + // Original access order: + // + // [0, 1], [2, 3], [4, 5], [6, 7] + // + // Transformed access order: + // + // [0], [2], [1], [3], [4], [6], [5], [7] SmallVector ret; for (unsigned i = 0; i < values.size(); i += 8) { ret.push_back(values[i]); - ret.push_back(values[i + 1]); - ret.push_back(values[i + 4]); - ret.push_back(values[i + 5]); ret.push_back(values[i + 2]); + ret.push_back(values[i + 1]); ret.push_back(values[i + 3]); + ret.push_back(values[i + 4]); ret.push_back(values[i + 6]); + ret.push_back(values[i + 5]); ret.push_back(values[i + 7]); } return ret; } if (inBitWidth == 8 && ouBitWidth == 16) { + // Register layout conversion: + // + // [0, 1, 2, 3], [8, 9, 10, 11] ⟶ [0, 1], [2, 3], [8, 9], [10, 11] + // [4, 5, 6, 7], [12, 13, 14, 15] [4, 5], [6, 7], [12, 13], [14, 15] + // + // Original access order: + // + // [0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15] + // + // Transformed access order: + // + // [0, 1], [4, 5], [2, 3], [6, 7], [8, 9], [12, 13], [10, 11], [14, 15] SmallVector ret; for (unsigned i = 0; i < values.size(); i += 16) { - ret.push_back(values[i + 0]); + ret.push_back(values[i]); ret.push_back(values[i + 1]); - ret.push_back(values[i + 2]); - ret.push_back(values[i + 3]); - ret.push_back(values[i + 8]); - ret.push_back(values[i + 9]); - ret.push_back(values[i + 10]); - ret.push_back(values[i + 11]); ret.push_back(values[i + 4]); ret.push_back(values[i + 5]); + ret.push_back(values[i + 2]); + ret.push_back(values[i + 3]); ret.push_back(values[i + 6]); ret.push_back(values[i + 7]); + ret.push_back(values[i + 8]); + ret.push_back(values[i + 9]); ret.push_back(values[i + 12]); ret.push_back(values[i + 13]); + ret.push_back(values[i + 10]); + ret.push_back(values[i + 11]); ret.push_back(values[i + 14]); ret.push_back(values[i + 15]); } diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 54371d063f..71fd3c0cd4 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -641,7 +641,6 @@ struct ConvertLayoutOpConversion // for the destination type, we need to pack values together // so they can be consumed by tensor core operations SmallVector vecVals; - SmallVector types; // For some reasons, LLVM's NVPTX backend inserts unnecessary (?) integer // instructions to pack & unpack sub-word integers. A workaround is to // store the results of ldmatrix in i32 @@ -655,37 +654,20 @@ struct ConvertLayoutOpConversion shl(i32_ty, zext(i32_ty, vals[i + j]), i32_val(elemSize * j)); val = or_(i32_ty, val, ext); } - vecVals.push_back(val); + vecVals.push_back(bitcast(val, i32_ty)); } - elems = elems / (32 / elemSize); - types = SmallVector(elems, i32_ty); } else { unsigned vecSize = std::max(32 / elemSize, 1); Type vecTy = vec_ty(elemTy, vecSize); - types = SmallVector(elems / vecSize, vecTy); for (unsigned i = 0; i < elems; i += vecSize) { Value packed = rewriter.create(loc, vecTy); for (unsigned j = 0; j < vecSize; j++) packed = insert_element(vecTy, packed, vals[i + j], i32_val(j)); - vecVals.push_back(packed); + vecVals.push_back(bitcast(packed, i32_ty)); } } - - // This needs to be ordered the same way that - // ldmatrix.x4 would order it - // TODO: this needs to be refactor so we don't - // implicitly depends on how emitOffsetsForMMAV2 - // is implemented - SmallVector reorderedVals; - for (unsigned i = 0; i < vecVals.size(); i += 4) { - reorderedVals.push_back(bitcast(vecVals[i], i32_ty)); - reorderedVals.push_back(bitcast(vecVals[i + 2], i32_ty)); - reorderedVals.push_back(bitcast(vecVals[i + 1], i32_ty)); - reorderedVals.push_back(bitcast(vecVals[i + 3], i32_ty)); - } - - Value view = packLLElements(loc, getTypeConverter(), reorderedVals, - rewriter, dstTy); + Value view = + packLLElements(loc, getTypeConverter(), vecVals, rewriter, dstTy); rewriter.replaceOp(op, view); return success(); } diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp index 73c21cae6d..21c2bee584 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp @@ -513,8 +513,8 @@ Value composeValuesToDotOperandLayoutStruct( for (int m = 0; m < n0; ++m) for (int k = 0; k < n1; ++k) { elems.push_back(vals.at({b, 2 * m, 2 * k})); - elems.push_back(vals.at({b, 2 * m, 2 * k + 1})); elems.push_back(vals.at({b, 2 * m + 1, 2 * k})); + elems.push_back(vals.at({b, 2 * m, 2 * k + 1})); elems.push_back(vals.at({b, 2 * m + 1, 2 * k + 1})); } assert(!elems.empty()); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp index 79ccb57206..c2940a0438 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp @@ -75,9 +75,39 @@ ValueTableV2 getValuesFromDotOperandLayoutStruct( // For kWidth = 8, split the mma into 4 mmas with "stride 4" along K if (dot.getOpIdx() == 0) { - si = llvm::SmallVector{0, 8, 4, 12, 1, 9, 5, 13, - 2, 10, 6, 14, 3, 11, 7, 15}; + // Original register layout: + // + // [0, 1, 2, 3], [8, 9, 10, 11] + // [4, 5, 6, 7], [12, 13, 14, 15] + // + // Each element in the layout consists of two bf16 values. + // For example, the row [0, 1, 2, 3] expands to: + // + // [[0/0, 0/1], [1/0, 1/1], [2/0, 2/1], [3/0, 3/1]] + // + // Here, 0/0 refers to the first half of element 0, and 0/1 refers to the + // second half, matching kWidth = 8. + // + // To derive four independent MMA operations, a stride of 4 is applied to + // the original register layout: + // + // 1st MMA: [0, 4, 8, 12] + // 2nd MMA: [1, 5, 9, 13] + // 3rd MMA: [2, 6, 10, 14] + // 4th MMA: [3, 7, 11, 15] + si = llvm::SmallVector{0, 4, 8, 12, 1, 5, 9, 13, + 2, 6, 10, 14, 3, 7, 11, 15}; } else { + // Original register layout: + // + // [0, 1, 2, 3]^T, [4, 5, 6, 7]^T + // + // A stride of 4 is applied to derive four independent MMA operations: + // + // 1st MMA: [0, 4] + // 2nd MMA: [1, 5] + // 3rd MMA: [2, 6] + // 4th MMA: [3, 7] si = llvm::SmallVector{0, 4, 1, 5, 2, 6, 3, 7}; } @@ -112,8 +142,8 @@ ValueTableV2 getValuesFromDotOperandLayoutStruct( for (auto i = 0; i < n0; ++i) { for (auto j = 0; j < n1; j++) { vals[{b, 2 * i, 2 * j}] = elems[offset++]; - vals[{b, 2 * i, 2 * j + 1}] = elems[offset++]; vals[{b, 2 * i + 1, 2 * j}] = elems[offset++]; + vals[{b, 2 * i, 2 * j + 1}] = elems[offset++]; vals[{b, 2 * i + 1, 2 * j + 1}] = elems[offset++]; } } diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp index 9404bb4474..722bf56cd0 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp @@ -80,19 +80,6 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { ret.push_back(v); } } - // FIXME [Dot LL] - // The DotOperandEncodingAttr without LLs encodes the - // layout as - // e0 e1 - // e2 e3 - // rather than transposed that, as the PTX docs say - // We transpose every block of 4 elements (kWidth = 8 -> 4 bf16x2) - assert(ret.size() % 16 == 0); - for (int i = 0; i < ret.size() / 16; ++i) { - for (int j = 0; j < 4; ++j) { - std::swap(ret[16 * i + j + 4], ret[16 * i + j + 8]); - } - } return ret; }