Skip to content

Commit 7d89248

Browse files
authored
[BACKEND] Support Hopper MMA to MMA convert_layout ops (intel#4492)
This PR enables mma to mma conversion on the hopper architecture. We also replace the previous `isMmaToMmaShortcut` check with `cvtReordersRegisters` in several places. Note that mma to mma conversion using shared memory still goes through the legacy `ConvertLayoutOpConversion` function; we will deprecate it soon in the next PR.
1 parent 6a9a0a6 commit 7d89248

File tree

11 files changed

+214
-250
lines changed

11 files changed

+214
-250
lines changed

include/triton/Analysis/Utility.h

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -189,18 +189,24 @@ bool supportMMA(triton::DotOp op, int version);
189189

190190
bool supportMMA(Value value, int version);
191191

192+
// Conversion from `srcTy` to `dstTy` only involves reordering of registers.
193+
// There is no need for data exchange across threads, warps, or blocks.
194+
bool cvtReordersRegisters(RankedTensorType srcTy, RankedTensorType dstTy);
195+
196+
// Conversion from `srcTy` to `dstTy` involves data exchange across threads
197+
// within a warp. No data exchange across warps or blocks is needed.
198+
bool cvtNeedsWarpShuffle(RankedTensorType srcTy, RankedTensorType dstTy);
199+
200+
// Conversion from `srcTy` to `dstTy` involves data exchange across threads,
201+
// warps, and possibly blocks.
192202
bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy);
193203

194204
bool atomicNeedsSharedMemory(Value result);
195205

196-
bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy);
206+
bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);
197207

198208
bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);
199209

200-
// TODO(jlebar): Remove this function; it's subsumed by the linear-layout case
201-
// in cvtNeedsSharedMemory.
202-
bool isMmaToMmaShortcut(RankedTensorType srcTy, RankedTensorType dstTy);
203-
204210
// Return true if the src and dst layout match.
205211
bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
206212
RankedTensorType dstTy);

include/triton/Conversion/TritonGPUToLLVM/Patterns.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ void decomposeSplatOpToSharedLayoutConversion(ModuleOp module);
2020
/// Replaces `mma/mfma -> dot_op` with `mma/mfma -> blocked -> dot_op` in the
2121
/// given |module| op, but bypass the decomposition if |shortcutFn| returns
2222
/// true.
23-
using ShortcutFn = std::function<bool(RankedTensorType &, RankedTensorType &)>;
23+
using ShortcutFn = std::function<bool(RankedTensorType, RankedTensorType)>;
2424
template <typename TensorCoreEncodingAttr>
2525
void decomposeTensorCoreToDotLayoutConversion(ModuleOp module,
2626
ShortcutFn shortcutFn);

