Skip to content

Commit 9afbfb1

Browse files
committed
Merge commit '13594bbf4571fbee066fe45fb8b1874690598e97'
2 parents 5feeb96 + 13594bb commit 9afbfb1

File tree

9 files changed

+288
-464
lines changed

9 files changed

+288
-464
lines changed

include/triton/Analysis/Utility.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "mlir/Support/LLVM.h"
77
#include "triton/Dialect/Triton/IR/Dialect.h"
88
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
9+
#include "triton/Tools/LinearLayout.h"
910

1011
namespace mlir {
1112

@@ -189,6 +190,14 @@ bool supportMMA(triton::DotOp op, int version);
189190

190191
bool supportMMA(Value value, int version);
191192

193+
// Conversion from `srcTy` to `dstTy` involving the minimum amount of data
194+
// transfer provided that both types can be converted to LL (if it can't it'll
195+
// return nullopt). The output will be such that layout.getInDimNames() ==
196+
// layout.getOutDimNames() and the conversion will not include kBlock (resp.
197+
// kWarp or kLane) if it can be avoided
198+
std::optional<mlir::triton::LinearLayout>
199+
minimalCvtLayout(RankedTensorType srcTy, RankedTensorType dstTy);
200+
192201
// Conversion from `srcTy` to `dstTy` only involves reordering of registers.
193202
// There is no need for data exchange across threads, warps, or blocks.
194203
bool cvtReordersRegisters(RankedTensorType srcTy, RankedTensorType dstTy);

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,6 @@ std::optional<LinearLayout>
4444
toLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
4545
std::optional<int32_t> elemBitWidth = std::nullopt);
4646

47-
// Given a linear layout with input dims and output dims containing a "block"
48-
// dimension, determines if the layout moves data across block boundaries.
49-
bool isCrossCTAConversion(const LinearLayout &layout);
50-
5147
// Given a linear layout where the input dimensions contain a "block" dimension,
5248
// this method sets the "block" dimension to 0 and removes the corresponding
5349
// output dimensions.

include/triton/Tools/LinearLayout.h

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -575,29 +575,20 @@ class LinearLayout {
575575
return *this;
576576
}
577577

578-
// divideLeft and divideRight are the inverses of operator*.
579-
//
580-
// Consider `a = c.divideRight(b)`, where `a` is a linear layout with
581-
// `in-dims(a) == in-dims(b)` and `out-dims(a) == out-dims(c)`. We may remove
582-
// some empty dimensions from `a` to form `a'` and still have `a' * b == c`.
583-
// Therefore, there are multiple possible values that we could return for
584-
// `(a * b).divideRight(b)` which would satisfy
585-
// `((a * b).divideRight(b)) * b == a * b`.
586-
//
587-
// In the following example, we have `a * b == a' * b` when "in1" is an empty
588-
// dimension that maps everything to 0:
589-
//
590-
// a = L("in1", "in2") -> ("out1", "out2")
591-
// a' = L("in1") -> ("out1")
592-
// b = L("in2") -> ("out2")
593-
//
594-
// divideLeft and divideRight resolve this ambiguity by always returning the
595-
// "canonical" quotient, namely the one with the fewest possible size-zero
596-
// input and output dimensions.
597-
//
598-
// TODO(jlebar): Implement divideLeft.
599-
// std::optional<LinearLayout> divideLeft(const LinearLayout &divisor);
600-
std::optional<LinearLayout> divideRight(const LinearLayout &divisor) const;
578+
// Returns true if this layout acts trivially (as the identity) on the given
579+
// dimensions. This means that it's the identity on those dimensions, and it
580+
// does not map other dimensions onto those or these onto other dimensions.
581+
bool isTrivialOver(ArrayRef<StringAttr> dimNames) const;
582+
583+
// For an endomorphism on dimNames (linear map that maps dimNames to dimNames)
584+
// checks whether it is the identity map on these dimensions (i.e
585+
// LinearLayouts::isTrivialOver) and if so, returns the sublayout of the
586+
// remaining dimensions.
587+
// nb. The isTrivialOver condition is more restrictive than the usual
588+
// "leaves the subspace invariant" condition in maths.
589+
// We can always relax it if we know how to take advantage of a conversion
590+
// layout being block-diagonal in the future.
591+
std::optional<LinearLayout> quotient(ArrayRef<StringAttr> dimNames) const;
601592

