Skip to content

Commit 137bc62

Browse files
authored
[LAYOUTS] Implement generic layout propagation through ReshapeOp (#5389)
This PR also: - Enables backward rematerialisation and hoisting for LLs - Adds a fold reshape(cvt) -> reshape when the layouts are structurally the same - Removes an assert that was disallowing the use of LLs across broadcast. When this happens, the LL will not have the same shape as the tensor. We do this to match the legacy behaviour and avoid the proliferation of new layouts - Removes the layout-specific tests from before and instead we create functional tests that test the axioms for the reshape function. We see that all the legacy layouts pass these tests. - Temporarily tested that the legacy path and the new path agree in CI in triton-lang/triton@e93638b
1 parent 80e2abd commit 137bc62

File tree

7 files changed

+184
-389
lines changed

7 files changed

+184
-389
lines changed

include/triton/Dialect/Triton/IR/Dialect.h

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,19 @@ class DialectInferLayoutInterface
5555

5656
// Tries to compute the encoding for the result of a reshape operation that
5757
// makes the reshape a "nop", i.e. the same GPU threads contain the same
58-
// elements as before the reshape. Note that this is not always possible (in
59-
// which case you'd need to choose a different layout for the input to the
60-
// reshape).
58+
// elements as before the reshape using legacy layouts. This is not always
59+
// possible (in which case we fallback to using LinearLayouts)
60+
// In the future we'll always use LinearLayouts
6161
virtual LogicalResult
62-
inferReshapeOpNoReorderEncoding(ArrayRef<int64_t> srcShape, Attribute srcEnc,
63-
ArrayRef<int64_t> dstShape, Attribute &dstEnc,
64-
std::optional<Location> loc) const = 0;
62+
inferReshapeOpEncoding(ArrayRef<int64_t> srcShape, Attribute srcEnc,
63+
ArrayRef<int64_t> dstShape, Attribute &dstEnc,
64+
std::optional<Location> loc) const = 0;
65+
66+
// Check if two layouts are structurally the same, even if their names are
67+
// different
68+
virtual LogicalResult verifyLayoutsAreEqual(ArrayRef<int64_t> shape,
69+
Attribute expected, Attribute got,
70+
Location loc) const = 0;
6571

6672
virtual LogicalResult
6773
inferJoinOpEncoding(Attribute srcEnc, Attribute &dstEnc,

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "triton/Dialect/Triton/IR/Dialect.h"
99
#include "triton/Dialect/Triton/IR/Types.h"
1010
#include "triton/Dialect/Triton/IR/Utility.h"
11+
#include "triton/Tools/LinearLayout.h"
1112
#include "llvm/Support/ErrorHandling.h"
1213

1314
namespace mlir {
@@ -701,24 +702,21 @@ LogicalResult ReshapeOp::verify() {
701702
"encodings, or (b) neither does.");
702703
}
703704

704-
if (srcEnc && !getAllowReorder()) {
705-
Attribute inferredDstEnc;
706-
if (cast<DialectInferLayoutInterface>(&srcEnc.getDialect())
707-
->inferReshapeOpNoReorderEncoding(srcTy.getShape(), srcEnc,
708-
dstTy.getShape(), inferredDstEnc,
709-
getLoc())
710-
.failed()) {
711-
return emitError("This reshape is impossible without reordering, but "
712-
"reordering is not allowed. Try choosing a different "
713-
"encoding for the input tensor (or allow reordering).");
714-
}
715-
if (inferredDstEnc != dstEnc) {
716-
return emitError("Expected result encoding ")
717-
<< inferredDstEnc << " but was " << dstEnc;
718-
}
705+
if (!srcEnc || getAllowReorder()) {
706+
return success();
719707
}
720708

721-
return success();
709+
// Check that we can infer the dst encoding from the src encoding
710+
// and that the inferred dst encoding is the same as the given dst encoding
711+
Attribute inferredDstEnc;
712+
auto result =
713+
cast<DialectInferLayoutInterface>(&srcEnc.getDialect())
714+
->inferReshapeOpEncoding(srcTy.getShape(), srcEnc, dstTy.getShape(),
715+
inferredDstEnc, getLoc());
716+
assert(succeeded(result));
717+
return cast<DialectInferLayoutInterface>(&srcEnc.getDialect())
718+
->verifyLayoutsAreEqual(dstTy.getShape(), inferredDstEnc, dstEnc,
719+
getLoc());
722720
}
723721

724722
//-- FpToFpOp --

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 80 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1598,11 +1598,12 @@ LinearEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
15981598

15991599
SmallVector<unsigned>
16001600
LinearEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape, Type) const {
1601-
// We can relax this assert by calling toLinearLayout rather than
1602-
// getLinearLayout
1603-
SmallVector<int32_t> shapeVec(shape.begin(), shape.end());
1604-
assert(shapeVec == llvm::to_vector(getLinearLayout().getOutDimSizes()));
1605-
auto ll = getLinearLayout();
1601+
// When broadcasting the layout the shape changes, otherwise the shape is
1602+
// the same as the shape of the tensor
1603+
// We can either have BroadcastOp with SameOperandsAndResultEncoding, or keep
1604+
// the invariant that the shape of the LL is that of the tensor
1605+
// We choose the former for BC
1606+
auto ll = *toLinearLayout(shape);
16061607
return basesPerDim(ll, StringAttr::get(getContext(), "register"));
16071608
}
16081609

@@ -2623,8 +2624,8 @@ struct TritonGPUInferLayoutInterface
26232624
// contains elements [a,b,c,d] before the reshape, it contains those same
26242625
// elements after the reshape, they're just "renamed".
26252626
//
2626-
// A dst encoding that satisfies this property does not exist for all inputs.
2627-
// Here are some positive and negative examples.
2627+
// Using legacy layouts, a dst encoding that satisfies this property may not
2628+
// exist. Here are some positive and negative examples.
26282629
//
26292630
// - NOT OK: 4x4 order=[0,1] -> 16. Reshape merges elements so
26302631
// dim 1 is the fastest-changing in the dst, but the src has the opposite
@@ -2638,17 +2639,19 @@ struct TritonGPUInferLayoutInterface
26382639
// - OK: 32x4 sizePerThread=[4,4] -> 128. dst with sizePerThread=[16] will
26392640
// contain the same elements as before.
26402641
//
2642+
// With linear layouts, we can always find a dst encoding that satisfies
2643+
// this property. See inferReshapeOpEncoding.
2644+
//
26412645
// Users of this function require that it is symmetrical: if
26422646
// (srcShape,srcEnc,dstShape) => dstEnc, then (dstShape,dstEnc,srcShape) =>
26432647
// srcEnc.
2644-
LogicalResult
2645-
inferReshapeOpNoReorderEncoding(ArrayRef<int64_t> srcShape, Attribute srcEnc,
2646-
ArrayRef<int64_t> dstShape, Attribute &dstEnc,
2647-
std::optional<Location> loc) const override {
2648+
LogicalResult inferReshapeOpLegacyEncoding(ArrayRef<int64_t> srcShape,
2649+
Attribute srcEnc,
2650+
ArrayRef<int64_t> dstShape,
2651+
Attribute &dstEnc) const {
26482652
auto src = mlir::dyn_cast<BlockedEncodingAttr>(srcEnc);
26492653
if (!src) {
2650-
return emitOptionalError(
2651-
loc, "Non-reordering reshape only supports BlockedEncoding");
2654+
return failure();
26522655
}
26532656

26542657
// Nop reshape; we can always infer an encoding.
@@ -2681,9 +2684,7 @@ struct TritonGPUInferLayoutInterface
26812684
// to handle CTASplitNum.
26822685
if (!all_of(src.getCTAsPerCGA(), [](int32_t x) { return x == 1; }) ||
26832686
!all_of(src.getCTASplitNum(), [](int32_t x) { return x == 1; })) {
2684-
return emitOptionalError(
2685-
loc, "Non-reordering reshape does not currently support multi-CTA "
2686-
"layouts other than the default layout.");
2687+
return failure();
26872688
}
26882689

26892690
// Cowardly refuse to handle encodings where shape[dim] is not divisible by
@@ -2693,12 +2694,7 @@ struct TritonGPUInferLayoutInterface
26932694
for (int dim = 0; dim < srcShape.size(); dim++) {
26942695
if (srcShape[dim] >= subblock[dim] &&
26952696
srcShape[dim] % subblock[dim] != 0) {
2696-
return emitOptionalError(loc,
2697-
"Can't do a non-reordering reshape because "
2698-
"the size of dimension ",
2699-
dim, " (", srcShape[dim], ")",
2700-
" is not divisible by ", name, "[", dim, "]",
2701-
" = ", subblock[dim]);
2697+
return failure();
27022698
}
27032699
}
27042700
return success();
@@ -2723,11 +2719,7 @@ struct TritonGPUInferLayoutInterface
27232719
// physical order, with `a` being the most major.
27242720
for (const auto &[srcDims, dstDims] : decomp) {
27252721
if (!isConsecutive(to_vector(reverse(gather(srcInvOrder, srcDims))))) {
2726-
return emitOptionalError(loc,
2727-
"Cannot do a non-reordering reshape given "
2728-
"this src encoding order. Dimensions [",
2729-
join(srcDims),
2730-
"] must be physically consecutive.");
2722+
return failure();
27312723
}
27322724
}
27332725

@@ -2774,11 +2766,7 @@ struct TritonGPUInferLayoutInterface
27742766
// Check that more-minor dims all have 1 in shapeRemaining.
27752767
for (int j = i + 1; j < srcDims.size(); j++) {
27762768
if (shapeRemaining[j] != 1) {
2777-
return emitOptionalError(
2778-
loc,
2779-
"Invalid src encoding for non-reordering reshape. Must use "
2780-
"up sizePerThread / threadsPerWarp / warpsPerCTA for "
2781-
"more-minor dimensions before more major-dims can use them.");
2769+
return failure();
27822770
}
27832771
}
27842772

@@ -2793,13 +2781,7 @@ struct TritonGPUInferLayoutInterface
27932781
// only if we're the most-major dimension of the chunk and in all
27942782
// future chunks, only this most-major dim has a non-1 size.
27952783
if (shapeRemaining[i] == 0 && i != 0) {
2796-
return emitOptionalError(
2797-
loc,
2798-
"Invalid src encoding for non-reordering reshape. Block "
2799-
"size in dimension ",
2800-
dim,
2801-
" is larger than the shape that dimension, but this is only "
2802-
"allowed for the most-major dimension of a reshape chunk");
2784+
return failure();
28032785
}
28042786
}
28052787
return success();
@@ -2889,6 +2871,65 @@ struct TritonGPUInferLayoutInterface
28892871
return success();
28902872
}
28912873

2874+
LogicalResult verifyLayoutsAreEqual(ArrayRef<int64_t> shape,
2875+
Attribute expected, Attribute got,
2876+
Location loc) const override {
2877+
if (expected == got) {
2878+
return success();
2879+
}
2880+
// Check whether the encodings are structurally the same.
2881+
auto expectedLL = triton::gpu::toLinearLayout(shape, expected);
2882+
auto gotLL = triton::gpu::toLinearLayout(shape, got);
2883+
if (expectedLL != gotLL) {
2884+
return emitError(loc, "Expected result encoding ")
2885+
<< expected << " but was " << got;
2886+
}
2887+
return success();
2888+
}
2889+
2890+
LogicalResult
2891+
inferReshapeOpEncoding(ArrayRef<int64_t> srcShape, Attribute srcEnc,
2892+
ArrayRef<int64_t> dstShape, Attribute &dstEnc,
2893+
std::optional<Location> loc) const override {
2894+
auto result =
2895+
inferReshapeOpLegacyEncoding(srcShape, srcEnc, dstShape, dstEnc);
2896+
if (succeeded(result)) {
2897+
return result;
2898+
}
2899+
2900+
// If the legacy encoding failed use LinearLayouts.
2901+
// Once LinearLayouts are more widely used, we can remove
2902+
// inferReshapeOpLegacyEncoding and simply use LLs.
2903+
auto *ctx = getContext();
2904+
auto src = triton::gpu::toLinearLayout(srcShape, srcEnc);
2905+
if (!src) {
2906+
return emitOptionalError(loc,
2907+
"src encoding does not support linear layout");
2908+
}
2909+
2910+
if (product(srcShape) != product(dstShape)) {
2911+
return emitOptionalError(loc, "numel of dst shape does not match "
2912+
"numel of src shape");
2913+
}
2914+
2915+
auto newRank = dstShape.size();
2916+
SmallVector<std::pair<StringAttr, int32_t>> newOutDims;
2917+
for (auto [dim, size] :
2918+
llvm::zip(standardOutDimNames(ctx, newRank), dstShape)) {
2919+
newOutDims.emplace_back(dim, size);
2920+
}
2921+
auto srcOutDims = llvm::to_vector(src->getOutDimNames());
2922+
// reshapeOp assumes minor-to-major, so we need to transpose the out dims
2923+
// before the reshape
2924+
std::reverse(srcOutDims.begin(), srcOutDims.end());
2925+
std::reverse(newOutDims.begin(), newOutDims.end());
2926+
auto dst = src->transposeOuts(srcOutDims)
2927+
.reshapeOuts(newOutDims)
2928+
.transposeOuts(standardOutDimNames(ctx, newRank));
2929+
dstEnc = LinearEncodingAttr::get(ctx, dst);
2930+
return success();
2931+
}
2932+
28922933
LogicalResult
28932934
inferJoinOpEncoding(Attribute srcEnc, Attribute &dstEnc,
28942935
std::optional<Location> loc) const override {

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,16 @@ struct CanonicalizeConvertFromReshape
4242
auto convert = op.getSrc().getDefiningOp<ConvertLayoutOp>();
4343
if (!convert)
4444
return failure();
45+
// If the layouts are structurally the same, the convert is trivial
46+
auto srcType = convert.getSrc().getType();
47+
auto dstType = convert.getType();
48+
auto srcLL = toLinearLayout(srcType.getShape(), srcType.getEncoding());
49+
auto dstLL = toLinearLayout(dstType.getShape(), dstType.getEncoding());
50+
if (srcLL && dstLL && *srcLL == *dstLL) {
51+
rewriter.replaceOpWithNewOp<triton::ReshapeOp>(
52+
op, op.getType(), convert.getSrc(), op.getAllowReorder());
53+
return mlir::success();
54+
}
4555
if (isExpensiveView(convert.getSrc().getType(), op.getType()))
4656
return failure();
4757
if (!op.getAllowReorder() || op.getEfficientLayout())

lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,9 +1025,7 @@ void LayoutRematerialization::backwardRematerialization(
10251025
// we don't handle conversions to DotOperandEncodingAttr
10261026
// this is a heuristic to accommodate fused attention
10271027
RankedTensorType targetType = convertOp.getType();
1028-
// We stop the rematerialization of linear layouts as we have to be a bit more
1029-
// careful with the heuristics for both correctness and perf
1030-
if (isa<DotOperandEncodingAttr, LinearEncodingAttr>(targetType.getEncoding()))
1028+
if (isa<DotOperandEncodingAttr>(targetType.getEncoding()))
10311029
return;
10321030
Value oldV = convertOp.getSrc();
10331031
LDBG("check backward remat with source " << oldV << " encoding "
@@ -1069,11 +1067,8 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
10691067
ConvertLayoutOp convertOp) {
10701068
// we don't handle conversions to DotOperandEncodingAttr
10711069
// this is a heuristics to accommodate fused attention
1072-
// We stop the rematerialization of linear layouts as we have to be a bit more
1073-
// careful with the heuristics for both correctness and perf
10741070
RankedTensorType targetType = convertOp.getType();
1075-
if (mlir::isa<DotOperandEncodingAttr, LinearEncodingAttr>(
1076-
targetType.getEncoding()))
1071+
if (isa<DotOperandEncodingAttr>(targetType.getEncoding()))
10771072
return;
10781073

10791074
auto isExtOrBroadcastOp = [](Operation *op) {

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -407,14 +407,13 @@ static Attribute inferReshapeOpDstEncoding(ArrayRef<int64_t> srcShape,
407407
return {};
408408

409409
Attribute dstEnc;
410-
if (succeeded(
411-
srcEnc.getDialect()
412-
.getRegisteredInterface<triton::DialectInferLayoutInterface>()
413-
->inferReshapeOpNoReorderEncoding(
414-
srcShape, srcEnc, dstShape, dstEnc, /*loc=*/std::nullopt))) {
415-
return dstEnc;
416-
}
417-
return {};
410+
auto result =
411+
srcEnc.getDialect()
412+
.getRegisteredInterface<triton::DialectInferLayoutInterface>()
413+
->inferReshapeOpEncoding(srcShape, srcEnc, dstShape, dstEnc,
414+
/*loc=*/std::nullopt);
415+
assert(succeeded(result));
416+
return dstEnc;
418417
}
419418

420419
static Attribute inferDstEncoding(triton::ReshapeOp op, Attribute encoding) {

0 commit comments

Comments
 (0)