lib/Analysis/Allocation.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ getCvtOrder(Attribute srcLayout, Attribute dstLayout) {
4444
auto dstMmaLayout = mlir::dyn_cast<NvidiaMmaEncodingAttr>(dstLayout);
4545
auto dstDotLayout = mlir::dyn_cast<DotOperandEncodingAttr>(dstLayout);
4646

47-
assert(!(srcMmaLayout && dstMmaLayout && !srcMmaLayout.isAmpere()) &&
47+
assert(!(srcMmaLayout && dstMmaLayout && !srcMmaLayout.isAmpere() &&
48+
!srcMmaLayout.isHopper()) &&
4849
"mma -> mma layout conversion is only supported on Ampere");
4950

5051
// mma or dot layout does not have an order, so the order depends on the

lib/Analysis/Utility.cpp

Lines changed: 46 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -527,7 +527,7 @@ bool supportMMA(Value value, int version) {
527527
(elemTy.isInteger(8) && version >= 2);
528528
}
529529

530-
bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
530+
bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
531531
auto mfmaLayout = dyn_cast<AMDMfmaEncodingAttr>(srcTy.getEncoding());
532532
auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
533533
if (mfmaLayout == nullptr || dotOperandLayout == nullptr)
@@ -543,21 +543,6 @@ bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
543543
(srcTy.getElementType().isF16() || srcTy.getElementType().isBF16());
544544
}
545545

546-
static bool isMmaToMmaShortcut(Attribute srcEncoding, Attribute dstEncoding) {
547-
auto src = dyn_cast<NvidiaMmaEncodingAttr>(srcEncoding);
548-
auto dst = dyn_cast<NvidiaMmaEncodingAttr>(dstEncoding);
549-
if (!src || !dst)
550-
return false;
551-
// when #mma = MmaEncoding<version=3, warpsPerCTA=[..., 1]>
552-
return src && dst && src.getVersionMajor() == 3 &&
553-
src.getWarpsPerCTA()[1] == 1 && dst.getVersionMajor() == 3 &&
554-
dst.getWarpsPerCTA()[1] == 1;
555-
}
556-
557-
bool isMmaToMmaShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
558-
return isMmaToMmaShortcut(srcTy.getEncoding(), dstTy.getEncoding());
559-
}
560-
561546
// For MMAV3 dotOperand layout matches mma operand for f16 and bf16 cases.
562547
bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
563548
RankedTensorType dstTy) {
@@ -567,14 +552,16 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
567552
return false;
568553
}
569554
int elementTypeSize = srcTy.getElementType().getIntOrFloatBitWidth();
570-
auto ans =
571-
mmaLayout.getVersionMajor() == 3 && dotOperandLayout.getOpIdx() == 0 &&
572-
isMmaToMmaShortcut(dotOperandLayout.getParent(), srcTy.getEncoding()) &&
573-
(elementTypeSize == 16 || elementTypeSize == 8);
555+
auto parentTy = RankedTensorType::get(
556+
srcTy.getShape(), srcTy.getElementType(), dotOperandLayout.getParent());
557+
auto ans = mmaLayout.getVersionMajor() == 3 &&
558+
dotOperandLayout.getOpIdx() == 0 &&
559+
!cvtNeedsSharedMemory(parentTy, srcTy) &&
560+
(elementTypeSize == 16 || elementTypeSize == 8);
574561
return ans;
575562
}
576563

577-
bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) {
564+
bool cvtReordersRegisters(RankedTensorType srcTy, RankedTensorType dstTy) {
578565
MLIRContext *ctx = srcTy.getContext();
579566
std::optional<LinearLayout> srcLayout =
580567
toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
@@ -586,26 +573,54 @@ bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) {
586573
StringAttr kLane = StringAttr::get(ctx, "lane");
587574
StringAttr kWarp = StringAttr::get(ctx, "warp");
588575
StringAttr kBlock = StringAttr::get(ctx, "block");
589-
// In principle, there's no need for shared memory if there's no
590-
// communication between warps. However, right now we only have implemented
591-
// the shortcut case where there's no communication between *threads*.
592-
//
593-
// TODO(jlebar): Remove the kLane layout once we add support for
594-
// shuffle-based layout conversions in ConvertLayoutToLLVM.
576+
// TODO(jlebar): These checks are overly-restrictive. For example, we can
577+
// transfer by shuffling registers (case 1) if and only if all of the bases
578+
// for `register` have 0s for lane, warp, and block. But the check below is
579+
// stronger than this, checking also that the choice of lane/warp/block does
580+
// not affect the permutation of registers. If we allow different
581+
// lane/warp/blocks to have different permutations, we can generalize this.
595582
if (comp.divideRight(LinearLayout::identity1D(comp.getInDimSize(kLane),
596583
kLane, kLane) *
597584
LinearLayout::identity1D(comp.getInDimSize(kWarp),
598585
kWarp, kWarp) *
599586
LinearLayout::identity1D(comp.getInDimSize(kBlock),
600587
kBlock, kBlock))
601588
.has_value()) {
602-
return false;
589+
return true;
603590
}
604591
}
592+
return false;
593+
}
605594

