Skip to content

Commit 6f06595

Browse files
authored
[AMD] Use permlane_swap for layout conversions between two dot operations (triton-lang#7947)
1 parent 72ec661 commit 6f06595

File tree

5 files changed

+51
-9
lines changed

5 files changed

+51
-9
lines changed

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -691,18 +691,19 @@ LinearLayout mfmaDotToLinearLayout(DotOperandEncodingAttr dotMfmaLayout,
691691
int mIndex = 0 + hasBatchDim;
692692

693693
int32_t kWidth = dotMfmaLayout.getKWidth();
694-
auto kDimIndex = dotMfmaLayout.getOpIdx() == 0 ? rank - 1 : rank - 2;
694+
auto nonKDimIndex = dotMfmaLayout.getOpIdx() == 0 ? rank - 2 : rank - 1;
695695

696696
auto warpsPerCTA = mfmaLayout.getWarpsPerCTA();
697697
auto tilesPerWarp = mfmaLayout.getTilesPerWarp();
698-
auto tilePerWarpNonK = tilesPerWarp[kDimIndex];
698+
auto tilePerWarpNonK = tilesPerWarp[nonKDimIndex];
699699

700700
auto mDim = mfmaLayout.getMDim();
701701
auto nDim = mfmaLayout.getNDim();
702702
auto opIdx = dotMfmaLayout.getOpIdx();
703703
auto nonKDim = opIdx == 0 ? mDim : nDim;
704704
constexpr int warpSize = 64;
705705

706+
auto kDimIndex = dotMfmaLayout.getOpIdx() == 0 ? rank - 1 : rank - 2;
706707
int32_t kSize = shape[kDimIndex];
707708

708709
MLIRContext *ctx = dotMfmaLayout.getContext();

test/Conversion/amd/mfma-shortcut.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,17 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32}
9494
tt.return
9595
}
9696
}
97+
98+
// -----
99+
100+
#mma = #ttg.amd_mfma<{version = 4, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>
101+
#mma1 = #ttg.amd_mfma<{version = 4, warpsPerCTA = [4, 1], tilesPerWarp = [2, 1], instrShape = [16, 16], isTransposed = true}>
102+
module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
103+
// GFX950-LABEL: mfma_dotop_permlane_swap
104+
tt.func public @mfma_dotop_permlane_swap(%arg0: tensor<128x16xf16, #mma1>) {
105+
// GFX950-NOT: load
106+
// GFX950-COUNT-2: llvm.call_intrinsic "llvm.amdgcn.permlane16.swap"
107+
%1 = ttg.convert_layout %arg0: tensor<128x16xf16, #mma1> -> tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>
108+
tt.return
109+
}
110+
}

third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -686,7 +686,7 @@ bool isUsedByDotScaledOp(Operation *op) {
686686
});
687687
}
688688

689-
bool isChainDotHead(tt::DotOpInterface dotOp) {
689+
bool isChainDotHead(tt::DotOpInterface dotOp, unsigned opIdx) {
690690
auto isInSameRegion = [&dotOp](Operation *op) {
691691
return op->getParentRegion() == dotOp->getParentRegion();
692692
};
@@ -697,8 +697,9 @@ bool isChainDotHead(tt::DotOpInterface dotOp) {
697697
for (Operation *op : fwdSlices) {
698698
if (auto dOp = dyn_cast<tt::DotOpInterface>(op)) {
699699
assert(dOp != dotOp);
700-
auto opA = dOp.getA().getDefiningOp();
701-
if (opA && fwdSlices.contains(opA)) {
700+
Operation *dotOperand = (opIdx == 0) ? dOp.getA().getDefiningOp()
701+
: dOp.getB().getDefiningOp();
702+
if (dotOperand && fwdSlices.contains(dotOperand)) {
702703
return true;
703704
}
704705
}

third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,9 @@ bool doesSwizzleInsideWarp(RewriterBase &rewriter,
111111
// Return true if op is used by DotScaledOp or UpcastMXFPOp ops.
112112
bool isUsedByDotScaledOp(Operation *op);
113113

114-
// Check if the result of this tl.dot is used as opA of another tl.dot
114+
// Check if the result of this tl.dot is used as opA or opB of another tl.dot
115115
// in the same region
116-
bool isChainDotHead(mlir::triton::DotOpInterface dotOp);
116+
bool isChainDotHead(mlir::triton::DotOpInterface dotOp, unsigned opIdx = 0);
117117

118118
// Check if given operand of this tt.dot is the result of a tt.trans
119119
// in the same region

third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -475,9 +475,36 @@ class BlockedToMFMA : public OpRewritePattern<tt::DotOp> {
475475
// requires to broadcast the operand A.
476476
bool isTransposed = !(mDim == 4 && nDim == 64);
477477
auto aElemTy = mfmaInstr->aElementType;
478+
auto is16BitElemTy = (aElemTy.isF16() || aElemTy.isBF16());
479+
480+
unsigned rank = oldRetType.getRank();
481+
SmallVector<unsigned, 2> tilesPerWarp = {1, 1};
482+
483+
// Set tilesPerWarp and isTransposed to enable intra warp conversion for the
484+
// mfma16x16 layout of a dot op, depending on whether
485+
// its result is used by operand 0 or operand 1 of another dot op.
486+
if (mfmaVersion == 4 && is16BitElemTy && mDim == 16 && nDim == 16 &&
487+
rank == 2) {
488+
if (isChainDotHead(dotOp, 0u) &&
489+
retShape.front() >= 16 * 2 * warpsPerTile.front() &&
490+
retShape.back() == 16 && warpsPerTile.back() == 1) {
491+
isTransposed = true;
492+
tilesPerWarp = {2, 1};
493+
} else if (isChainDotHead(dotOp, 1u) && retShape.front() == 16 &&
494+
retShape.back() >= 16 * 2 * warpsPerTile.back() &&
495+
warpsPerTile.front() == 1) {
496+
isTransposed = false;
497+
tilesPerWarp = {1, 2};
498+
}
499+
}
500+
501+
if (rank == 3) {
502+
tilesPerWarp.insert(tilesPerWarp.begin(), 1);
503+
}
504+
478505
ttg::AMDMfmaEncodingAttr mfmaEnc = ttg::AMDMfmaEncodingAttr::get(
479506
oldRetType.getContext(),
480-
/*version*/ mfmaVersion, warpsPerTile,
507+
/*version*/ mfmaVersion, warpsPerTile, tilesPerWarp,
481508
/*instrShape*/ mDim, nDim, /*isTransposed=*/isTransposed, CTALayout,
482509
mfmaAccType);
483510

@@ -524,7 +551,6 @@ class BlockedToMFMA : public OpRewritePattern<tt::DotOp> {
524551
// kWidth = 4 so that the coversion from #mma (result of 1st dot)
525552
// to #dotOp (operand 0 of 2nd dot) is a no-op.
526553
// TODO (lixun): relax the condition for 8-bit elementTy.
527-
auto is16BitElemTy = (aElemTy.isF16() || aElemTy.isBF16());
528554
if (is16BitElemTy && isDotChainTail) {
529555
kWidth = 4;
530556
}

0 commit comments

Comments
 (0)