Skip to content

Commit e1bb2cc

Browse files
lezcanomakslevental
authored andcommitted
[LAYOUTS] Choose between wgmma RS and SS within AccelerateMatmul (triton-lang#5798)
Doing so is rather natural. We choose to keep the lhs in registers if it does not come from a load. We also implement a generic transpose-inputs-and-outputs transformation when this is beneficial for mmav3.
1 parent 2b0fbd0 commit e1bb2cc

File tree

6 files changed

+182
-404
lines changed

6 files changed

+182
-404
lines changed

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

Lines changed: 102 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,28 @@ static bool bwdFilter(Operation *op) {
195195
mlir::TypeID::get<arith::ArithDialect>());
196196
}
197197

198+
static SmallVector<int, 2> getTransposeOrder(int rank) {
199+
assert(rank >= 2);
200+
auto transOrder = llvm::to_vector<2>(llvm::seq<int>(rank - 2));
201+
transOrder.push_back(rank - 1);
202+
transOrder.push_back(rank - 2);
203+
return transOrder;
204+
}
205+
206+
static DotOp transposeDotOp(PatternRewriter &rewriter, DotOp dotOp) {
207+
auto rank = dotOp.getResult().getType().getRank();
208+
Value a = dotOp.getA();
209+
Value b = dotOp.getB();
210+
Value c = dotOp.getC();
211+
auto transOrder = getTransposeOrder(rank);
212+
a = rewriter.create<TransOp>(a.getLoc(), a, transOrder);
213+
b = rewriter.create<TransOp>(b.getLoc(), b, transOrder);
214+
c = rewriter.create<TransOp>(c.getLoc(), c, transOrder);
215+
return rewriter.create<DotOp>(dotOp.getLoc(), c.getType(), b, a, c,
216+
dotOp.getInputPrecision(),
217+
dotOp.getMaxNumImpreciseAcc());
218+
}
219+
198220
// Finds the first different bitwidth in the chain of shape-preserving
199221
// unary ops that x depends on.
200222
// There are two primary scenarios:
@@ -249,29 +271,69 @@ class BlockedToMMA : public mlir::OpRewritePattern<DotOp> {
249271
return failure();
250272
}
251273
// TODO: Check data-types and SM compatibility
252-
RankedTensorType oldRetType = dotOp.getType();
253-
if (!oldRetType.getEncoding() ||
254-
mlir::isa<NvidiaMmaEncodingAttr>(oldRetType.getEncoding()))
274+
if (!dotOp.getType().getEncoding() ||
275+
mlir::isa<NvidiaMmaEncodingAttr>(dotOp.getType().getEncoding()))
255276
return failure();
256277

257-
// get MMA encoding for the given number of warps
258-
auto retShapePerCTA = getShapePerCTA(oldRetType);
259278
auto mod = dotOp->getParentOfType<mlir::ModuleOp>();
260279
int numWarps = TritonGPUDialect::getNumWarps(mod);
261-
auto CTALayout = getCTALayout(oldRetType.getEncoding());
262-
263280
int versionMajor = getMMAVersionSafe(computeCapability, dotOp);
264281
if (!(versionMajor >= 1 && versionMajor <= 3))
265282
return failure();
266283

267-
auto instrShape = mmaVersionToInstrShape(
268-
versionMajor, retShapePerCTA, dotOp.getA().getType().getElementType(),
269-
numWarps);
270-
// operands
284+
// If both of the operands are not loads, we fallback to MMAv2
285+
// otherwise the reg-smem roundtrip will tank the MMAv3 performance
286+
auto comesFromLoadOrBlockArg = [](Value v) -> bool {
287+
// Peel out the original cvt dot_op<..., #blocked>
288+
// and any other potential cvt/trans ops
289+
while (true) {
290+
if (auto cvtOp = v.getDefiningOp<ConvertLayoutOp>()) {
291+
v = cvtOp.getSrc();
292+
continue;
293+
}
294+
if (auto transOp = v.getDefiningOp<TransOp>()) {
295+
v = transOp.getSrc();
296+
continue;
297+
}
298+
break;
299+
}
300+
// We also accept block arguments as they appear in many MLIR tests
301+
// If this is problematic we can totally drop them
302+
return isa<BlockArgument>(v) ||
303+
(v.getDefiningOp() &&
304+
isa<LoadOp, ExperimentalDescriptorLoadOp>(v.getDefiningOp()));
305+
};
306+
307+
bool aFromLoad = comesFromLoadOrBlockArg(dotOp.getA());
308+
bool bFromLoad = comesFromLoadOrBlockArg(dotOp.getB());
309+
bool transpose = false;
310+
auto origDotOp = dotOp;
311+
if (aFromLoad && !bFromLoad) {
312+
// If the lhs is not a load and the rhs is, we transpose the inputs
313+
// and the result provided this allows us to use mmav3
314+
// We transpose the result at the end of the rewrite
315+
DotOp transDot = transposeDotOp(rewriter, dotOp);
316+
if (getMMAVersionSafe(computeCapability, transDot) == 3) {
317+
dotOp = transDot;
318+
versionMajor = 3;
319+
transpose = true;
320+
}
321+
std::swap(aFromLoad, bFromLoad);
322+
}
323+
// If !aFromLoad && !bFromLoad, we just accept a shmem roundtrip
324+
// for versionMajor == 3
325+
271326
Value a = dotOp.getA();
272327
Value b = dotOp.getB();
273-
auto oldAType = dotOp.getA().getType();
274-
auto oldBType = dotOp.getB().getType();
328+
auto oldAType = cast<RankedTensorType>(a.getType());
329+
auto oldBType = cast<RankedTensorType>(b.getType());
330+
auto oldRetType = cast<RankedTensorType>(dotOp.getType());
331+
332+
// get MMA encoding for the given number of warps
333+
auto CTALayout = getCTALayout(oldRetType.getEncoding());
334+
auto retShapePerCTA = getShapePerCTA(oldRetType);
335+
auto instrShape = mmaVersionToInstrShape(
336+
versionMajor, retShapePerCTA, oldAType.getElementType(), numWarps);
275337