606-
// TODO(jlebar): Remove these special cases once they're fully subsumed by the
607-
// linear-layout check above.
608-
return !isMmaToMmaShortcut(srcTy, dstTy) &&
595+
bool cvtNeedsWarpShuffle(RankedTensorType srcTy, RankedTensorType dstTy) {
596+
MLIRContext *ctx = srcTy.getContext();
597+
std::optional<LinearLayout> srcLayout =
598+
toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
599+
std::optional<LinearLayout> dstLayout =
600+
toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
601+
if (srcLayout.has_value() && dstLayout.has_value()) {
602+
// comp describes the layout function for converting from src to dst.
603+
LinearLayout comp = srcLayout->invertAndCompose(*dstLayout);
604+
StringAttr kWarp = StringAttr::get(ctx, "warp");
605+
StringAttr kBlock = StringAttr::get(ctx, "block");
606+
if (comp.divideRight(LinearLayout::identity1D(comp.getInDimSize(kWarp),
607+
kWarp, kWarp) *
608+
LinearLayout::identity1D(comp.getInDimSize(kBlock),
609+
kBlock, kBlock))
610+
.has_value()) {
611+
return true;
612+
}
613+
}
614+
return false;
615+
}
616+
617+
bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) {
618+
// TODO(jlebar): Remove these special cases (`isMmaToDotShortcut` and
619+
// `isMfmaToDotShortcut`) once they're fully subsumed by the linear-layout
620+
// checks.
621+
// TODO(Keren): We didn't check `cvtNeedsWarpShuffle` here because it's not
622+
// supported yet in Triton's backend.
623+
return !cvtReordersRegisters(srcTy, dstTy) &&
609624
!isMmaToDotShortcut(srcTy, dstTy) &&
610625
!isMfmaToDotShortcut(srcTy, dstTy);
611626
}

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 55 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -250,10 +250,12 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
250250
MLIRContext *ctx = op.getContext();
251251

