Skip to content

Commit c4cc78f

Browse files
Merge commit '5f77e8c6ead532086590c3b93e03b6b824c65e69'
2 parents e4416dd + 5f77e8c commit c4cc78f

File tree

21 files changed

+1127
-118
lines changed

21 files changed

+1127
-118
lines changed

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,15 @@ toLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
4848
// dimension, determines if the layout moves data across block boundaries.
4949
bool isCrossCTAConversion(const LinearLayout &layout);
5050

51+
// Given a linear layout where the input dimensions contain a "block" dimension,
52+
// this method sets the "block" dimension to 0 and removes the corresponding
53+
// output dimensions.
54+
//
55+
// Note that this behavior differs from calling
56+
// `LinearLayout::sublayout(inDimNames, outDimNames)` when "block" is not in
57+
// `inDimNames`. The latter does not modify the output sizes.
58+
LinearLayout getLayoutWithinBlock(const LinearLayout &layout);
59+
5160
// In this function, we construct a linear layout representing the
5261
// <shared memory offset, iteration, block> -> <tensor element index> mapping
5362
// for entire `src` and `dst` tensors. We determine the shape of the

include/triton/Tools/LinearLayout.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -597,7 +597,7 @@ class LinearLayout {
597597
//
598598
// TODO(jlebar): Implement divideLeft.
599599
// std::optional<LinearLayout> divideLeft(const LinearLayout &divisor);
600-
std::optional<LinearLayout> divideRight(const LinearLayout &divisor);
600+
std::optional<LinearLayout> divideRight(const LinearLayout &divisor) const;
601601

602602
// Gets a layout with only these in/out dimensions.
603603
//

lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ struct AssertOpConversion : public ConvertOpToLLVMPattern<triton::AssertOp> {
5252
int line = 0;
5353
int col = 0;
5454

55+
while (auto callLoc = dyn_cast<CallSiteLoc>(loc))
56+
loc = callLoc.getCallee();
57+
5558
if (auto fileLineColLoc = dyn_cast<FileLineColLoc>(loc)) {
5659
file = fileLineColLoc.getFilename();
5760
line = fileLineColLoc.getLine();

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -367,12 +367,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
367367
// The following tasks must be completed before we can remove the layoutIsOK
368368
// check:
369369
// 1. Support for AMD's MFMA and WMMA
370-
// 2. Handling NVIDIA's MMA layout when CTA per CGA > 1
371370
std::function<bool(Attribute)> layoutIsOK = [&](Attribute layout) {
372371
if (auto nvidiaMma = dyn_cast<NvidiaMmaEncodingAttr>(layout)) {
373-
if (product(getCTAsPerCGA(nvidiaMma)) > 1) {
374-
return false;
375-
}
376372
if (useLegacyMMAConversion) {
377373
return false;
378374
}
@@ -419,8 +415,11 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
419415
}
420416
}
421417

422-
SmallVector<Value> outVals = transferWithinBlockOrGroupImpl(
423-
inVals, conversion, op, srcLayout, dstLayout, adaptor, rewriter);
418+
auto srcLayoutWithinBlock = getLayoutWithinBlock(srcLayout);
419+
auto dstLayoutWithinBlock = getLayoutWithinBlock(dstLayout);
420+
SmallVector<Value> outVals =
421+
transferWithinBlock(inVals, op, srcLayoutWithinBlock,
422+
dstLayoutWithinBlock, adaptor, rewriter);
424423

425424
// Unmunge output values
426425
for (const auto &it : llvm::enumerate(outVals)) {
@@ -437,11 +436,11 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
437436
return success();
438437
}
439438

440-
SmallVector<Value> transferWithinBlockOrGroupImpl(
441-
ArrayRef<Value> inVals, const LinearLayout &conversion,
442-
ConvertLayoutOp op, const LinearLayout &srcLayout,
443-
const LinearLayout &dstLayout, OpAdaptor adaptor,
444-
ConversionPatternRewriter &rewriter) const {
439+
SmallVector<Value>
440+
transferWithinBlock(ArrayRef<Value> inVals, ConvertLayoutOp op,
441+
const LinearLayout &srcLayout,
442+
const LinearLayout &dstLayout, OpAdaptor adaptor,
443+
ConversionPatternRewriter &rewriter) const {
445444
MLIRContext *ctx = op.getContext();
446445
auto loc = op.getLoc();
447446

@@ -459,11 +458,12 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
459458

460459
auto scratchConfig =
461460
getScratchConfigForCvt(op.getSrc().getType(), op.getType());
462-
auto tensorShape = convertType<unsigned, int64_t>(op.getType().getShape());
461+
auto tensorShapePerCTA = convertType<unsigned, int64_t>(getShapePerCTA(
462+
op.getSrc().getType().getEncoding(), op.getType().getShape()));
463463
// Input dims: [offset, iteration, block]
464464
// Output dims: dimN-1, dimN-2, ..., dim0, where N is obtained from repShape
465465
LinearLayout sharedLayout = chooseShemLayoutForRegToRegConversion(
466-
ctx, tensorShape, scratchConfig.repShape, scratchConfig.order);
466+
ctx, tensorShapePerCTA, scratchConfig.repShape, scratchConfig.order);
467467

468468
// Layout for the store from registers to shared memory.
469469
//

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 188 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "triton/Dialect/Triton/IR/Dialect.h"
22

3+
#include <cstdint>
34
#include <numeric>
45

56
#include "mlir/IR/DialectImplementation.h"
@@ -250,6 +251,30 @@ SmallVector<unsigned> getWarpOrder(Attribute layout) {
250251
return order;
251252
}
252253

254+
SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank) {
255+
SmallVector<unsigned> order(rank);
256+
// The 'order' field typically represents a descending sorted array of
257+
// dimensions based on contiguity. For instance, in axisInfo utilities that
258+
// retrieve tensor contiguity, it's assumed that the dimension with the
259+
// highest contiguity corresponds to order[0].
260+
//
261+
// The relation between contiguity and order is only relevant if the layout
262+
// interfaces with HBM, as is the case when we load tensor from HBM to
263+
// registers in the dot layout to bypass LDS. When bypassing LDS, we make the
264+
// following assumptions about tensor layouts:
265+
// - Tensor A (opIdx == 0) is considered to be row-major.
266+
// - Tensor B (opIdx == 1) is considered to be column-major.
267+
//
268+
// Based on these assumptions, we define the following orders:
269+
// - For opIdx == 0, we assume an order of [1, 0].
270+
// - For opIdx == 1, we assume an order of [0, 1].
271+
std::iota(order.rbegin(), order.rend(), 0);
272+
if (opIdx == 1) {
273+
std::swap(order[0], order[1]);
274+
}
275+
return order;
276+
}
277+
253278
SmallVector<unsigned> getOrder(Attribute layout) {
254279
if (auto blockedLayout = dyn_cast<BlockedEncodingAttr>(layout)) {
255280
return llvm::to_vector(blockedLayout.getOrder());
@@ -264,7 +289,11 @@ SmallVector<unsigned> getOrder(Attribute layout) {
264289
if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
265290
auto rank = getWarpsPerCTA(dotLayout.getParent()).size();
266291
SmallVector<unsigned> order(rank);
267-
std::iota(order.rbegin(), order.rend(), 0);
292+
if (isa<AMDMfmaEncodingAttr>(dotLayout.getParent())) {
293+
return getOrderForDotOperand(dotLayout.getOpIdx(), rank);
294+
} else {
295+
std::iota(order.rbegin(), order.rend(), 0);
296+
}
268297
return order;
269298
}
270299
if (auto sliceLayout = dyn_cast<SliceEncodingAttr>(layout)) {
@@ -928,6 +957,27 @@ unsigned SharedEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
928957
SmallVector<unsigned>
929958
DotOperandEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
930959
Type eltTy) const {
960+
961+
if (auto parent = mlir::dyn_cast<AMDMfmaEncodingAttr>(getParent())) {
962+
auto rank = shape.size();
963+
assert(rank == 2 || rank == 3);
964+
965+
auto idx = getOpIdx();
966+
assert(idx == 0 || idx == 1);
967+
968+
SmallVector<unsigned> elemsPerThread(rank);
969+
970+
auto kWidth = getKWidth();
971+
auto rep = parent.getMFMARepForOperands(shape, kWidth, idx);
972+
973+
if (rank == 3)
974+
elemsPerThread[0] = rep[0];
975+
elemsPerThread[rank - 2] = (idx == 0) ? rep[1] : rep[1] * kWidth;
976+
elemsPerThread[rank - 1] = (idx == 0) ? rep[2] * kWidth : rep[2];
977+
978+
return elemsPerThread;
979+
}
980+
931981
if (auto mmaParent = mlir::dyn_cast<MmaEncodingTrait>(getParent())) {
932982
return mmaParent.getElemsPerThreadForOperands(shape, eltTy, getOpIdx());
933983
}
@@ -3107,8 +3157,124 @@ static std::string paddedString(int value, int max) {
31073157
return str;
31083158
}
31093159

3110-
std::string mlir::triton::gpu::getLayoutStr(RankedTensorType tensorType,
3111-
bool useHWPointOfView) {
3160+
std::string getSharedLayoutStr(RankedTensorType tensorType,
3161+
bool useHWPointOfView) {
3162+
auto layout = tensorType.getEncoding();
3163+
if (!layout)
3164+
return "";
3165+
3166+
std::optional<LinearLayout> ll =
3167+
triton::gpu::toLinearLayout(tensorType.getShape(), layout);
3168+
if (!ll.has_value())
3169+
llvm::report_fatal_error("Failed to convert layout to linear layout");
3170+
3171+
StringAttr kOffset = StringAttr::get(tensorType.getContext(), "offset");
3172+
StringAttr kBlock = StringAttr::get(tensorType.getContext(), "block");
3173+
int64_t tensorSize = product(tensorType.getShape());
3174+
unsigned numBlocks = getNumCTAs(layout);
3175+
int32_t blockSize = tensorSize / numBlocks;
3176+
3177+
// elementMapping is for the non-hw layout, offsetMapping for hw-layout
3178+
std::vector<std::string> elementMapping(tensorSize);
3179+
std::vector<std::string> offsetMapping;
3180+
3181+
// Shared layouts are a mapping of (block, offset) --> (...)
3182+
3183+
// We can just use a single int to index into elementMapping because
3184+
// the 'swizzle' operation rearranges the indicies---and we want to keep it
3185+
// that way
3186+
int32_t idx = 0;
3187+
// Enumerate all the offsets for each block
3188+
for (int32_t block = 0; block < numBlocks; block++) {
3189+
for (int32_t offset = 0; offset < blockSize; offset++) {
3190+
SmallVector<std::pair<StringAttr, int32_t>> inputs = {
3191+
{kBlock, block},
3192+
{kOffset, offset},
3193+
};
3194+
3195+
SmallVector<std::pair<StringAttr, int32_t>> outputs = ll->apply(inputs);
3196+
3197+
std::string sharedInfo = "(";
3198+
std::string &value = elementMapping[idx];
3199+
3200+
if (!value.empty())
3201+
value += "|";
3202+
3203+
value += "(";
3204+
// We can build up both strings (for hw/non-hw layouts) concurrently
3205+
for (int i = 0; i < outputs.size(); i++) {
3206+
// Based on the formatting from LinearLayout::toString, the format for
3207+
// the hw layout is slightly different. HW layouts use "," vs ":".
3208+
if (i > 0) {
3209+
sharedInfo += ",";
3210+
value += ":";
3211+
}
3212+
auto index = paddedString(outputs[i].second, tensorType.getDimSize(i));
3213+
sharedInfo += index;
3214+
value += index;
3215+
}
3216+
value += ")";
3217+
sharedInfo += ")";
3218+
3219+
offsetMapping.push_back(sharedInfo);
3220+
3221+
idx++;
3222+
}
3223+
}
3224+
3225+
std::string layoutStr;
3226+
3227+
if (!useHWPointOfView) {
3228+
int rank = tensorType.getRank();
3229+
bool newLine = true;
3230+
for (int i = 0; i < tensorSize; i++) {
3231+
auto indices = delinearizeIndex(i, tensorType.getShape());
3232+
int numOpenBracket = 0;
3233+
for (int j = rank - 1; j >= 0; j--) {
3234+
if (indices[j] % tensorType.getDimSize(j) != 0)
3235+
break;
3236+
layoutStr += "[";
3237+
numOpenBracket++;
3238+
}
3239+
if (newLine) {
3240+
for (int j = 0; j < rank - numOpenBracket; j++)
3241+
layoutStr += " ";
3242+
newLine = false;
3243+
}
3244+
3245+
layoutStr += elementMapping[i];
3246+
auto nextIndices = delinearizeIndex(i + 1, tensorType.getShape());
3247+
for (int j = rank - 1; j >= 0; j--) {
3248+
if (nextIndices[j] % tensorType.getDimSize(j) != 0)
3249+
break;
3250+
layoutStr += "]";
3251+
}
3252+
if (nextIndices.back() % tensorType.getShape().back() == 0) {
3253+
layoutStr += "\n";
3254+
newLine = true;
3255+
} else {
3256+
layoutStr += ",";
3257+
}
3258+
}
3259+
} else {
3260+
// For the HW view here, print the (block, offset) --> (r,c) mapping
3261+
uint32_t idx = 0;
3262+
for (int32_t block = 0; block < numBlocks; block++) {
3263+
layoutStr += "Block: " + std::to_string(block) + ":\n";
3264+
for (int32_t offset = 0; offset < (tensorSize / numBlocks); offset++) {
3265+
layoutStr += "Offset: " + std::to_string(offset) + " -> ";
3266+
layoutStr += offsetMapping[idx];
3267+
layoutStr += "\n";
3268+
idx++;
3269+
}
3270+
}
3271+
}
3272+
3273+
return layoutStr;
3274+
}
3275+
3276+
std::string getDistributedLayoutStr(RankedTensorType tensorType,
3277+
bool useHWPointOfView) {
31123278
auto layout = tensorType.getEncoding();
31133279
if (!layout)
31143280
return "";
@@ -3175,7 +3341,7 @@ std::string mlir::triton::gpu::getLayoutStr(RankedTensorType tensorType,
31753341
}
31763342
std::string layoutStr;
31773343
if (!useHWPointOfView) {
3178-
// Printing the threads containning each elements of the tensor.
3344+
// Printing the threads containing each elements of the tensor.
31793345
int rank = tensorType.getRank();
31803346
bool newLine = true;
31813347
for (int i = 0; i < tensorSize; i++) {
@@ -3233,6 +3399,24 @@ std::string mlir::triton::gpu::getLayoutStr(RankedTensorType tensorType,
32333399
return layoutStr;
32343400
}
32353401

3402+
std::string mlir::triton::gpu::getLayoutStr(RankedTensorType tensorType,
3403+
bool useHWPointOfView) {
3404+
auto layout = tensorType.getEncoding();
3405+
3406+
// tensorType is needed later on (e.g., getDimSize(j)), so we still have to
3407+
// pass it as a param
3408+
if (auto sharedLayout = mlir::dyn_cast<SharedEncodingAttr>(layout)) {
3409+
return getSharedLayoutStr(tensorType, useHWPointOfView);
3410+
} else if (auto distributedLayout =
3411+
mlir::dyn_cast<DistributedEncodingTrait>(layout)) {
3412+
return getDistributedLayoutStr(tensorType, useHWPointOfView);
3413+
}
3414+
3415+
// else unimplemented, return error
3416+
llvm::report_fatal_error("Unimplemented usage of getLayoutStr");
3417+
return "";
3418+
}
3419+
32363420
void mlir::triton::gpu::dumpLayout(RankedTensorType tensorType) {
32373421
llvm::errs() << getLayoutStr(tensorType, /*useHWPointOfView=*/false);
32383422
}

0 commit comments

Comments
 (0)