Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1044,16 +1044,16 @@ SmallVector<unsigned> DotOperandEncodingAttr::getCTASplitNum() const {
return res;
}
SmallVector<unsigned> DotOperandEncodingAttr::getWarpsPerCTA() const {
auto parentLayout = getParent();
assert(parentLayout && "DotOperandEncodingAttr must have a parent");
if (auto distributedLayout =
mlir::dyn_cast<DistributedEncodingTrait>(parentLayout)) {
return distributedLayout.getWarpsPerCTA();
} else {
llvm::report_fatal_error(
"DotOperandEncodingAttr non-DistributedEncodingAttr parent not "
"supported yet");
}
auto distributedLayout = mlir::cast<DistributedEncodingTrait>(getParent());
auto warps = distributedLayout.getWarpsPerCTA();
// FIXME: This is a temporary solution to avoid distribute-to-warps.mlir
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Created #2529 to have it properly fixed.

// failure.
if (mlir::triton::tools::getBoolEnv("TRITON_INTEL_ADVANCED_PATH"))
return warps;
auto rank = warps.size();
auto kDim = getOpIdx() == 0 ? rank - 1 : rank - 2;
warps[kDim] = 1;
return warps;
}
SmallVector<unsigned> DotOperandEncodingAttr::getWarpOrder() const {
return ::getWarpOrder(*this);
Expand Down
68 changes: 67 additions & 1 deletion lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h"
#include "triton/Tools/LinearLayout.h"
#include "triton/Tools/StrUtil.h"
#include "llvm/ADT/DenseMap.h"
Expand Down Expand Up @@ -822,16 +823,81 @@ SliceEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
return ret;
}

LinearLayout ampereDotToLinearLayout(ArrayRef<int64_t> shape,
DotOperandEncodingAttr dot) {
// TODO,BE. Implement ampereMMA in terms of this one
int rank = shape.size();
auto mma = cast<NvidiaMmaEncodingAttr>(dot.getParent());
int kWidth = dot.getKWidth();
bool isA = dot.getOpIdx() == 0;

assert(mma.isAmpere());
assert((rank == 2 && mma.getInstrShape() == ArrayRef<unsigned>({16, 8})) ||
(rank == 3 && mma.getInstrShape() == ArrayRef<unsigned>({1, 16, 8})));

MLIRContext *ctx = mma.getContext();
SmallVector<StringAttr> dimNames = standardOutDimNames(ctx, rank);

// Implement A. For B transpose in the end
std::vector<std::vector<int32_t>> registers;
std::vector<std::vector<int32_t>> lanes;
int32_t i = 1;
// kWidth contiguous elements
while (i < kWidth) {
registers.push_back({i, 0});
i *= 2;
}
// 4 threads per chunk
for (int j = 0; j < 2; j++) {
lanes.push_back({i, 0});
i *= 2;
}
// 8 threads going down
lanes.push_back({0, 1});
lanes.push_back({0, 2});
lanes.push_back({0, 4});
// 2 tiles in column-major order
// Just one if it's the B operand
if (isA) {
registers.push_back({0, 8});
}
registers.push_back({i, 0});

if (!isA) {
for (auto &r : registers) {
std::swap(r[0], r[1]);
}
for (auto &l : lanes) {
std::swap(l[0], l[1]);
}
}

LinearLayout ctaLayout(
{{S("register"), registers}, {S("lane"), lanes}},
llvm::to_vector(llvm::reverse(ArrayRef(dimNames).take_back(2))));

auto order = dot.getCTAOrder();
assert(order[0] == 1 && order[1] == 0);
ctaLayout *= identityND(S("warp"), dot.getWarpsPerCTA(), order, dimNames);

return combineCtaCgaWithShape(ctaLayout, mma.getCTALayout(), shape);
}

std::optional<LinearLayout>
DotOperandEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {

if (auto mfmaLayout = llvm::dyn_cast<AMDMfmaEncodingAttr>(getParent())) {
return dotOperandMfmaToLinearLayout(*this, shape);
}
if (auto dpasLayout = llvm::dyn_cast<intel::DpasEncodingAttr>(getParent())) {
return dotOperandDpasToLinearLayout(*this, shape);
}

// TODO Activate in a follow-up PR
// else if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(getParent())) {
// if (mma.isAmpere()) {
// return ampereDotToLinearLayout(shape, *this);
// }
//}
return std::nullopt;
}

