Skip to content

Commit 4d6ce4e

Browse files
authored
[AMD] Fix wmma scaled with small k dim on gfx1250 (#8487)
This RP fixes the layout and lowering for wmma scaled with small k dim where the tensor's k dimension is smaller than the a single wmma scaled instruction's k dimension. Add corresponding lit tests for common cases.
1 parent c07886c commit 4d6ce4e

File tree

6 files changed

+424
-86
lines changed

6 files changed

+424
-86
lines changed

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -137,10 +137,9 @@ LinearLayout chooseScaledMfmaScaleLayout(MLIRContext *ctx, int dotOperandIdx,
137137
ArrayRef<unsigned> tilesPerWarp,
138138
ArrayRef<unsigned> warpsPerCTA);
139139

140-
LinearLayout chooseScaledWmmaScaleLayout(
141-
MLIRContext *ctx, int dotOperandIdx,
142-
const std::vector<std::vector<int32_t>> &dotOperandWarpBasis,
143-
ArrayRef<int64_t> dotOperandShape);
140+
LinearLayout chooseScaledWmmaScaleLayout(MLIRContext *ctx, int dotOperandIdx,
141+
ArrayRef<unsigned> warpsPerCTA,
142+
ArrayRef<int64_t> dotOperandShape);
144143

145144
LinearLayout getSM120DotScaledScaleLayout(MLIRContext *ctx,
146145
ArrayRef<int64_t> shape, int opIdx,

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 32 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1434,64 +1434,54 @@ chooseDsReadTrLayout(Attribute enc, ArrayRef<int64_t> shape,
14341434
}
14351435
}
14361436

1437-
LinearLayout chooseScaledWmmaScaleLayout(
1438-
MLIRContext *ctx, int dotOperandIdx,
1439-
const std::vector<std::vector<int32_t>> &dotOperandWarpBasis,
1440-
ArrayRef<int64_t> dotOperandShape) {
1437+
LinearLayout chooseScaledWmmaScaleLayout(MLIRContext *ctx, int dotOperandIdx,
1438+
ArrayRef<unsigned> warpsPerCTA,
1439+
ArrayRef<int64_t> dotOperandShape) {
14411440
using basisT = std::vector<std::vector<int32_t>>;
14421441
unsigned rank = dotOperandShape.size();
14431442
auto order = mlir::triton::gpu::getMatrixOrder(rank, /*rowMajor=*/true);
1444-
auto standardOutDims = standardOutDimNames(ctx, rank);
1443+
auto outDimNames = standardOutDimNames(ctx, rank);
1444+
14451445
StringAttr kRegister = StringAttr::get(ctx, "register");
14461446
StringAttr kLane = StringAttr::get(ctx, "lane");
14471447
StringAttr kWarp = StringAttr::get(ctx, "warp");
14481448
StringAttr kBlock = StringAttr::get(ctx, "block");
1449-
unsigned int scaleKWidth = dotOperandShape[1];
1450-
// Init register layout. Will be adjusted later
1451-
auto regs =
1452-
mlir::triton::identityStandardND(kRegister, {1, scaleKWidth}, order);
1453-
LinearLayout lanes = LinearLayout::empty();
1449+
14541450
// In scaled dot, the shapes of operands(without batch dimension) are,
14551451
// respectively:
14561452
// - A: [M, K]
14571453
// - B: [K, N]
14581454
// - aScale: [M, K / 32 or 16]
14591455
// - bScale: [N, K / 32 or 16]
1460-
//
1461-
// To correctly feed A/B and its scale into instruction, we need to
1462-
// distribute aScale/bScale among warps in the same way as A/B. But bScale
1463-
// is not transposed like B. So we need to transpose the warp layout of
1464-
// bScale.
1465-
//
1466-
// The tricky part is, our desired outputs are [dim0, dim1], but
1467-
// at this position, the layouts are transposed to [dim1, dim0]. So
1468-
// instead of reverse bScale's layout, we need to reverse aScale's. There
1469-
// will be a transpose in the end to correct everything.
1470-
basisT warps = dotOperandWarpBasis;
1471-
if (dotOperandIdx == 0) {
1472-
for (auto &basis : warps) {
1473-
std::reverse(basis.begin(), basis.end());
1474-
}
1475-
}
1456+
auto dimK = outDimNames[order[0]];
1457+
auto dimNonK = outDimNames[order[1]];
14761458

1477-
lanes = LinearLayout({{kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 0}}},
1478-
{kWarp, warps},
1479-
{kBlock, {}}},
1480-
{standardOutDims[order[0]], standardOutDims[order[1]]});
1481-
LinearLayout newLL = regs * lanes;
1459+
// Each lane holds kWidth=4 consecutive values along the k dim.
1460+
// The first 16 lanes are distributed along the non-k dim. We are not using
1461+
// the remaining 16 lanes, so just let them duplicate values of the first 16
1462+
// lanes. If the shape along the k dim is larger than kWidth, repeat this
1463+
// pattern to fill the k dim.
1464+
unsigned scaleKWidth = 4;
1465+
auto kSize = dotOperandShape[1];
1466+
LinearLayout tileLayout =
1467+
LinearLayout::identity1D(scaleKWidth, kRegister, dimK) *
1468+
LinearLayout::identity1D(16, kLane, dimNonK) *
1469+
LinearLayout::zeros1D(2, kLane, dimK) *
1470+
LinearLayout::identity1D(kSize / scaleKWidth, kRegister, dimK);
14821471

1483-
// Adjust register-level layout to fill the shape, at this level, both
1484-
// aScale and bScale should align with A operand.
1485-
SmallVector<int, 2> repOrder = {1, 0};
1486-
for (auto d : repOrder) {
1487-
auto outDim = standardOutDims[d];
1488-
auto dimSize = newLL.getOutDimSize(outDim);
1489-
newLL *= LinearLayout::identity1D(dotOperandShape[d] / dimSize, kRegister,
1490-
outDim);
1491-
}
1492-
newLL = newLL.transposeOuts(standardOutDims);
1472+
auto warpsPerCTANew = (dotOperandIdx == 1)
1473+
? SmallVector{warpsPerCTA[1], warpsPerCTA[0]}
1474+
: SmallVector{warpsPerCTA[0], warpsPerCTA[1]};
1475+
1476+
auto warpOrder = (dotOperandIdx == 1) ? SmallVector<unsigned>{0, 1}
1477+
: SmallVector<unsigned>{1, 0};
1478+
LinearLayout warpLayout =
1479+
identityStandardND(kWarp, warpsPerCTANew, warpOrder);
1480+
LinearLayout ctaLayout = tileLayout.transposeOuts(outDimNames) *
1481+
warpLayout.transposeOuts(outDimNames);
14931482

1494-
return newLL;
1483+
return combineCtaCgaWithShape(
1484+
ctaLayout, CTALayoutAttr::getDefault(ctx, /*rank=*/2), dotOperandShape);
14951485
}
14961486

14971487
// PTX ISA - Warp-level MMA Block Scaling

0 commit comments

Comments
 (0)