602593
// Gets a layout with only these in/out dimensions.
603594
//
@@ -614,10 +605,10 @@ class LinearLayout {
614605
bool sublayoutIsZero(ArrayRef<StringAttr> inDimNames,
615606
ArrayRef<StringAttr> outDimNames) const;
616607

617-
// Is the sublayout restricted to inDimNames + outDimNames and then flattened
618-
// to 1D the identity layout (ignoring out-dim sizes)?
619-
bool sublayoutIsIdentity(ArrayRef<StringAttr> inDimNames,
620-
ArrayRef<StringAttr> outDimNames) const;
608+
// Is the sublayout defined from dimNames to dimNames the identity?
609+
// In particular, is the input and output size in these dimensions
610+
// the same, and are the bases the identity?
611+
bool squareSublayoutIsIdentity(ArrayRef<StringAttr> dimNames) const;
621612

622613
// Computes and returns L(x, y, z).
623614
//

lib/Analysis/Utility.cpp

Lines changed: 38 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -647,57 +647,56 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
647647
return ans;
648648
}
649649

650-
bool cvtReordersRegisters(RankedTensorType srcTy, RankedTensorType dstTy) {
650+
// We get the smallest submap of srcTy^{-1} * dstTy that is not the identity
651+
// under kBlock, kWarp or kLane (in that order). The idea here is that if we
652+
// have a transformation that's the identity on kBlock, we don't need to use
653+
// distributed shared memory. If it's also the identity on kWarp, we can
654+
// transfer via warp-shuffles, and if it's the identity on kLane just have to
655+
// reorder the registers
656+
std::optional<LinearLayout> minimalCvtLayout(RankedTensorType srcTy,
657+
RankedTensorType dstTy) {
651658
MLIRContext *ctx = srcTy.getContext();
652659
std::optional<LinearLayout> srcLayout =
653660
toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
654661
std::optional<LinearLayout> dstLayout =
655662
toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
656-
if (srcLayout.has_value() && dstLayout.has_value()) {
657-
// comp describes the layout function for converting from src to dst.
658-
LinearLayout comp = srcLayout->invertAndCompose(*dstLayout);
659-
StringAttr kLane = StringAttr::get(ctx, "lane");
660-
StringAttr kWarp = StringAttr::get(ctx, "warp");
661-
StringAttr kBlock = StringAttr::get(ctx, "block");
662-
// TODO(jlebar): These checks are overly-restrictive. For example, we can
663-
// transfer by shuffling registers (case 1) if and only if all of the bases
664-
// for `register` have 0s for lane, warp, and block. But the check below is
665-
// stronger than this, checking also that the choice of lane/warp/block does
666-
// not affect the permutation of registers. If we allow different
667-
// lane/warp/blocks to have different permutations, we can generalize this.
668-
if (comp.divideRight(LinearLayout::identity1D(comp.getInDimSize(kLane),
669-
kLane, kLane) *
670-
LinearLayout::identity1D(comp.getInDimSize(kWarp),
671-
kWarp, kWarp) *
672-
LinearLayout::identity1D(comp.getInDimSize(kBlock),
673-
kBlock, kBlock))
674-
.has_value()) {
675-
return true;
663+
if (!(srcLayout.has_value() && dstLayout.has_value()))
664+
return std::nullopt;
665+
// comp describes the layout function to create dst from src.
666+
LinearLayout comp = dstLayout->invertAndCompose(*srcLayout);
667+
// We try to quotient by the largest subspace first
668+
auto dims = SmallVector<StringRef>{"block", "warp", "lane", "register"};
669+
for (auto dim : dims) {
670+
auto quotient = comp.quotient(StringAttr::get(ctx, dim));
671+
if (!quotient.has_value()) {
672+
break;
676673
}
674+
comp = *quotient;
677675
}
678-
return false;
676+
return comp;
677+
}
678+
679+
bool cvtReordersRegisters(RankedTensorType srcTy, RankedTensorType dstTy) {
680+
auto layout = minimalCvtLayout(srcTy, dstTy);
681+
MLIRContext *ctx = srcTy.getContext();
682+
if (!layout.has_value()) {
683+
return false;
684+
}
685+
auto kRegister = StringAttr::get(ctx, "register");
686+
auto outDims = llvm::to_vector(layout->getOutDimNames());
687+
return outDims.empty() || ArrayRef(outDims) == ArrayRef({kRegister});
679688
}
680689

681690
bool cvtNeedsWarpShuffle(RankedTensorType srcTy, RankedTensorType dstTy) {
691+
auto layout = minimalCvtLayout(srcTy, dstTy);
682692
MLIRContext *ctx = srcTy.getContext();
683-
std::optional<LinearLayout> srcLayout =
684-
toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
685-
std::optional<LinearLayout> dstLayout =
686-
toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
687-
if (srcLayout.has_value() && dstLayout.has_value()) {
688-
// comp describes the layout function for converting from src to dst.
689-
LinearLayout comp = srcLayout->invertAndCompose(*dstLayout);
690-
StringAttr kWarp = StringAttr::get(ctx, "warp");
691-
StringAttr kBlock = StringAttr::get(ctx, "block");
692-
if (comp.divideRight(LinearLayout::identity1D(comp.getInDimSize(kWarp),
693-
kWarp, kWarp) *
694-
LinearLayout::identity1D(comp.getInDimSize(kBlock),
695-
kBlock, kBlock))
696-
.has_value()) {
697-
return true;
698-
}
693+
if (!layout.has_value()) {
694+
return false;
699695
}
700-
return false;
696+
auto kRegister = StringAttr::get(ctx, "register");
697+
auto kLane = StringAttr::get(ctx, "lane");
698+
return llvm::to_vector(layout->getOutDimNames()) ==
699+
llvm::SmallVector<StringAttr, 2>{kRegister, kLane};
701700
}
702701

703702
bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) {

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 60 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -282,111 +282,79 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
282282
const auto &shape = op.getType().getShape();
283283
auto srcTy = op.getSrc().getType();
284284
auto dstTy = op.getType();
285-
std::optional<LinearLayout> srcLayout =
286-
toLinearLayout(shape, srcTy.getEncoding());
287-
std::optional<LinearLayout> dstLayout =
288-
toLinearLayout(shape, dstTy.getEncoding());
289-
if (!srcLayout.has_value() || !dstLayout.has_value()) {
290-
return failure();
291-
}
292285

293-
// There are four cases to handle.
294-
//
295-
// 1. Transfer between values in the same thread, in which case we simply
296-
// reorder the elements of adaptor.getSrc().
297-
// 2. Transfer between values in the same warp, in which case we try to
298-
// move values using warp shuffles, though if the pattern is complicated
299-
// enough we may fall back to using shared memory (case 3).
300-
// 3. Transfer between values in the same CTA, in which case we move values
301-
// through shared memory.
302-
// 4. Transfer between values in different CTAs, in which case we move
303-
// values through distributed shared memory.
304-
//
305-
// We can tell which case we're in by examining `conversion`.
306-
// For example, if the block -> block mapping is an identity layout: {1, 2,
307-
// 4, ...}, then there's no movement between data in different CTAs, and we
308-
// know we're not in case 4.
309-
if (cvtReordersRegisters(srcTy, dstTy)) { // Case 1.
310-
return transferWithinThread(op, *srcLayout, *dstLayout, adaptor,
311-
rewriter);
286+
auto conversion = minimalCvtLayout(srcTy, dstTy);
287+
if (!conversion.has_value()) {
288+
return rewriter.notifyMatchFailure(
289+
op, "NYI. srcTy and/or dstTy don't implement LLs yet");
312290
}
313291

314-
if (cvtNeedsWarpShuffle(srcTy, dstTy)) { // Case 2.
315-
return transferWithinLane(op, *srcLayout, *dstLayout, adaptor, rewriter);
292+
assert(to_vector(conversion->getInDimNames()) ==
293+
to_vector(conversion->getOutDimNames()));
294+
auto dims = conversion->getInDimNames();
295+
if (llvm::is_contained(dims, str_attr("block"))) {
296+
// Case 1: Transfer between values in different CTAs.
297+
// This requires moving values through distributed shared memory.
298+
return rewriter.notifyMatchFailure(
299+
op, "NYI: Transfer between different CTAs");
300+
} else if (llvm::is_contained(dims, str_attr("warp"))) {
301+
// Case 2: Transfer between values in the same CTA, in which case we move
302+
// values through shared memory.
303+
LinearLayout srcLayout =
304+
*toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
305+
LinearLayout dstLayout =
306+
*toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
307+
return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter);
308+
} else if (llvm::is_contained(dims, str_attr("lane"))) {
309+
// Case 3. Transfer between values in the same warp, in which case we try
310+
// to move values using warp shuffles, though if the pattern is
311+
// complicated enough we may fall back to using shared memory
312+
// TODO(Keren): implement warp shuffle instead of using the general
313+
// approach that uses shared memory
314+
LinearLayout srcLayout =
315+
*toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
316+
LinearLayout dstLayout =
317+
*toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
318+
return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter);
319+
} else if (llvm::is_contained(dims, str_attr("register"))) {
320+
// Case 4. Transfer between values in the same thread, in which case we
321+
// simply reorder the elements of adaptor.getSrc().
322+
return transferWithinThread(op, *conversion, adaptor, rewriter);
323+
} else {
324+
// The two layouts are equivalent. We should probably remove these in
325+
// RemoveLayoutConversion.
326+
rewriter.replaceOp(op, adaptor.getSrc());
327+
return success();
316328
}
317-
318-
return transferWithinBlockOrGroup(op, *srcLayout, *dstLayout, adaptor,
319-
rewriter); // Case 3 and 4
320329
}
321330

322331
LogicalResult
323-
transferWithinThread(ConvertLayoutOp op, const LinearLayout &srcLayout,
324-
const LinearLayout &dstLayout, OpAdaptor adaptor,
332+
transferWithinThread(ConvertLayoutOp op, const LinearLayout &conversion,
333+
OpAdaptor adaptor,
325334
ConversionPatternRewriter &rewriter) const {
326335
MLIRContext *ctx = op.getContext();
327336
auto loc = op.getLoc();
328337
StringAttr kRegister = str_attr("register");
329-
StringAttr kLane = str_attr("lane");
330-
StringAttr kWarp = str_attr("warp");
331-
StringAttr kBlock = str_attr("block");
332-
333-
// There are three possible cases:
334-
//
335-
// 1. `srcLayout` has the same number of registers as `dstLayout`.
336-
// 2. `srcLayout` has fewer registers than `dstLayout`.
337-
// 3. `srcLayout` has more registers than `dstLayout`.
338-
//
339-
// In the second case `srcLayout . dstLayout^-1` is not surjective
340-
// because not all destination registers are covered.
341-
// Since the goal is to cover all of the destination
342-
// registers, we can instead use `dstLayout . srcLayout^-1`.
343-
LinearLayout conversion = dstLayout.invertAndCompose(srcLayout);
344-
auto dstToSrc = conversion.divideRight(
345-
LinearLayout::identity1D(conversion.getInDimSize(kLane), kLane, kLane) *
346-
LinearLayout::identity1D(conversion.getInDimSize(kWarp), kWarp, kWarp) *
347-
LinearLayout::identity1D(conversion.getInDimSize(kBlock), kBlock,
348-
kBlock));
349-
350338
assert(!cvtNeedsSharedMemory(op.getSrc().getType(), op.getType()));
351-
assert(ArrayRef(to_vector(dstToSrc->getInDimNames())) ==
352-
ArrayRef{kRegister});
353-
assert(ArrayRef(to_vector(dstToSrc->getOutDimNames())) ==
354-
ArrayRef{kRegister});
355339

356340
auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter);
357341
SmallVector<Value> outVals;
358-
outVals.resize(dstToSrc->getInDimSize(kRegister));
359-
for (int i = 0; i < dstToSrc->getInDimSize(kRegister); i++) {
360-
auto srcIdx = dstToSrc->apply({{kRegister, i}});
361-
outVals[i] = inVals[srcIdx.begin()->second];
342+
outVals.resize(conversion.getInDimSize(kRegister));
343+
for (int i = 0; i < conversion.getInDimSize(kRegister); i++) {
344+
auto srcIdx = conversion.apply({{kRegister, i}}).begin()->second;
345+
outVals[i] = inVals[srcIdx];
362346
}
363347
Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter,
364348
op.getType());
365349
rewriter.replaceOp(op, result);
366350
return success();
367351
}
368352

369-
LogicalResult transferWithinLane(ConvertLayoutOp op,
370-
const LinearLayout &srcLayout,
371-
const LinearLayout &dstLayout,
372-
OpAdaptor adaptor,
373-
ConversionPatternRewriter &rewriter) const {
374-
// TODO(Keren): implement warp shuffle instead of using the general approach
375-
// that uses shared memory
376-
return transferWithinBlockOrGroup(op, srcLayout, dstLayout, adaptor,
377-
rewriter);
378-
}
379-
380-
LogicalResult
381-
transferWithinBlockOrGroup(ConvertLayoutOp op, const LinearLayout &srcLayout,
382-
const LinearLayout &dstLayout, OpAdaptor adaptor,
383-
ConversionPatternRewriter &rewriter) const {
384-
LinearLayout conversion = srcLayout.invertAndCompose(dstLayout);
385-
386-
// TODO(Keren): LLs support cross-CTA conversions, this function does not
387-
if (isCrossCTAConversion(conversion))
388-
return failure();
389-
353+
LogicalResult transferWithinBlock(ConvertLayoutOp op,
354+
const LinearLayout &srcLayout,
355+
const LinearLayout &dstLayout,
356+
OpAdaptor adaptor,
357+
ConversionPatternRewriter &rewriter) const {
390358
MLIRContext *ctx = op.getContext();
391359
auto loc = op.getLoc();
392360
auto srcTy = op.getSrc().getType();
@@ -445,11 +413,14 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
445413
}
446414
}
447415

416+
// Pretty sure this is the identity function ATM
417+
// It'd be better to simply call `quotient({kBlock})` and
418+
// remove kBlock from transferWithinBlockImpl
448419
auto srcLayoutWithinBlock = getLayoutWithinBlock(srcLayout);
449420
auto dstLayoutWithinBlock = getLayoutWithinBlock(dstLayout);
450421
SmallVector<Value> outVals =
451-
transferWithinBlock(inVals, op, srcLayoutWithinBlock,
452-
dstLayoutWithinBlock, adaptor, rewriter);
422+
transferWithinBlockImpl(inVals, op, srcLayoutWithinBlock,
423+
dstLayoutWithinBlock, adaptor, rewriter);
453424

454425
// Unmunge output values
455426
for (const auto &it : llvm::enumerate(outVals)) {
@@ -467,10 +438,10 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
467438
}
468439

469440
SmallVector<Value>
470-
transferWithinBlock(ArrayRef<Value> inVals, ConvertLayoutOp op,
471-
const LinearLayout &srcLayout,
472-
const LinearLayout &dstLayout, OpAdaptor adaptor,
473-
ConversionPatternRewriter &rewriter) const {
441+
transferWithinBlockImpl(ArrayRef<Value> inVals, ConvertLayoutOp op,
442+
const LinearLayout &srcLayout,
443+
const LinearLayout &dstLayout, OpAdaptor adaptor,
444+
ConversionPatternRewriter &rewriter) const {
474445
MLIRContext *ctx = op.getContext();
475446
auto loc = op.getLoc();
476447

0 commit comments

Comments
 (0)