Skip to content

Commit 3a7a236

Browse files
Jokerenliuyunqi20
authored andcommitted
[BACKEND] Support convert_layout with num_ctas > 1 Using Linear Layout (#4782)
Particularly, this PR implements layout conversion when a CGA contains more than one CTA. In such cases, a Triton tensor is split into multiple blocks, with each block being handled by a CTA. ``` block0 | block1 ---------------- block2 | block3 ``` If data transfer is required from block0 to block3, this PR cannot handle it, and we use `isCrossCTAConversion` to check this condition.
1 parent 5c30929 commit 3a7a236

File tree

5 files changed

+49
-21
lines changed

5 files changed

+49
-21
lines changed

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,15 @@ toLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
4848
// dimension, determines if the layout moves data across block boundaries.
4949
bool isCrossCTAConversion(const LinearLayout &layout);
5050

51+
// Given a linear layout where the input dimensions contain a "block" dimension,
52+
// this method sets the "block" dimension to 0 and removes the corresponding
53+
// output dimensions.
54+
//
55+
// Note that this behavior differs from calling
56+
// `LinearLayout::sublayout(inDimNames, outDimNames)` when "block" is not in
57+
// `inDimNames`. The latter does not modify the output sizes.
58+
LinearLayout getLayoutWithinBlock(const LinearLayout &layout);
59+
5160
// In this function, we construct a linear layout representing the
5261
// <shared memory offset, iteration, block> -> <tensor element index> mapping
5362
// for entire `src` and `dst` tensors. We determine the shape of the

include/triton/Tools/LinearLayout.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -597,7 +597,7 @@ class LinearLayout {
597597
//
598598
// TODO(jlebar): Implement divideLeft.
599599
// std::optional<LinearLayout> divideLeft(const LinearLayout &divisor);
600-
std::optional<LinearLayout> divideRight(const LinearLayout &divisor);
600+
std::optional<LinearLayout> divideRight(const LinearLayout &divisor) const;
601601

602602
// Gets a layout with only these in/out dimensions.
603603
//

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -367,12 +367,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
367367
// The following tasks must be completed before we can remove the layoutIsOK
368368
// check:
369369
// 1. Support for AMD's MFMA and WMMA
370-
// 2. Handling NVIDIA's MMA layout when CTA per CGA > 1
371370
std::function<bool(Attribute)> layoutIsOK = [&](Attribute layout) {
372371
if (auto nvidiaMma = dyn_cast<NvidiaMmaEncodingAttr>(layout)) {
373-
if (product(getCTAsPerCGA(nvidiaMma)) > 1) {
374-
return false;
375-
}
376372
if (useLegacyMMAConversion) {
377373
return false;
378374
}
@@ -419,8 +415,11 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
419415
}
420416
}
421417

422-
SmallVector<Value> outVals = transferWithinBlockOrGroupImpl(
423-
inVals, conversion, op, srcLayout, dstLayout, adaptor, rewriter);
418+
auto srcLayoutWithinBlock = getLayoutWithinBlock(srcLayout);
419+
auto dstLayoutWithinBlock = getLayoutWithinBlock(dstLayout);
420+
SmallVector<Value> outVals =
421+
transferWithinBlock(inVals, op, srcLayoutWithinBlock,
422+
dstLayoutWithinBlock, adaptor, rewriter);
424423

425424
// Unmunge output values
426425
for (const auto &it : llvm::enumerate(outVals)) {
@@ -437,11 +436,11 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
437436
return success();
438437
}
439438

440-
SmallVector<Value> transferWithinBlockOrGroupImpl(
441-
ArrayRef<Value> inVals, const LinearLayout &conversion,
442-
ConvertLayoutOp op, const LinearLayout &srcLayout,
443-
const LinearLayout &dstLayout, OpAdaptor adaptor,
444-
ConversionPatternRewriter &rewriter) const {
439+
SmallVector<Value>
440+
transferWithinBlock(ArrayRef<Value> inVals, ConvertLayoutOp op,
441+
const LinearLayout &srcLayout,
442+
const LinearLayout &dstLayout, OpAdaptor adaptor,
443+
ConversionPatternRewriter &rewriter) const {
445444
MLIRContext *ctx = op.getContext();
446445
auto loc = op.getLoc();
447446

@@ -459,11 +458,12 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
459458

460459
auto scratchConfig =
461460
getScratchConfigForCvt(op.getSrc().getType(), op.getType());
462-
auto tensorShape = convertType<unsigned, int64_t>(op.getType().getShape());
461+
auto tensorShapePerCTA = convertType<unsigned, int64_t>(getShapePerCTA(
462+
op.getSrc().getType().getEncoding(), op.getType().getShape()));
463463
// Input dims: [offset, iteration, block]
464464
// Output dims: dimN-1, dimN-2, ..., dim0, where N is obtained from repShape
465465
LinearLayout sharedLayout = chooseShemLayoutForRegToRegConversion(
466-
ctx, tensorShape, scratchConfig.repShape, scratchConfig.order);
466+
ctx, tensorShapePerCTA, scratchConfig.repShape, scratchConfig.order);
467467

468468
// Layout for the store from registers to shared memory.
469469
//

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -869,6 +869,17 @@ bool isCrossCTAConversion(const LinearLayout &layout) {
869869
!layout.sublayoutIsIdentity({kBlock}, {kBlock});
870870
}
871871

872+
LinearLayout getLayoutWithinBlock(const LinearLayout &layout) {
873+
assert(!layout.getInDimNames().empty());
874+
MLIRContext *ctx = layout.getInDimNames().begin()->getContext();
875+
876+
StringAttr kBlock = S("block");
877+
assert(layout.hasInDim(kBlock));
878+
auto bases = layout.getBases();
879+
bases[kBlock] = {};
880+
return LinearLayout(bases, llvm::to_vector<4>(layout.getOutDimNames()));
881+
}
882+
872883
LinearLayout chooseShemLayoutForRegToRegConversion(
873884
MLIRContext *ctx, ArrayRef<unsigned> tensorShape,
874885
ArrayRef<unsigned> repShape, ArrayRef<unsigned> order) {
@@ -925,11 +936,11 @@ bool canUseStMatrix(RankedTensorType tensorTy, ArrayRef<unsigned> repShape,
925936
if (order[0] != 1)
926937
return false;
927938

928-
auto tensorShape = tensorTy.getShape();
929-
if (tensorShape.size() != 2)
939+
auto tensorShapePerCTA = getShapePerCTA(mmaLayout, tensorTy.getShape());
940+
if (tensorShapePerCTA.size() != 2)
930941
return false;
931-
auto numIterations = ceil<unsigned>(tensorShape[1], repShape[1]) *
932-
ceil<unsigned>(tensorShape[0], repShape[0]);
942+
auto numIterations = ceil<unsigned>(tensorShapePerCTA[1], repShape[1]) *
943+
ceil<unsigned>(tensorShapePerCTA[0], repShape[0]);
933944
if (numIterations > 1)
934945
return false;
935946
if (paddedRepShape[1] % 8 != 0)
@@ -1020,6 +1031,7 @@ std::optional<LinearLayout> chooseStMatrixLayoutNoLeadingOffset(
10201031
StringAttr kWarp = S("warp");
10211032
StringAttr kCol = S("dim1");
10221033
StringAttr kRow = S("dim0");
1034+
StringAttr kBlock = S("block");
10231035

10241036
std::vector<std::vector<int>> basesReg = {{1, 0}, {2, 0}, {4, 0}};
10251037
std::vector<std::vector<int>> basesLane = {
@@ -1039,9 +1051,16 @@ std::optional<LinearLayout> chooseStMatrixLayoutNoLeadingOffset(
10391051
.transposeOuts(llvm::to_vector(layout.getOutDimNames()));
10401052
auto ret =
10411053
combineCtaCgaWithShape(layout, mma.getCTALayout(), tensorTy.getShape());
1054+
auto tensorShapePerCTA = getShapePerCTA(mma, tensorTy.getShape());
1055+
llvm::SmallDenseMap<StringAttr, int64_t> namedTensorShape;
1056+
namedTensorShape[kRow] = tensorShapePerCTA[0];
1057+
namedTensorShape[kCol] = tensorShapePerCTA[1];
1058+
ret = ensureLayoutNotSmallerThan(ret, namedTensorShape);
1059+
ret = ensureLayoutNotLargerThan(ret, namedTensorShape);
10421060
return ret.transposeOuts(llvm::to_vector(layout.getOutDimNames()))
1043-
.reshapeOuts(
1044-
{{S("offset"), ret.getTotalOutDimSize()}, {S("iteration"), 1}});
1061+
.reshapeOuts({{S("offset"), ret.getTotalOutDimSize()},
1062+
{S("iteration"), 1}}) *
1063+
identityND(kBlock, {1, 1}, {0, 1}, {S("offset"), S("iteration")});
10451064
}
10461065

10471066
} // anonymous namespace

lib/Tools/LinearLayout.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -640,7 +640,7 @@ LinearLayout operator*(LinearLayout inner, LinearLayout outer) {
640640
}
641641

642642
std::optional<LinearLayout>
643-
LinearLayout::divideRight(const LinearLayout &divisor) {
643+
LinearLayout::divideRight(const LinearLayout &divisor) const {
644644
assertCommonDimsSameOrder(getOutDimNames(), divisor.getOutDimNames());
645645
assertCommonDimsSameOrder(getInDimNames(), divisor.getInDimNames());
646646

0 commit comments

Comments
 (0)