Skip to content

Commit 13594bb

Browse files
authored
Add LL::quotient and remove uses of divideRight and sublayoutIsIdentity (#4968)
We add a new abstraction `LL::quotient` that abstracts the idea of "a linear layout does not permute certain dimensions". Doing so, allows us to remove `divideRight` and subsume them into this higher-level abstraction. We also fix a bug in `isCrossCTAConversion`. We also remove some code duplication from `transferWithinThreads` and `cvtReorderRegisters` in favour of a more generic approach. We fix a bug in `sublayout` that meant that `sublayout` would reorder `outDims` at will by using a set instead of a vector. I am missing adding tests for LL::quotient, will do in a minute.
1 parent 3613bf4 commit 13594bb

File tree

9 files changed

+287
-463
lines changed

9 files changed

+287
-463
lines changed

include/triton/Analysis/Utility.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "mlir/Support/LLVM.h"
77
#include "triton/Dialect/Triton/IR/Dialect.h"
88
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
9+
#include "triton/Tools/LinearLayout.h"
910

1011
namespace mlir {
1112

@@ -189,6 +190,14 @@ bool supportMMA(triton::DotOp op, int version);
189190

190191
bool supportMMA(Value value, int version);
191192

193+
// Conversion from `srcTy` to `dstTy` involving the minimum amount of data
194+
// transfer provided that both types can be converted to LL (if it can't it'll
195+
// return nullopt). The output will be such that layout.getInDimNames() ==
196+
// layout.getOutDimNames() and the conversion will not include kBlock (resp.
197+
// kWarp or kLane) if it can be avoided
198+
std::optional<mlir::triton::LinearLayout>
199+
minimalCvtLayout(RankedTensorType srcTy, RankedTensorType dstTy);
200+
192201
// Conversion from `srcTy` to `dstTy` only involves reordering of registers.
193202
// There is no need for data exchange across threads, warps, or blocks.
194203
bool cvtReordersRegisters(RankedTensorType srcTy, RankedTensorType dstTy);

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,6 @@ std::optional<LinearLayout>
4444
toLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
4545
std::optional<int32_t> elemBitWidth = std::nullopt);
4646

47-
// Given a linear layout with input dims and output dims containing a "block"
48-
// dimension, determines if the layout moves data across block boundaries.
49-
bool isCrossCTAConversion(const LinearLayout &layout);
50-
5147
// Given a linear layout where the input dimensions contain a "block" dimension,
5248
// this method sets the "block" dimension to 0 and removes the corresponding
5349
// output dimensions.

include/triton/Tools/LinearLayout.h

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -575,29 +575,20 @@ class LinearLayout {
575575
return *this;
576576
}
577577

578-
// divideLeft and divideRight are the inverses of operator*.
579-
//
580-
// Consider `a = c.divideRight(b)`, where `a` is a linear layout with
581-
// `in-dims(a) == in-dims(b)` and `out-dims(a) == out-dims(c)`. We may remove
582-
// some empty dimensions from `a` to form `a'` and still have `a' * b == c`.
583-
// Therefore, there are multiple possible values that we could return for
584-
// `(a * b).divideRight(b)` which would satisfy
585-
// `((a * b).divideRight(b)) * b == a * b`.
586-
//
587-
// In the following example, we have `a * b == a' * b` when "in1" is an empty
588-
// dimension that maps everything to 0:
589-
//
590-
// a = L("in1", "in2") -> ("out1", "out2")
591-
// a' = L("in1") -> ("out1")
592-
// b = L("in2") -> ("out2")
593-
//
594-
// divideLeft and divideRight resolve this ambiguity by always returning the
595-
// "canonical" quotient, namely the one with the fewest possible size-zero
596-
// input and output dimensions.
597-
//
598-
// TODO(jlebar): Implement divideLeft.
599-
// std::optional<LinearLayout> divideLeft(const LinearLayout &divisor);
600-
std::optional<LinearLayout> divideRight(const LinearLayout &divisor) const;
578+
// Returns true if this layout acts trivially (as the identity) on the given
579+
// dimensions. This means that it's the identity on those dimensions, and it
580+
// does not map other dimensions onto those or these onto other dimensions.
581+
bool isTrivialOver(ArrayRef<StringAttr> dimNames) const;
582+
583+
// For an endomorphism on dimNames (linear map that maps dimNames to dimNames)
584+
// checks whether it is the identity map on these dimensions (i.e
585+
// LinearLayouts::isTrivialOver) and if so, returns the sublayout of the
586+
// remaining dimensions.
587+
// nb. The isTrivialOver condition is more restrictive than the usual
588+
// "leaves the subspace invariant" condition in maths.
589+
// We can always relax it if we know how to take advantage of a conversion
590+
// layout being block-diagonal in the future.
591+
std::optional<LinearLayout> quotient(ArrayRef<StringAttr> dimNames) const;
601592

602593
// Gets a layout with only these in/out dimensions.
603594
//
@@ -614,10 +605,10 @@ class LinearLayout {
614605
bool sublayoutIsZero(ArrayRef<StringAttr> inDimNames,
615606
ArrayRef<StringAttr> outDimNames) const;
616607

617-
// Is the sublayout restricted to inDimNames + outDimNames and then flattened
618-
// to 1D the identity layout (ignoring out-dim sizes)?
619-
bool sublayoutIsIdentity(ArrayRef<StringAttr> inDimNames,
620-
ArrayRef<StringAttr> outDimNames) const;
608+
// Is the sublayout defined from dimNames to dimNames the identity?
609+
// In particular, is the input and output size in these dimensions
610+
// the same, and are the bases the identity?
611+
bool squareSublayoutIsIdentity(ArrayRef<StringAttr> dimNames) const;
621612

622613
// Computes and returns L(x, y, z).
623614
//

lib/Analysis/Utility.cpp

Lines changed: 38 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -640,57 +640,56 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
640640
return ans;
641641
}
642642

643-
bool cvtReordersRegisters(RankedTensorType srcTy, RankedTensorType dstTy) {
643+
// We get the smallest submap of srcTy^{-1} * dstTy that is not the identity
644+
// under kBlock, kWarp or kLane (in that order). The idea here is that if we
645+
// have a transformation that's the identity on kBlock, we don't need to use
646+
// distributed shared memory. If it's also the identity on kWarp, we can
647+
// transfer via warp-shuffles, and if it's the identity on kLane just have to
648+
// reorder the registers
649+
std::optional<LinearLayout> minimalCvtLayout(RankedTensorType srcTy,
650+
RankedTensorType dstTy) {
644651
MLIRContext *ctx = srcTy.getContext();
645652
std::optional<LinearLayout> srcLayout =
646653
toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
647654
std::optional<LinearLayout> dstLayout =
648655
toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
649-
if (srcLayout.has_value() && dstLayout.has_value()) {
650-
// comp describes the layout function for converting from src to dst.
651-
LinearLayout comp = srcLayout->invertAndCompose(*dstLayout);
652-
StringAttr kLane = StringAttr::get(ctx, "lane");
653-
StringAttr kWarp = StringAttr::get(ctx, "warp");
654-
StringAttr kBlock = StringAttr::get(ctx, "block");
655-
// TODO(jlebar): These checks are overly-restrictive. For example, we can
656-
// transfer by shuffling registers (case 1) if and only if all of the bases
657-
// for `register` have 0s for lane, warp, and block. But the check below is
658-
// stronger than this, checking also that the choice of lane/warp/block does
659-
// not affect the permutation of registers. If we allow different
660-
// lane/warp/blocks to have different permutations, we can generalize this.
661-
if (comp.divideRight(LinearLayout::identity1D(comp.getInDimSize(kLane),
662-
kLane, kLane) *
663-
LinearLayout::identity1D(comp.getInDimSize(kWarp),
664-
kWarp, kWarp) *
665-
LinearLayout::identity1D(comp.getInDimSize(kBlock),
666-
kBlock, kBlock))
667-
.has_value()) {
668-
return true;
656+
if (!(srcLayout.has_value() && dstLayout.has_value()))
657+
return std::nullopt;
658+
// comp describes the layout function to create dst from src.
659+
LinearLayout comp = dstLayout->invertAndCompose(*srcLayout);
660+
// We try to quotient by the largest subspace first
661+
auto dims = SmallVector<StringRef>{"block", "warp", "lane", "register"};
662+
for (auto dim : dims) {
663+
auto quotient = comp.quotient(StringAttr::get(ctx, dim));
664+
if (!quotient.has_value()) {
665+
break;
669666
}
667+
comp = *quotient;
670668
}
671-
return false;
669+
return comp;
670+
}
671+
672+
bool cvtReordersRegisters(RankedTensorType srcTy, RankedTensorType dstTy) {
673+
auto layout = minimalCvtLayout(srcTy, dstTy);
674+
MLIRContext *ctx = srcTy.getContext();
675+
if (!layout.has_value()) {
676+
return false;
677+
}
678+
auto kRegister = StringAttr::get(ctx, "register");
679+
auto outDims = llvm::to_vector(layout->getOutDimNames());
680+
return outDims.empty() || ArrayRef(outDims) == ArrayRef({kRegister});
672681
}
673682

674683
bool cvtNeedsWarpShuffle(RankedTensorType srcTy, RankedTensorType dstTy) {
684+
auto layout = minimalCvtLayout(srcTy, dstTy);
675685
MLIRContext *ctx = srcTy.getContext();
676-
std::optional<LinearLayout> srcLayout =
677-
toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
678-
std::optional<LinearLayout> dstLayout =
679-
toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
680-
if (srcLayout.has_value() && dstLayout.has_value()) {
681-
// comp describes the layout function for converting from src to dst.
682-
LinearLayout comp = srcLayout->invertAndCompose(*dstLayout);
683-
StringAttr kWarp = StringAttr::get(ctx, "warp");
684-
StringAttr kBlock = StringAttr::get(ctx, "block");
685-
if (comp.divideRight(LinearLayout::identity1D(comp.getInDimSize(kWarp),
686-
kWarp, kWarp) *
687-
LinearLayout::identity1D(comp.getInDimSize(kBlock),
688-
kBlock, kBlock))
689-
.has_value()) {
690-
return true;
691-
}
686+
if (!layout.has_value()) {
687+
return false;
692688
}
693-
return false;
689+
auto kRegister = StringAttr::get(ctx, "register");
690+
auto kLane = StringAttr::get(ctx, "lane");
691+
return llvm::to_vector(layout->getOutDimNames()) ==
692+
llvm::SmallVector<StringAttr, 2>{kRegister, kLane};
694693
}
695694

696695
bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) {

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 60 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -282,111 +282,79 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
282282
const auto &shape = op.getType().getShape();
283283
auto srcTy = op.getSrc().getType();
284284
auto dstTy = op.getType();
285-
std::optional<LinearLayout> srcLayout =
286-
toLinearLayout(shape, srcTy.getEncoding());
287-
std::optional<LinearLayout> dstLayout =
288-
toLinearLayout(shape, dstTy.getEncoding());
289-
if (!srcLayout.has_value() || !dstLayout.has_value()) {
290-
return failure();
291-
}
292285

293-
// There are four cases to handle.
294-
//
295-
// 1. Transfer between values in the same thread, in which case we simply
296-
// reorder the elements of adaptor.getSrc().
297-
// 2. Transfer between values in the same warp, in which case we try to
298-
// move values using warp shuffles, though if the pattern is complicated
299-
// enough we may fall back to using shared memory (case 3).
300-
// 3. Transfer between values in the same CTA, in which case we move values
301-
// through shared memory.
302-
// 4. Transfer between values in different CTAs, in which case we move
303-
// values through distributed shared memory.
304-
//
305-
// We can tell which case we're in by examining `conversion`.
306-
// For example, if the block -> block mapping is an identity layout: {1, 2,
307-
// 4, ...}, then there's no movement between data in different CTAs, and we
308-
// know we're not in case 4.
309-
if (cvtReordersRegisters(srcTy, dstTy)) { // Case 1.
310-
return transferWithinThread(op, *srcLayout, *dstLayout, adaptor,
311-
rewriter);
286+
auto conversion = minimalCvtLayout(srcTy, dstTy);
287+
if (!conversion.has_value()) {
288+
return rewriter.notifyMatchFailure(
289+
op, "NYI. srcTy and/or dstTy don't implement LLs yet");
312290
}
313291

314-
if (cvtNeedsWarpShuffle(srcTy, dstTy)) { // Case 2.
315-
return transferWithinLane(op, *srcLayout, *dstLayout, adaptor, rewriter);
292+
assert(to_vector(conversion->getInDimNames()) ==
293+
to_vector(conversion->getOutDimNames()));
294+
auto dims = conversion->getInDimNames();
295+
if (llvm::is_contained(dims, str_attr("block"))) {
296+
// Case 1: Transfer between values in different CTAs.
297+
// This requires moving values through distributed shared memory.
298+
return rewriter.notifyMatchFailure(
299+
op, "NYI: Transfer between different CTAs");
300+
} else if (llvm::is_contained(dims, str_attr("warp"))) {
301+
// Case 2: Transfer between values in the same CTA, in which case we move
302+
// values through shared memory.
303+
LinearLayout srcLayout =
304+
*toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
305+
LinearLayout dstLayout =
306+
*toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
307+
return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter);
308+
} else if (llvm::is_contained(dims, str_attr("lane"))) {
309+
// Case 3. Transfer between values in the same warp, in which case we try
310+
// to move values using warp shuffles, though if the pattern is
311+
// complicated enough we may fall back to using shared memory
312+
// TODO(Keren): implement warp shuffle instead of using the general
313+
// approach that uses shared memory
314+
LinearLayout srcLayout =
315+
*toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
316+
LinearLayout dstLayout =
317+
*toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
318+
return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter);
319+
} else if (llvm::is_contained(dims, str_attr("register"))) {
320+
// Case 4. Transfer between values in the same thread, in which case we
321+
// simply reorder the elements of adaptor.getSrc().
322+
return transferWithinThread(op, *conversion, adaptor, rewriter);
323+
} else {
324+
// The two layouts are equivalent. We should probably remove these in
325+
// RemoveLayoutConversion.
326+
rewriter.replaceOp(op, adaptor.getSrc());
327+
return success();
316328
}
317-
318-
return transferWithinBlockOrGroup(op, *srcLayout, *dstLayout, adaptor,
319-
rewriter); // Case 3 and 4
320329
}
321330

322331
LogicalResult
323-
transferWithinThread(ConvertLayoutOp op, const LinearLayout &srcLayout,
324-
const LinearLayout &dstLayout, OpAdaptor adaptor,
332+
transferWithinThread(ConvertLayoutOp op, const LinearLayout &conversion,
333+
OpAdaptor adaptor,
325334
ConversionPatternRewriter &rewriter) const {
326335
MLIRContext *ctx = op.getContext();
327336
auto loc = op.getLoc();
328337
StringAttr kRegister = str_attr("register");
329-
StringAttr kLane = str_attr("lane");
330-
StringAttr kWarp = str_attr("warp");
331-
StringAttr kBlock = str_attr("block");
332-
333-
// There are three possible cases:
334-
//
335-
// 1. `srcLayout` has the same number of registers as `dstLayout`.
336-
// 2. `srcLayout` has fewer registers than `dstLayout`.
337-
// 3. `srcLayout` has more registers than `dstLayout`.
338-
//
339-
// In the second case `srcLayout . dstLayout^-1` is not surjective
340-
// because not all destination registers are covered.
341-
// Since the goal is to cover all of the destination
342-
// registers, we can instead use `dstLayout . srcLayout^-1`.
343-
LinearLayout conversion = dstLayout.invertAndCompose(srcLayout);
344-
auto dstToSrc = conversion.divideRight(
345-
LinearLayout::identity1D(conversion.getInDimSize(kLane), kLane, kLane) *
346-
LinearLayout::identity1D(conversion.getInDimSize(kWarp), kWarp, kWarp) *
347-
LinearLayout::identity1D(conversion.getInDimSize(kBlock), kBlock,
348-
kBlock));
349-
350338
assert(!cvtNeedsSharedMemory(op.getSrc().getType(), op.getType()));
351-
assert(ArrayRef(to_vector(dstToSrc->getInDimNames())) ==
352-
ArrayRef{kRegister});
353-
assert(ArrayRef(to_vector(dstToSrc->getOutDimNames())) ==
354-
ArrayRef{kRegister});
355339

356340
auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter);
357341
SmallVector<Value> outVals;
358-
outVals.resize(dstToSrc->getInDimSize(kRegister));
359-
for (int i = 0; i < dstToSrc->getInDimSize(kRegister); i++) {
360-
auto srcIdx = dstToSrc->apply({{kRegister, i}});
361-
outVals[i] = inVals[srcIdx.begin()->second];
342+
outVals.resize(conversion.getInDimSize(kRegister));
343+
for (int i = 0; i < conversion.getInDimSize(kRegister); i++) {
344+
auto srcIdx = conversion.apply({{kRegister, i}}).begin()->second;
345+
outVals[i] = inVals[srcIdx];
362346
}
363347
Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter,
364348
op.getType());
365349
rewriter.replaceOp(op, result);
366350
return success();
367351
}
368352

369-
LogicalResult transferWithinLane(ConvertLayoutOp op,
370-
const LinearLayout &srcLayout,
371-
const LinearLayout &dstLayout,
372-
OpAdaptor adaptor,
373-
ConversionPatternRewriter &rewriter) const {
374-
// TODO(Keren): implement warp shuffle instead of using the general approach
375-
// that uses shared memory
376-
return transferWithinBlockOrGroup(op, srcLayout, dstLayout, adaptor,
377-
rewriter);
378-
}
379-
380-
LogicalResult
381-
transferWithinBlockOrGroup(ConvertLayoutOp op, const LinearLayout &srcLayout,
382-
const LinearLayout &dstLayout, OpAdaptor adaptor,
383-
ConversionPatternRewriter &rewriter) const {
384-
LinearLayout conversion = srcLayout.invertAndCompose(dstLayout);
385-
386-
// TODO(Keren): LLs support cross-CTA conversions, this function does not
387-
if (isCrossCTAConversion(conversion))
388-
return failure();
389-
353+
LogicalResult transferWithinBlock(ConvertLayoutOp op,
354+
const LinearLayout &srcLayout,
355+
const LinearLayout &dstLayout,
356+
OpAdaptor adaptor,
357+
ConversionPatternRewriter &rewriter) const {
390358
MLIRContext *ctx = op.getContext();
391359
auto loc = op.getLoc();
392360
auto srcTy = op.getSrc().getType();
@@ -461,11 +429,14 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
461429
}
462430
}
463431

432+
// Pretty sure this is the identity function ATM
433+
// It'd be better to simply call `quotient({kBlock})` and
434+
// remove kBlock from transferWithinBlockImpl
464435
auto srcLayoutWithinBlock = getLayoutWithinBlock(srcLayout);
465436
auto dstLayoutWithinBlock = getLayoutWithinBlock(dstLayout);
466437
SmallVector<Value> outVals =
467-
transferWithinBlock(inVals, op, srcLayoutWithinBlock,
468-
dstLayoutWithinBlock, adaptor, rewriter);
438+
transferWithinBlockImpl(inVals, op, srcLayoutWithinBlock,
439+
dstLayoutWithinBlock, adaptor, rewriter);
469440

470441
// Unmunge output values
471442
for (const auto &it : llvm::enumerate(outVals)) {
@@ -499,10 +470,10 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
499470
}
500471

501472
SmallVector<Value>
502-
transferWithinBlock(ArrayRef<Value> inVals, ConvertLayoutOp op,
503-
const LinearLayout &srcLayout,
504-
const LinearLayout &dstLayout, OpAdaptor adaptor,
505-
ConversionPatternRewriter &rewriter) const {
473+
transferWithinBlockImpl(ArrayRef<Value> inVals, ConvertLayoutOp op,
474+
const LinearLayout &srcLayout,
475+
const LinearLayout &dstLayout, OpAdaptor adaptor,
476+
ConversionPatternRewriter &rewriter) const {
506477
MLIRContext *ctx = op.getContext();
507478
auto loc = op.getLoc();
508479

0 commit comments

Comments
 (0)