276338
assert(versionMajor == 2 || versionMajor == 3);
277339
int versionMinor = computeCapability == 75 ? 1 : 0;
@@ -287,12 +349,28 @@ class BlockedToMMA : public mlir::OpRewritePattern<DotOp> {
287349
auto newAcc =
288350
rewriter.create<ConvertLayoutOp>(oldAcc.getLoc(), newRetType, oldAcc);
289351

352+
auto getDotOperand = [&](Value v, int opIdx, int bitwidth) {
353+
auto minType =
354+
bitwidth > 0 ? rewriter.getIntegerType(bitwidth) : v.getType();
355+
auto vType = cast<RankedTensorType>(v.getType());
356+
auto newVEncoding = DotOperandEncodingAttr::get(
357+
v.getContext(), opIdx, newRetType.getEncoding(), minType);
358+
auto newVType = RankedTensorType::get(
359+
vType.getShape(), vType.getElementType(), newVEncoding);
360+
return rewriter.create<ConvertLayoutOp>(v.getLoc(), newVType, v);
361+
};
362+
290363
Operation *newDot = nullptr;
291364
if (versionMajor == 3) {
292365
auto eltType = dotOp.getA().getType().getElementType();
293366
// In MMAV3 transpose is only supported for f16 and bf16.
294367
bool allowTranspose = eltType.isF16() || eltType.isBF16();
295-
a = getSharedMemoryMMAOperand(a, rewriter, 0, allowTranspose);
368+
if (!aFromLoad) {
369+
int bitwidth = getElementTypeOrSelf(a).getIntOrFloatBitWidth();
370+
a = getDotOperand(a, 0, bitwidth);
371+
} else {
372+
a = getSharedMemoryMMAOperand(a, rewriter, 0, allowTranspose);
373+
}
296374
b = getSharedMemoryMMAOperand(b, rewriter, 1, allowTranspose);
297375
newDot = rewriter.create<triton::nvidia_gpu::WarpGroupDotOp>(
298376
dotOp.getLoc(), newRetType, a, b, newAcc, nullptr,
@@ -301,27 +379,21 @@ class BlockedToMMA : public mlir::OpRewritePattern<DotOp> {
301379
// convert operands
302380
int minBitwidth =
303381
std::min(computeOrigBitWidth(a), computeOrigBitWidth(b));
304-
Type minType = rewriter.getIntegerType(minBitwidth);
305-
// convert A operand
306-
auto newAEncoding = DotOperandEncodingAttr::get(
307-
oldAType.getContext(), 0, newRetType.getEncoding(),
308-
minBitwidth > 0 ? minType : oldAType.getElementType());
309-
auto newAType = RankedTensorType::get(
310-
oldAType.getShape(), oldAType.getElementType(), newAEncoding);
311-
a = rewriter.create<ConvertLayoutOp>(a.getLoc(), newAType, a);
312-
// convert B operand
313-
auto newBEncoding = DotOperandEncodingAttr::get(
314-
oldBType.getContext(), 1, newRetType.getEncoding(),
315-
minBitwidth > 0 ? minType : oldBType.getElementType());
316-
auto newBType = RankedTensorType::get(
317-
oldBType.getShape(), oldBType.getElementType(), newBEncoding);
318-
b = rewriter.create<ConvertLayoutOp>(b.getLoc(), newBType, b);
382+
383+
a = getDotOperand(a, 0, minBitwidth);
384+
b = getDotOperand(b, 1, minBitwidth);
319385
newDot = rewriter.create<DotOp>(dotOp.getLoc(), newRetType, a, b, newAcc,
320386
dotOp.getInputPrecision(),
321387
dotOp.getMaxNumImpreciseAcc());
322388
}
389+
if (transpose) {
390+
auto rank = dotOp.getResult().getType().getRank();
391+
auto transOrder = getTransposeOrder(rank);
392+
newDot = rewriter.create<TransOp>(newDot->getLoc(), newDot->getResult(0),
393+
transOrder);
394+
}
323395
// convert dot instruction
324-
rewriter.replaceOpWithNewOp<ConvertLayoutOp>(dotOp, oldRetType,
396+
rewriter.replaceOpWithNewOp<ConvertLayoutOp>(origDotOp, origDotOp.getType(),
325397
newDot->getResult(0));
326398
return success();
327399
}

0 commit comments

Comments
 (0)