Skip to content

Commit 3ed479f

Browse files
apgoucherAdam P. Goucher
andauthored
This reverts commit 70359fa which was causing some of our internal tests to fail. Co-authored-by: Adam P. Goucher <[email protected]>
1 parent 7db39a9 commit 3ed479f

File tree

11 files changed

+396
-308
lines changed

11 files changed

+396
-308
lines changed

include/triton/Dialect/Triton/IR/Dialect.h

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -55,19 +55,13 @@ class DialectInferLayoutInterface
5555

5656
// Tries to compute the encoding for the result of a reshape operation that
5757
// makes the reshape a "nop", i.e. the same GPU threads contain the same
58-
// elements as before the reshape using legacy layouts. This is not always
59-
// possible (in which case we fallback to using LinearLayouts)
60-
// In the future we'll always use LinearLayouts
58+
// elements as before the reshape. Note that this is not always possible (in
59+
// which case you'd need to choose a different layout for the input to the
60+
// reshape).
6161
virtual LogicalResult
62-
inferReshapeOpEncoding(ArrayRef<int64_t> srcShape, Attribute srcEnc,
63-
ArrayRef<int64_t> dstShape, Attribute &dstEnc,
64-
std::optional<Location> loc) const = 0;
65-
66-
// Check if two layouts are structurally the same, even if their names are
67-
// different
68-
virtual LogicalResult verifyLayoutsAreEqual(ArrayRef<int64_t> shape,
69-
Attribute expected, Attribute got,
70-
Location loc) const = 0;
62+
inferReshapeOpNoReorderEncoding(ArrayRef<int64_t> srcShape, Attribute srcEnc,
63+
ArrayRef<int64_t> dstShape, Attribute &dstEnc,
64+
std::optional<Location> loc) const = 0;
7165