252252
const auto &shape = op.getType().getShape();
253+
auto srcTy = op.getSrc().getType();
254+
auto dstTy = op.getType();
253255
std::optional<LinearLayout> srcLayout =
254-
toLinearLayout(shape, op.getSrc().getType().getEncoding());
256+
toLinearLayout(shape, srcTy.getEncoding());
255257
std::optional<LinearLayout> dstLayout =
256-
toLinearLayout(shape, op.getType().getEncoding());
258+
toLinearLayout(shape, dstTy.getEncoding());
257259
if (!srcLayout.has_value() || !dstLayout.has_value()) {
258260
return failure();
259261
}
@@ -270,93 +272,94 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
270272
// 4. Transfer between values in different CTAs, in which case we move
271273
// values through distributed shared memory.
272274
//
273-
// We can tell which case we're in by examining `conversion`. If e.g. the
274-
// block -> block mapping is {1, 2, 4, ...} then there's no movement between
275-
// data in different CTAs and we know we're not in case 4.
276-
LinearLayout conversion = srcLayout->invertAndCompose(*dstLayout);
277-
278-
int numLanes = conversion.getInDimSize(str_attr("lane"));
279-
int numWarps = conversion.getInDimSize(str_attr("warp"));
280-
int numBlocks = conversion.getInDimSize(str_attr("block"));
281-
282-
StringAttr kLane = str_attr("lane");
283-
StringAttr kWarp = str_attr("warp");
284-
StringAttr kBlock = str_attr("block");
285-
286-
// TODO(jlebar): These checks are overly-restrictive. For example, we can
287-
// transfer by shuffling registers (case 1) if and only if all of the bases
288-
// for `register` have 0s for lane, warp, and block. But the check below is
289-
// stronger than this, checking also that the choice of lane/warp/block does
290-
// not affect the permutation of registers. If we allow different
291-
// lane/warp/blocks to have different permutations, we can generalize this.
292-
if (std::optional<LinearLayout> c = conversion.divideRight(
293-
LinearLayout::identity1D(numLanes, kLane, kLane) *
294-
LinearLayout::identity1D(numWarps, kWarp, kWarp) *
295-
LinearLayout::identity1D(numBlocks, kBlock, kBlock));
296-
c.has_value()) {
297-
return transferWithinThread(*c, op, adaptor, rewriter);
275+
// We can tell which case we're in by examining `conversion`.
276+
// For example, if the block -> block mapping is an identity layout: {1, 2,
277+
// 4, ...}, then there's no movement between data in different CTAs, and we
278+
// know we're not in case 4.
279+
if (cvtReordersRegisters(srcTy, dstTy)) { // Case 1.
280+
return transferWithinThread(op, *srcLayout, *dstLayout, adaptor,
281+
rewriter);
298282
}
299283

300-
if (std::optional<LinearLayout> c = conversion.divideRight(
301-
LinearLayout::identity1D(numWarps, kWarp, kWarp) *
302-
LinearLayout::identity1D(numBlocks, kBlock, kBlock));
303-
c.has_value()) {
304-
return transferWithinLane(*c, op, adaptor, rewriter);
284+
if (cvtNeedsWarpShuffle(srcTy, dstTy)) { // Case 2.
285+
return transferWithinLane(op, *srcLayout, *dstLayout, adaptor, rewriter);
305286
}
306287

307-
return transferWithinBlockOrGroup(conversion, op, *srcLayout, *dstLayout,
308-
adaptor, rewriter);
288+
return transferWithinBlockOrGroup(op, *srcLayout, *dstLayout, adaptor,
289+
rewriter); // Case 3 and 4
309290
}
310291

311292
LogicalResult
312-
transferWithinThread(const LinearLayout &conversion, ConvertLayoutOp op,
313-
OpAdaptor adaptor,
293+
transferWithinThread(ConvertLayoutOp op, const LinearLayout &srcLayout,
294+
const LinearLayout &dstLayout, OpAdaptor adaptor,
314295
ConversionPatternRewriter &rewriter) const {
315296
MLIRContext *ctx = op.getContext();
316297
auto loc = op.getLoc();
317298
StringAttr kRegister = str_attr("register");
299+
StringAttr kLane = str_attr("lane");
300+
StringAttr kWarp = str_attr("warp");
301+
StringAttr kBlock = str_attr("block");
302+
303+
// There are three possible cases:
304+
//
305+
// 1. `srcLayout` has the same number of registers as `dstLayout`.
306+
// 2. `srcLayout` has fewer registers than `dstLayout`.
307+
// 3. `srcLayout` has more registers than `dstLayout`.
308+
//
309+
// In the second case `srcLayout . dstLayout^-1` is not surjective
310+
// because not all destination registers are covered.
311+
// Since the goal is to cover all of the destination
312+
// registers, we can instead use `dstLayout . srcLayout^-1`.
313+
LinearLayout conversion = dstLayout.invertAndCompose(srcLayout);
314+
auto dstToSrc = conversion.divideRight(
315+
LinearLayout::identity1D(conversion.getInDimSize(kLane), kLane, kLane) *
316+
LinearLayout::identity1D(conversion.getInDimSize(kWarp), kWarp, kWarp) *
317+
LinearLayout::identity1D(conversion.getInDimSize(kBlock), kBlock,
318+
kBlock));
318319

319320
assert(!cvtNeedsSharedMemory(op.getSrc().getType(), op.getType()));
320-
assert(ArrayRef(to_vector(conversion.getInDimNames())) ==
321+
assert(ArrayRef(to_vector(dstToSrc->getInDimNames())) ==
321322
ArrayRef{kRegister});
322-
assert(ArrayRef(to_vector(conversion.getOutDimNames())) ==
323+
assert(ArrayRef(to_vector(dstToSrc->getOutDimNames())) ==
323324
ArrayRef{kRegister});
324325

325326
auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter);
326-
SmallVector<Value> outVals(conversion.getOutDimSize(kRegister));
327-
for (int i = 0; i < conversion.getInDimSize(kRegister); i++) {
328-
auto dstIdx = conversion.apply({{kRegister, i}});
329-
outVals[dstIdx.begin()->second] = inVals[i];
327+
SmallVector<Value> outVals;
328+
outVals.resize(dstToSrc->getInDimSize(kRegister));
329+
for (int i = 0; i < dstToSrc->getInDimSize(kRegister); i++) {
330+
auto srcIdx = dstToSrc->apply({{kRegister, i}});
331+
outVals[i] = inVals[srcIdx.begin()->second];
330332
}
331333
Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter,
332334
op.getType());
333335
rewriter.replaceOp(op, result);
334336
return success();
335337
}
336338

337-
LogicalResult transferWithinLane(const LinearLayout &conversion,
338-
ConvertLayoutOp op, OpAdaptor adaptor,
339+
LogicalResult transferWithinLane(ConvertLayoutOp op,
340+
const LinearLayout &srcLayout,
341+
const LinearLayout &dstLayout,
342+
OpAdaptor adaptor,
339343
ConversionPatternRewriter &rewriter) const {
340344
// TODO(jlebar): Implement me.
341345
return failure();
342346
}
343347

344348
LogicalResult
345-
transferWithinBlockOrGroup(const LinearLayout &conversion, ConvertLayoutOp op,
346-
const LinearLayout &srcLayout,
349+
transferWithinBlockOrGroup(ConvertLayoutOp op, const LinearLayout &srcLayout,
347350
const LinearLayout &dstLayout, OpAdaptor adaptor,
348351
ConversionPatternRewriter &rewriter) const {
352+
LinearLayout conversion = srcLayout.invertAndCompose(dstLayout);
353+
349354
// TODO(Keren): LLs support cross-CTA conversions, this function does not
350355
if (isCrossCTAConversion(conversion))
351356
return failure();
352357

353358
MLIRContext *ctx = op.getContext();
354359
auto loc = op.getLoc();
355360

356-
assert(cvtNeedsSharedMemory(op.getSrc().getType(), op.getType()));
357-
358-
// TODO(jlebar): For now we handle only blocked/slice -> blocked/slice
359-
// conversions. Once we have ldmatrix support in
361+
// TODO(jlebar): For now we handle only blocked/slice ->
362+
// blocked/slice conversions. Once we have ldmatrix support in
360363
// load/storeDistributedToShared, we can remove this constraint.
361364
std::function<bool(Attribute)> layoutIsOK = [&](Attribute layout) {
362365
if (isa<BlockedEncodingAttr>(layout)) {
@@ -372,6 +375,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
372375
return failure();
373376
}
374377

378+
assert(cvtNeedsSharedMemory(op.getSrc().getType(), op.getType()));
379+
375380
SmallVector<Value> inVals =
376381
unpackLLElements(loc, adaptor.getSrc(), rewriter);
377382
assert(!inVals.empty());

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -336,8 +336,8 @@ LogicalResult MakeRangeOp::verify() {
336336

337337
//-- ReduceOp --
338338
static LogicalResult
339-
inferReduceReturnShape(const RankedTensorType &argTy, const Type &retEltTy,
340-
int axis, SmallVectorImpl<Type> &inferredReturnTypes) {
339+
inferReduceReturnShape(RankedTensorType argTy, Type retEltTy, int axis,
340+
SmallVectorImpl<Type> &inferredReturnTypes) {
341341
auto retShape = argTy.getShape().vec();
342342
retShape.erase(retShape.begin() + axis);
343343
if (retShape.empty()) {

0 commit comments

Comments
 (0)