Expand Down
3 changes: 2 additions & 1 deletion test/TritonIntelGPU/distribute-to-warps.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// RUN: triton-opt %s -split-input-file -tritonintelgpu-distribute-to-warps | FileCheck %s
// FIXME: Remove the env. variable once issue #2529 is fixed.
// RUN: env TRITON_INTEL_ADVANCED_PATH=1 triton-opt %s -split-input-file -tritonintelgpu-distribute-to-warps | FileCheck %s

#blocked1 = #triton_gpu.blocked<{sizePerThread = [32, 32], threadsPerWarp = [1, 1], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [32, 32], threadsPerWarp = [1, 1], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
Expand Down
82 changes: 82 additions & 0 deletions unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Tools/StrUtil.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/Support/Signals.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
Expand Down Expand Up @@ -40,6 +41,12 @@ class LinearLayoutConversionsTest : public ::testing::Test {
CTALayoutAttr::get(&ctx, cpg, cSplit, cOrd), instrShape);
}

DotOperandEncodingAttr dotMMAv2(int idx, int kWidth, ArrayRef<unsigned> warps,
ArrayRef<unsigned> order) {
auto mmaLayout = mma(2, 0, {16, 8}, warps, {1, 1}, {1, 1}, order);
return DotOperandEncodingAttr::get(&ctx, idx, mmaLayout, /*kWidth=*/kWidth);
}

AMDMfmaEncodingAttr mfma(ArrayRef<unsigned> warps, unsigned mDim,
unsigned nDim, bool isTransposed) {
SmallVector<unsigned> cpg(warps.size(), 1u);
Expand Down Expand Up @@ -494,6 +501,81 @@ TEST_F(LinearLayoutConversionsTest, MMAv3_4x4Warps) {
{S("dim0"), S("dim1")}));
}

TEST_F(LinearLayoutConversionsTest, DotMMAv2_tile_kwidth8) {
EXPECT_EQ(ampereDotToLinearLayout({16, 64}, dotMMAv2(0, 8, {1, 1}, {1, 0})),
LinearLayout(
{
{S("register"), {{0, 1}, {0, 2}, {0, 4}, {8, 0}, {0, 32}}},
{S("lane"), {{0, 8}, {0, 16}, {1, 0}, {2, 0}, {4, 0}}},
{S("warp"), {}},
{S("block"), {}},
},
{S("dim0"), S("dim1")}));
EXPECT_EQ(ampereDotToLinearLayout({64, 8}, dotMMAv2(1, 8, {1, 1}, {1, 0})),
LinearLayout(
{
{S("register"), {{1, 0}, {2, 0}, {4, 0}, {32, 0}}},
{S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}},
{S("warp"), {}},
{S("block"), {}},
},
{S("dim0"), S("dim1")}));
}

TEST_F(LinearLayoutConversionsTest, DotMMAv2_large_warp4_kwidth8) {
EXPECT_EQ(
ampereDotToLinearLayout({128, 128}, dotMMAv2(0, 8, {4, 1}, {1, 0})),
LinearLayout(
{
{S("register"),
{{0, 1}, {0, 2}, {0, 4}, {8, 0}, {0, 32}, {0, 64}, {64, 0}}},
{S("lane"), {{0, 8}, {0, 16}, {1, 0}, {2, 0}, {4, 0}}},
{S("warp"), {{16, 0}, {32, 0}}},
{S("block"), {}},
},
{S("dim0"), S("dim1")}));
EXPECT_EQ(ampereDotToLinearLayout({128, 64}, dotMMAv2(1, 8, {4, 1}, {1, 0})),
LinearLayout(
{
{S("register"),
{{1, 0},
{2, 0},
{4, 0},
{32, 0},
{0, 8},
{0, 16},
{0, 32},
{64, 0}}},
{S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}},
{
S("warp"),
{},
},
{S("block"), {}},
},
{S("dim0"), S("dim1")}));
EXPECT_EQ(ampereDotToLinearLayout({64, 128}, dotMMAv2(1, 8, {4, 1}, {1, 0})),
LinearLayout(
{
{S("register"),
{{1, 0},
{2, 0},
{4, 0},
{32, 0},
{0, 8},
{0, 16},
{0, 32},
{0, 64}}},
{S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}},
{
S("warp"),
{},
},
{S("block"), {}},
},
{S("dim0"), S("dim1")}));
}

TEST_F(LinearLayoutConversionsTest, MFMA32_2x4Warps) {
auto mfmaNT = mfma(/*warps=*/{2, 4}, /*mDim=*/32, /*nDim=*/32,
/*isTransposed=*/false);
Expand Down
Loading