7266
virtual LogicalResult
7367
inferJoinOpEncoding(Attribute srcEnc, Attribute &dstEnc,

lib/Analysis/Utility.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -219,14 +219,15 @@ bool ReduceOpHelper::isSupportedLayout() {
219219
}
220220

221221
auto srcLayout = getSrcLayout();
222-
if (isa<BlockedEncodingAttr, LinearEncodingAttr, SliceEncodingAttr>(
223-
srcLayout)) {
222+
if (isa<BlockedEncodingAttr>(srcLayout)) {
224223
return true;
225224
}
226-
227225
if (auto mmaLayout = dyn_cast<MmaEncodingTrait>(srcLayout)) {
228226
return mmaLayout.supportReduction();
229227
}
228+
if (auto sliceLayout = dyn_cast<SliceEncodingAttr>(srcLayout)) {
229+
return true;
230+
}
230231
return false;
231232
}
232233

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
#include "triton/Dialect/Triton/IR/Dialect.h"
99
#include "triton/Dialect/Triton/IR/Types.h"
1010
#include "triton/Dialect/Triton/IR/Utility.h"
11-
#include "triton/Tools/LinearLayout.h"
1211
#include "llvm/Support/ErrorHandling.h"
1312

1413
namespace mlir {
@@ -702,21 +701,24 @@ LogicalResult ReshapeOp::verify() {
702701
"encodings, or (b) neither does.");
703702
}
704703

705-
if (!srcEnc || getAllowReorder()) {
706-
return success();
704+
if (srcEnc && !getAllowReorder()) {
705+
Attribute inferredDstEnc;
706+
if (cast<DialectInferLayoutInterface>(&srcEnc.getDialect())
707+
->inferReshapeOpNoReorderEncoding(srcTy.getShape(), srcEnc,
708+
dstTy.getShape(), inferredDstEnc,
709+
getLoc())
710+
.failed()) {
711+
return emitError("This reshape is impossible without reordering, but "
712+
"reordering is not allowed. Try choosing a different "
713+
"encoding for the input tensor (or allow reordering).");
714+
}
715+
if (inferredDstEnc != dstEnc) {
716+
return emitError("Expected result encoding ")
717+
<< inferredDstEnc << " but was " << dstEnc;
718+
}
707719
}
708720

709-
// Check that we can infer the dst encoding from the src encoding
710-
// and that the inferred dst encoding is the same as the given dst encoding
711-
Attribute inferredDstEnc;
712-
auto result =
713-
cast<DialectInferLayoutInterface>(&srcEnc.getDialect())
714-
->inferReshapeOpEncoding(srcTy.getShape(), srcEnc, dstTy.getShape(),
715-
inferredDstEnc, getLoc());
716-
assert(succeeded(result));
717-
return cast<DialectInferLayoutInterface>(&srcEnc.getDialect())
718-
->verifyLayoutsAreEqual(dstTy.getShape(), inferredDstEnc, dstEnc,
719-
getLoc());
721+
return success();
720722
}
721723

722724
//-- FpToFpOp --

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 42 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1441,7 +1441,7 @@ SmallVector<unsigned> basesPerDim(const LinearLayout::BasesT &namedBases,
14411441

14421442
SmallVector<unsigned> ret(rank, 1);
14431443
auto nonZero = [](auto val) { return val != 0; };
1444-
int nonZeroIdx = 0;
1444+
int nonZeroIdx = -1;
14451445
for (const auto &basis : bases) {
14461446
auto it = std::find_if(basis.begin(), basis.end(), nonZero);
14471447
// Bases can have one or zero non-zero elements
@@ -1453,6 +1453,7 @@ SmallVector<unsigned> basesPerDim(const LinearLayout::BasesT &namedBases,
14531453
} else if (!skipBroadcast) {
14541454
// If we've seen a non-zero basis, we double the size of the previous dim
14551455
// This is just needed to count the CTAsPerCGA
1456+
assert(nonZeroIdx != -1);
14561457
ret[nonZeroIdx] *= 2;
14571458
}
14581459
}
@@ -1597,14 +1598,12 @@ LinearEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
15971598

15981599
SmallVector<unsigned>
15991600
LinearEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape, Type) const {
1600-
// When broadcasting the layout the shape changes, otherwise the shape is
1601-
// the same as the shape of the tensor
1602-
// We can either have BroadcastOp with SameOperandsAndResultEncoding, or keep
1603-
// the invariant that the shape of the LL is that of the tensor
1604-
// We choose the former for BC
1605-
auto ll = *toLinearLayout(shape);
1606-
return basesPerDim(ll, StringAttr::get(getContext(), "register"),
1607-
/*skipBroadcast=*/false);
1601+
// We can relax this assert by calling toLinearLayout rather than
1602+
// getLinearLayout
1603+
SmallVector<int32_t> shapeVec(shape.begin(), shape.end());
1604+
assert(shapeVec == llvm::to_vector(getLinearLayout().getOutDimSizes()));
1605+
auto ll = getLinearLayout();
1606+
return basesPerDim(ll, StringAttr::get(getContext(), "register"));
16081607
}
16091608

16101609
// Start of Selection
@@ -2674,8 +2673,8 @@ struct TritonGPUInferLayoutInterface
26742673
// contains elements [a,b,c,d] before the reshape, it contains those same
26752674
// elements after the reshape, they're just "renamed".
26762675
//
2677-
// Using legacy layouts, a dst encoding that satisfies this property may not
2678-
// exist. Here are some positive and negative examples.
2676+
// A dst encoding that satisfies this property does not exist for all inputs.
2677+
// Here are some positive and negative examples.
26792678
//
26802679
// - NOT OK: 4x4 order=[0,1] -> 16. Reshape merges elements so
26812680
// dim 1 is the fastest-changing in the dst, but the src has the opposite
@@ -2689,19 +2688,17 @@ struct TritonGPUInferLayoutInterface
26892688
// - OK: 32x4 sizePerThread=[4,4] -> 128. dst with sizePerThread=[16] will
26902689
// contain the same elements as before.
26912690
//
2692-
// With linear layouts, we can always find a dst encoding that satisfies
2693-
// this property. See inferReshapeOpEncoding.
2694-
//
26952691
// Users of this function require that it is symmetrical: if
26962692
// (srcShape,srcEnc,dstShape) => dstEnc, then (dstShape,dstEnc,srcShape) =>
26972693
// srcEnc.
2698-
LogicalResult inferReshapeOpLegacyEncoding(ArrayRef<int64_t> srcShape,
2699-
Attribute srcEnc,
2700-
ArrayRef<int64_t> dstShape,
2701-
Attribute &dstEnc) const {
2694+
LogicalResult
2695+
inferReshapeOpNoReorderEncoding(ArrayRef<int64_t> srcShape, Attribute srcEnc,
2696+
ArrayRef<int64_t> dstShape, Attribute &dstEnc,
2697+
std::optional<Location> loc) const override {
27022698
auto src = mlir::dyn_cast<BlockedEncodingAttr>(srcEnc);
27032699
if (!src) {
2704-
return failure();
2700+
return emitOptionalError(
2701+
loc, "Non-reordering reshape only supports BlockedEncoding");
27052702
}
27062703

27072704
// Nop reshape; we can always infer an encoding.
@@ -2734,7 +2731,9 @@ struct TritonGPUInferLayoutInterface
27342731
// to handle CTASplitNum.
27352732
if (!all_of(src.getCTAsPerCGA(), [](int32_t x) { return x == 1; }) ||
27362733
!all_of(src.getCTASplitNum(), [](int32_t x) { return x == 1; })) {
2737-
return failure();
2734+
return emitOptionalError(
2735+
loc, "Non-reordering reshape does not currently support multi-CTA "
2736+
"layouts other than the default layout.");
27382737
}
27392738

27402739
// Cowardly refuse to handle encodings where shape[dim] is not divisible by
@@ -2744,7 +2743,12 @@ struct TritonGPUInferLayoutInterface
27442743
for (int dim = 0; dim < srcShape.size(); dim++) {
27452744
if (srcShape[dim] >= subblock[dim] &&
27462745
srcShape[dim] % subblock[dim] != 0) {
2747-
return failure();
2746+
return emitOptionalError(loc,
2747+
"Can't do a non-reordering reshape because "
2748+
"the size of dimension ",
2749+
dim, " (", srcShape[dim], ")",
2750+
" is not divisible by ", name, "[", dim, "]",
2751+
" = ", subblock[dim]);
27482752
}
27492753
}
27502754
return success();
@@ -2769,7 +2773,11 @@ struct TritonGPUInferLayoutInterface
27692773
// physical order, with `a` being the most major.
27702774
for (const auto &[srcDims, dstDims] : decomp) {
27712775
if (!isConsecutive(to_vector(reverse(gather(srcInvOrder, srcDims))))) {
2772-
return failure();
2776+
return emitOptionalError(loc,
2777+
"Cannot do a non-reordering reshape given "
2778+
"this src encoding order. Dimensions [",
2779+
join(srcDims),
2780+
"] must be physically consecutive.");
27732781
}
27742782
}
27752783

@@ -2816,7 +2824,11 @@ struct TritonGPUInferLayoutInterface
28162824
// Check that more-minor dims all have 1 in shapeRemaining.
28172825
for (int j = i + 1; j < srcDims.size(); j++) {
28182826
if (shapeRemaining[j] != 1) {
2819-
return failure();
2827+
return emitOptionalError(
2828+
loc,
2829+
"Invalid src encoding for non-reordering reshape. Must use "
2830+
"up sizePerThread / threadsPerWarp / warpsPerCTA for "
2831+
"more-minor dimensions before more major-dims can use them.");
28202832
}
28212833
}
28222834

@@ -2831,7 +2843,13 @@ struct TritonGPUInferLayoutInterface
28312843
// only if we're the most-major dimension of the chunk and in all
28322844
// future chunks, only this most-major dim has a non-1 size.
28332845
if (shapeRemaining[i] == 0 && i != 0) {
2834-
return failure();
2846+
return emitOptionalError(
2847+
loc,
2848+
"Invalid src encoding for non-reordering reshape. Block "
2849+
"size in dimension ",
2850+
dim,
2851+
" is larger than the shape that dimension, but this is only "
2852+
"allowed for the most-major dimension of a reshape chunk");
28352853
}
28362854
}
28372855
return success();
@@ -2921,65 +2939,6 @@ struct TritonGPUInferLayoutInterface
29212939
return success();
29222940
}
29232941

2924-
LogicalResult verifyLayoutsAreEqual(ArrayRef<int64_t> shape,
2925-
Attribute expected, Attribute got,
2926-
Location loc) const override {
2927-
if (expected == got) {
2928-
return success();
2929-
}
2930-
// Check whether the encodings are structurally the same.
2931-
auto expectedLL = triton::gpu::toLinearLayout(shape, expected);
2932-
auto gotLL = triton::gpu::toLinearLayout(shape, got);
2933-
if (expectedLL != gotLL) {
2934-
return emitError(loc, "Expected result encoding ")
2935-
<< expected << " but was " << got;
2936-
}
2937-
return success();
2938-
}
2939-
2940-
LogicalResult
2941-
inferReshapeOpEncoding(ArrayRef<int64_t> srcShape, Attribute srcEnc,
2942-
ArrayRef<int64_t> dstShape, Attribute &dstEnc,
2943-
std::optional<Location> loc) const override {
2944-
auto result =
2945-
inferReshapeOpLegacyEncoding(srcShape, srcEnc, dstShape, dstEnc);
2946-
if (succeeded(result)) {
2947-
return result;
2948-
}
2949-
2950-
// If the legacy encoding failed use LinearLayouts.
2951-
// Once LinearLayouts are more widely used, we can remove
2952-
// inferReshapeOpLegacyEncoding and simply use LLs.
2953-
auto *ctx = getContext();
2954-
auto src = triton::gpu::toLinearLayout(srcShape, srcEnc);
2955-
if (!src) {
2956-
return emitOptionalError(loc,
2957-
"src encoding does not support linear layout");
2958-
}
2959-
2960-
if (product(srcShape) != product(dstShape)) {
2961-
return emitOptionalError(loc, "numel of dst shape does not match "
2962-
"numel of src shape");
2963-
}
2964-
2965-
auto newRank = dstShape.size();
2966-
SmallVector<std::pair<StringAttr, int32_t>> newOutDims;
2967-
for (auto [dim, size] :
2968-
llvm::zip(standardOutDimNames(ctx, newRank), dstShape)) {
2969-
newOutDims.emplace_back(dim, size);
2970-
}
2971-
auto srcOutDims = llvm::to_vector(src->getOutDimNames());
2972-
// reshapeOp assumes minor-to-major, so we need to transpose the out dims
2973-
// before the reshape
2974-
std::reverse(srcOutDims.begin(), srcOutDims.end());
2975-
std::reverse(newOutDims.begin(), newOutDims.end());
2976-
auto dst = src->transposeOuts(srcOutDims)
2977-
.reshapeOuts(newOutDims)
2978-
.transposeOuts(standardOutDimNames(ctx, newRank));
2979-
dstEnc = LinearEncodingAttr::get(ctx, dst);
2980-
return success();
2981-
}
2982-
29832942
LogicalResult
29842943
inferJoinOpEncoding(Attribute srcEnc, Attribute &dstEnc,
29852944
std::optional<Location> loc) const override {

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,16 +42,6 @@ struct CanonicalizeConvertFromReshape
4242
auto convert = op.getSrc().getDefiningOp<ConvertLayoutOp>();
4343
if (!convert)
4444
return failure();
45-
// If the layouts are structurally the same, the convert is trivial
46-
auto srcType = convert.getSrc().getType();
47-
auto dstType = convert.getType();
48-
auto srcLL = toLinearLayout(srcType.getShape(), srcType.getEncoding());
49-
auto dstLL = toLinearLayout(dstType.getShape(), dstType.getEncoding());
50-
if (srcLL && dstLL && *srcLL == *dstLL) {
51-
rewriter.replaceOpWithNewOp<triton::ReshapeOp>(
52-
op, op.getType(), convert.getSrc(), op.getAllowReorder());
53-
return mlir::success();
54-
}
5545
if (isExpensiveView(convert.getSrc().getType(), op.getType()))
5646
return failure();
5747
if (!op.getAllowReorder() || op.getEfficientLayout())

lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,7 +1025,9 @@ void LayoutRematerialization::backwardRematerialization(
10251025
// we don't handle conversions to DotOperandEncodingAttr
10261026
// this is a heuristic to accommodate fused attention
10271027
RankedTensorType targetType = convertOp.getType();
1028-
if (isa<DotOperandEncodingAttr>(targetType.getEncoding()))
1028+
// We stop the rematerialization of linear layouts as we have to be a bit more
1029+
// careful with the heuristics for both correctness and perf
1030+
if (isa<DotOperandEncodingAttr, LinearEncodingAttr>(targetType.getEncoding()))
10291031
return;
10301032
Value oldV = convertOp.getSrc();
10311033
LDBG("check backward remat with source " << oldV << " encoding "
@@ -1067,8 +1069,11 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
10671069
ConvertLayoutOp convertOp) {
10681070
// we don't handle conversions to DotOperandEncodingAttr
10691071
// this is a heuristics to accommodate fused attention
1072+
// We stop the rematerialization of linear layouts as we have to be a bit more
1073+
// careful with the heuristics for both correctness and perf
10701074
RankedTensorType targetType = convertOp.getType();
1071-
if (isa<DotOperandEncodingAttr>(targetType.getEncoding()))
1075+
if (mlir::isa<DotOperandEncodingAttr, LinearEncodingAttr>(
1076+
targetType.getEncoding()))
10721077
return;
10731078

10741079
auto isExtOrBroadcastOp = [](Operation *op) {

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -407,13 +407,14 @@ static Attribute inferReshapeOpDstEncoding(ArrayRef<int64_t> srcShape,
407407
return {};
408408

409409
Attribute dstEnc;
410-
auto result =
411-
srcEnc.getDialect()
412-
.getRegisteredInterface<triton::DialectInferLayoutInterface>()
413-
->inferReshapeOpEncoding(srcShape, srcEnc, dstShape, dstEnc,
414-
/*loc=*/std::nullopt);
415-
assert(succeeded(result));
416-
return dstEnc;
410+
if (succeeded(
411+
srcEnc.getDialect()
412+
.getRegisteredInterface<triton::DialectInferLayoutInterface>()
413+
->inferReshapeOpNoReorderEncoding(
414+
srcShape, srcEnc, dstShape, dstEnc, /*loc=*/std::nullopt))) {
415+
return dstEnc;
416+
}
417+
return {};
417418
}
418419

419420
static Attribute inferDstEncoding(triton::ReshapeOp op, Attribute encoding) {

0 commit comments

Comments
 (0)