diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 378a501175..084bcc1583 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -1044,16 +1044,16 @@ SmallVector DotOperandEncodingAttr::getCTASplitNum() const { return res; } SmallVector DotOperandEncodingAttr::getWarpsPerCTA() const { - auto parentLayout = getParent(); - assert(parentLayout && "DotOperandEncodingAttr must have a parent"); - if (auto distributedLayout = - mlir::dyn_cast(parentLayout)) { - return distributedLayout.getWarpsPerCTA(); - } else { - llvm::report_fatal_error( - "DotOperandEncodingAttr non-DistributedEncodingAttr parent not " - "supported yet"); - } + auto distributedLayout = mlir::cast(getParent()); + auto warps = distributedLayout.getWarpsPerCTA(); + // FIXME: This is a temporary solution to avoid distribute-to-warps.mlir + // failure. + if (mlir::triton::tools::getBoolEnv("TRITON_INTEL_ADVANCED_PATH")) + return warps; + auto rank = warps.size(); + auto kDim = getOpIdx() == 0 ? rank - 1 : rank - 2; + warps[kDim] = 1; + return warps; } SmallVector DotOperandEncodingAttr::getWarpOrder() const { return ::getWarpOrder(*this); diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index 7d508f234f..7e6327e3c5 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -5,6 +5,7 @@ #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" +#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h" #include "triton/Tools/LinearLayout.h" #include "triton/Tools/StrUtil.h" #include "llvm/ADT/DenseMap.h" @@ -822,9 +823,68 @@ SliceEncodingAttr::toLinearLayout(ArrayRef shape) const { return ret; } +LinearLayout ampereDotToLinearLayout(ArrayRef shape, + DotOperandEncodingAttr dot) { + // TODO,BE. Implement ampereMMA in terms of this one + int rank = shape.size(); + auto mma = cast(dot.getParent()); + int kWidth = dot.getKWidth(); + bool isA = dot.getOpIdx() == 0; + + assert(mma.isAmpere()); + assert((rank == 2 && mma.getInstrShape() == ArrayRef({16, 8})) || + (rank == 3 && mma.getInstrShape() == ArrayRef({1, 16, 8}))); + + MLIRContext *ctx = mma.getContext(); + SmallVector dimNames = standardOutDimNames(ctx, rank); + + // Implement A. For B transpose in the end + std::vector> registers; + std::vector> lanes; + int32_t i = 1; + // kWidth contiguous elements + while (i < kWidth) { + registers.push_back({i, 0}); + i *= 2; + } + // 4 threads per chunk + for (int j = 0; j < 2; j++) { + lanes.push_back({i, 0}); + i *= 2; + } + // 8 threads going down + lanes.push_back({0, 1}); + lanes.push_back({0, 2}); + lanes.push_back({0, 4}); + // 2 tiles in column-major order + // Just one if it's the B operand + if (isA) { + registers.push_back({0, 8}); + } + registers.push_back({i, 0}); + + if (!isA) { + for (auto &r : registers) { + std::swap(r[0], r[1]); + } + for (auto &l : lanes) { + std::swap(l[0], l[1]); + } + } + + LinearLayout ctaLayout( + {{S("register"), registers}, {S("lane"), lanes}}, + llvm::to_vector(llvm::reverse(ArrayRef(dimNames).take_back(2)))); + + auto order = dot.getCTAOrder(); + assert(order[0] == 1 && order[1] == 0); + ctaLayout *= identityND(S("warp"), dot.getWarpsPerCTA(), order, dimNames); + + return combineCtaCgaWithShape(ctaLayout, mma.getCTALayout(), shape); +} + std::optional DotOperandEncodingAttr::toLinearLayout(ArrayRef shape) const { - if (auto mfmaLayout = llvm::dyn_cast(getParent())) { return dotOperandMfmaToLinearLayout(*this, shape); } @@ -832,6 +892,12 @@ DotOperandEncodingAttr::toLinearLayout(ArrayRef shape) const { return dotOperandDpasToLinearLayout(*this, shape); } + // TODO Activate in a follow-up PR + // else if (auto mma = mlir::dyn_cast(getParent())) { + // if (mma.isAmpere()) { + // return ampereDotToLinearLayout(shape, *this); + // } + //} return std::nullopt; } diff --git a/test/TritonIntelGPU/distribute-to-warps.mlir b/test/TritonIntelGPU/distribute-to-warps.mlir index 7d36c35b34..8f5ed6db34 100644 --- a/test/TritonIntelGPU/distribute-to-warps.mlir +++ b/test/TritonIntelGPU/distribute-to-warps.mlir @@ -1,4 +1,5 @@ -// RUN: triton-opt %s -split-input-file -tritonintelgpu-distribute-to-warps | FileCheck %s +// FIXME: Remove the env. variable once issue #2529 is fixed. +// RUN: env TRITON_INTEL_ADVANCED_PATH=1 triton-opt %s -split-input-file -tritonintelgpu-distribute-to-warps | FileCheck %s #blocked1 = #triton_gpu.blocked<{sizePerThread = [32, 32], threadsPerWarp = [1, 1], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> #blocked2 = #triton_gpu.blocked<{sizePerThread = [32, 32], threadsPerWarp = [1, 1], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> diff --git a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp index 60ccb6c5ca..894d78e1b4 100644 --- a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp +++ b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp @@ -4,6 +4,7 @@ #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Tools/StrUtil.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/Support/Signals.h" #include #include @@ -40,6 +41,12 @@ class LinearLayoutConversionsTest : public ::testing::Test { CTALayoutAttr::get(&ctx, cpg, cSplit, cOrd), instrShape); } + DotOperandEncodingAttr dotMMAv2(int idx, int kWidth, ArrayRef warps, + ArrayRef order) { + auto mmaLayout = mma(2, 0, {16, 8}, warps, {1, 1}, {1, 1}, order); + return DotOperandEncodingAttr::get(&ctx, idx, mmaLayout, /*kWidth=*/kWidth); + } + AMDMfmaEncodingAttr mfma(ArrayRef warps, unsigned mDim, unsigned nDim, bool isTransposed) { SmallVector cpg(warps.size(), 1u); @@ -494,6 +501,81 @@ TEST_F(LinearLayoutConversionsTest, MMAv3_4x4Warps) { {S("dim0"), S("dim1")})); } +TEST_F(LinearLayoutConversionsTest, DotMMAv2_tile_kwidth8) { + EXPECT_EQ(ampereDotToLinearLayout({16, 64}, dotMMAv2(0, 8, {1, 1}, {1, 0})), + LinearLayout( + { + {S("register"), {{0, 1}, {0, 2}, {0, 4}, {8, 0}, {0, 32}}}, + {S("lane"), {{0, 8}, {0, 16}, {1, 0}, {2, 0}, {4, 0}}}, + {S("warp"), {}}, + {S("block"), {}}, + }, + {S("dim0"), S("dim1")})); + EXPECT_EQ(ampereDotToLinearLayout({64, 8}, dotMMAv2(1, 8, {1, 1}, {1, 0})), + LinearLayout( + { + {S("register"), {{1, 0}, {2, 0}, {4, 0}, {32, 0}}}, + {S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}}, + {S("warp"), {}}, + {S("block"), {}}, + }, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, DotMMAv2_large_warp4_kwidth8) { + EXPECT_EQ( + ampereDotToLinearLayout({128, 128}, dotMMAv2(0, 8, {4, 1}, {1, 0})), + LinearLayout( + { + {S("register"), + {{0, 1}, {0, 2}, {0, 4}, {8, 0}, {0, 32}, {0, 64}, {64, 0}}}, + {S("lane"), {{0, 8}, {0, 16}, {1, 0}, {2, 0}, {4, 0}}}, + {S("warp"), {{16, 0}, {32, 0}}}, + {S("block"), {}}, + }, + {S("dim0"), S("dim1")})); + EXPECT_EQ(ampereDotToLinearLayout({128, 64}, dotMMAv2(1, 8, {4, 1}, {1, 0})), + LinearLayout( + { + {S("register"), + {{1, 0}, + {2, 0}, + {4, 0}, + {32, 0}, + {0, 8}, + {0, 16}, + {0, 32}, + {64, 0}}}, + {S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}}, + { + S("warp"), + {}, + }, + {S("block"), {}}, + }, + {S("dim0"), S("dim1")})); + EXPECT_EQ(ampereDotToLinearLayout({64, 128}, dotMMAv2(1, 8, {4, 1}, {1, 0})), + LinearLayout( + { + {S("register"), + {{1, 0}, + {2, 0}, + {4, 0}, + {32, 0}, + {0, 8}, + {0, 16}, + {0, 32}, + {0, 64}}}, + {S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}}, + { + S("warp"), + {}, + }, + {S("block"), {}}, + }, + {S("dim0"), S("dim1")})); +} + TEST_F(LinearLayoutConversionsTest, MFMA32_2x4Warps) { auto mfmaNT = mfma(/*warps=*/{2, 4}, /*mDim=*/32, /*nDim=*/32, /*isTransposed=*/false);