Skip to content

Commit 3d78176

Browse files
Merge commit '635435fc2e56b2a30276302d75df87956b541848'
2 parents d662e65 + 635435f commit 3d78176

File tree

25 files changed

+859
-619
lines changed

25 files changed

+859
-619
lines changed

include/triton/Conversion/TritonGPUToLLVM/Patterns.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,6 @@ namespace triton::gpu {
1313
/// |module| op because the codegen doesn't handle `blocked -> dot_op` directly.
1414
void decomposeBlockedToDotLayoutConversion(ModuleOp module);
1515

16-
/// Replaces `splat -> shared` with `splat -> blocked -> shared` in the given
17-
/// |module| op.
18-
void decomposeSplatOpToSharedLayoutConversion(ModuleOp module);
19-
2016
/// Replaces `mma/mfma -> dot_op` with `mma/mfma -> blocked -> dot_op` in the
2117
/// given |module| op, but bypass the decomposition if |shortcutFn| returns
2218
/// true.

include/triton/Dialect/Triton/IR/Dialect.h

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,19 @@ class DialectInferLayoutInterface
5555

5656
// Tries to compute the encoding for the result of a reshape operation that
5757
// makes the reshape a "nop", i.e. the same GPU threads contain the same
58-
// elements as before the reshape. Note that this is not always possible (in
59-
// which case you'd need to choose a different layout for the input to the
60-
// reshape).
58+
// elements as before the reshape using legacy layouts. This is not always
59+
// possible (in which case we fallback to using LinearLayouts)
60+
// In the future we'll always use LinearLayouts
6161
virtual LogicalResult
62-
inferReshapeOpNoReorderEncoding(ArrayRef<int64_t> srcShape, Attribute srcEnc,
63-
ArrayRef<int64_t> dstShape, Attribute &dstEnc,
64-
std::optional<Location> loc) const = 0;
62+
inferReshapeOpEncoding(ArrayRef<int64_t> srcShape, Attribute srcEnc,
63+
ArrayRef<int64_t> dstShape, Attribute &dstEnc,
64+
std::optional<Location> loc) const = 0;
65+
66+
// Check if two layouts are structurally the same, even if their names are
67+
// different
68+
virtual LogicalResult verifyLayoutsAreEqual(ArrayRef<int64_t> shape,
69+
Attribute expected, Attribute got,
70+
Location loc) const = 0;
6571

6672
virtual LogicalResult
6773
inferJoinOpEncoding(Attribute srcEnc, Attribute &dstEnc,

lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -18,32 +18,6 @@ static void addAttrs(Operation *op, ArrayRef<mlir::NamedAttribute> attrs) {
1818

1919
namespace mlir::triton::gpu {
2020

21-
void decomposeSplatOpToSharedLayoutConversion(ModuleOp module) {
22-
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(module);
23-
int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(module);
24-
int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(module);
25-
module.walk([&](triton::SplatOp splatOp) -> void {
26-
auto dstType = cast<RankedTensorType>(splatOp.getType());
27-
auto shared = dyn_cast_or_null<triton::gpu::SharedEncodingAttr>(
28-
dstType.getEncoding());
29-
if (shared) {
30-
OpBuilder builder(splatOp);
31-
SmallVector<unsigned, 4> sizePerThread(dstType.getRank(), 1);
32-
auto newType = RankedTensorType::get(
33-
dstType.getShape(), dstType.getElementType(),
34-
triton::gpu::BlockedEncodingAttr::get(
35-
module.getContext(), dstType.getShape(), sizePerThread,
36-
getOrder(shared), numWarps, threadsPerWarp, numCTAs));
37-
auto newSplat = builder.create<triton::SplatOp>(splatOp.getLoc(), newType,
38-
splatOp.getSrc());
39-
auto newConvert = builder.create<triton::gpu::ConvertLayoutOp>(
40-
splatOp.getLoc(), dstType, newSplat.getResult());
41-
splatOp.replaceAllUsesWith(newConvert.getResult());
42-
splatOp.erase();
43-
}
44-
});
45-
}
46-
4721
void decomposeTensorCoreToDotLayoutConversion(ModuleOp module,
4822
ShortcutFn shortcutFn) {
4923
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(module);

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "triton/Dialect/Triton/IR/Dialect.h"
99
#include "triton/Dialect/Triton/IR/Types.h"
1010
#include "triton/Dialect/Triton/IR/Utility.h"
11+
#include "triton/Tools/LinearLayout.h"
1112
#include "llvm/Support/ErrorHandling.h"
1213

1314
namespace mlir {
@@ -701,24 +702,21 @@ LogicalResult ReshapeOp::verify() {
701702
"encodings, or (b) neither does.");
702703
}
703704

704-
if (srcEnc && !getAllowReorder()) {
705-
Attribute inferredDstEnc;
706-
if (cast<DialectInferLayoutInterface>(&srcEnc.getDialect())
707-
->inferReshapeOpNoReorderEncoding(srcTy.getShape(), srcEnc,
708-
dstTy.getShape(), inferredDstEnc,
709-
getLoc())
710-
.failed()) {
711-
return emitError("This reshape is impossible without reordering, but "
712-
"reordering is not allowed. Try choosing a different "
713-
"encoding for the input tensor (or allow reordering).");
714-
}
715-
if (inferredDstEnc != dstEnc) {
716-
return emitError("Expected result encoding ")
717-
<< inferredDstEnc << " but was " << dstEnc;
718-
}
705+
if (!srcEnc || getAllowReorder()) {
706+
return success();
719707
}
720708

721-
return success();
709+
// Check that we can infer the dst encoding from the src encoding
710+
// and that the inferred dst encoding is the same as the given dst encoding
711+
Attribute inferredDstEnc;
712+
auto result =
713+
cast<DialectInferLayoutInterface>(&srcEnc.getDialect())
714+
->inferReshapeOpEncoding(srcTy.getShape(), srcEnc, dstTy.getShape(),
715+
inferredDstEnc, getLoc());
716+
assert(succeeded(result));
717+
return cast<DialectInferLayoutInterface>(&srcEnc.getDialect())
718+
->verifyLayoutsAreEqual(dstTy.getShape(), inferredDstEnc, dstEnc,
719+
getLoc());
722720
}
723721

724722
//-- FpToFpOp --

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 80 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1630,11 +1630,12 @@ LinearEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
16301630

16311631
SmallVector<unsigned>
16321632
LinearEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape, Type) const {
1633-
// We can relax this assert by calling toLinearLayout rather than
1634-
// getLinearLayout
1635-
SmallVector<int32_t> shapeVec(shape.begin(), shape.end());
1636-
assert(shapeVec == llvm::to_vector(getLinearLayout().getOutDimSizes()));
1637-
auto ll = getLinearLayout();
1633+
// When broadcasting the layout the shape changes, otherwise the shape is
1634+
// the same as the shape of the tensor
1635+
// We can either have BroadcastOp with SameOperandsAndResultEncoding, or keep
1636+
// the invariant that the shape of the LL is that of the tensor
1637+
// We choose the former for BC
1638+
auto ll = *toLinearLayout(shape);
16381639
return basesPerDim(ll, StringAttr::get(getContext(), "register"));
16391640
}
16401641

@@ -2658,8 +2659,8 @@ struct TritonGPUInferLayoutInterface
26582659
// contains elements [a,b,c,d] before the reshape, it contains those same
26592660
// elements after the reshape, they're just "renamed".
26602661
//
2661-
// A dst encoding that satisfies this property does not exist for all inputs.
2662-
// Here are some positive and negative examples.
2662+
// Using legacy layouts, a dst encoding that satisfies this property may not
2663+
// exist. Here are some positive and negative examples.
26632664
//
26642665
// - NOT OK: 4x4 order=[0,1] -> 16. Reshape merges elements so
26652666
// dim 1 is the fastest-changing in the dst, but the src has the opposite
@@ -2673,17 +2674,19 @@ struct TritonGPUInferLayoutInterface
26732674
// - OK: 32x4 sizePerThread=[4,4] -> 128. dst with sizePerThread=[16] will
26742675
// contain the same elements as before.
26752676
//
2677+
// With linear layouts, we can always find a dst encoding that satisfies
2678+
// this property. See inferReshapeOpEncoding.
2679+
//
26762680
// Users of this function require that it is symmetrical: if
26772681
// (srcShape,srcEnc,dstShape) => dstEnc, then (dstShape,dstEnc,srcShape) =>
26782682
// srcEnc.
2679-
LogicalResult
2680-
inferReshapeOpNoReorderEncoding(ArrayRef<int64_t> srcShape, Attribute srcEnc,
2681-
ArrayRef<int64_t> dstShape, Attribute &dstEnc,
2682-
std::optional<Location> loc) const override {
2683+
LogicalResult inferReshapeOpLegacyEncoding(ArrayRef<int64_t> srcShape,
2684+
Attribute srcEnc,
2685+
ArrayRef<int64_t> dstShape,
2686+
Attribute &dstEnc) const {
26832687
auto src = mlir::dyn_cast<BlockedEncodingAttr>(srcEnc);
26842688
if (!src) {
2685-
return emitOptionalError(
2686-
loc, "Non-reordering reshape only supports BlockedEncoding");
2689+
return failure();
26872690
}
26882691

26892692
// Nop reshape; we can always infer an encoding.
@@ -2716,9 +2719,7 @@ struct TritonGPUInferLayoutInterface
27162719
// to handle CTASplitNum.
27172720
if (!all_of(src.getCTAsPerCGA(), [](int32_t x) { return x == 1; }) ||
27182721
!all_of(src.getCTASplitNum(), [](int32_t x) { return x == 1; })) {
2719-
return emitOptionalError(
2720-
loc, "Non-reordering reshape does not currently support multi-CTA "
2721-
"layouts other than the default layout.");
2722+
return failure();
27222723
}
27232724

27242725
// Cowardly refuse to handle encodings where shape[dim] is not divisible by
@@ -2728,12 +2729,7 @@ struct TritonGPUInferLayoutInterface
27282729
for (int dim = 0; dim < srcShape.size(); dim++) {
27292730
if (srcShape[dim] >= subblock[dim] &&
27302731
srcShape[dim] % subblock[dim] != 0) {
2731-
return emitOptionalError(loc,
2732-
"Can't do a non-reordering reshape because "
2733-
"the size of dimension ",
2734-
dim, " (", srcShape[dim], ")",
2735-
" is not divisible by ", name, "[", dim, "]",
2736-
" = ", subblock[dim]);
2732+
return failure();
27372733
}
27382734
}
27392735
return success();
@@ -2758,11 +2754,7 @@ struct TritonGPUInferLayoutInterface
27582754
// physical order, with `a` being the most major.
27592755
for (const auto &[srcDims, dstDims] : decomp) {
27602756
if (!isConsecutive(to_vector(reverse(gather(srcInvOrder, srcDims))))) {
2761-
return emitOptionalError(loc,
2762-
"Cannot do a non-reordering reshape given "
2763-
"this src encoding order. Dimensions [",
2764-
join(srcDims),
2765-
"] must be physically consecutive.");
2757+
return failure();
27662758
}
27672759
}
27682760

@@ -2809,11 +2801,7 @@ struct TritonGPUInferLayoutInterface
28092801
// Check that more-minor dims all have 1 in shapeRemaining.
28102802
for (int j = i + 1; j < srcDims.size(); j++) {
28112803
if (shapeRemaining[j] != 1) {
2812-
return emitOptionalError(
2813-
loc,
2814-
"Invalid src encoding for non-reordering reshape. Must use "
2815-
"up sizePerThread / threadsPerWarp / warpsPerCTA for "
2816-
"more-minor dimensions before more major-dims can use them.");
2804+
return failure();
28172805
}
28182806
}
28192807

@@ -2828,13 +2816,7 @@ struct TritonGPUInferLayoutInterface
28282816
// only if we're the most-major dimension of the chunk and in all
28292817
// future chunks, only this most-major dim has a non-1 size.
28302818
if (shapeRemaining[i] == 0 && i != 0) {
2831-
return emitOptionalError(
2832-
loc,
2833-
"Invalid src encoding for non-reordering reshape. Block "
2834-
"size in dimension ",
2835-
dim,
2836-
" is larger than the shape that dimension, but this is only "
2837-
"allowed for the most-major dimension of a reshape chunk");
2819+
return failure();
28382820
}
28392821
}
28402822
return success();
@@ -2924,6 +2906,65 @@ struct TritonGPUInferLayoutInterface
29242906
return success();
29252907
}
29262908

2909+
LogicalResult verifyLayoutsAreEqual(ArrayRef<int64_t> shape,
2910+
Attribute expected, Attribute got,
2911+
Location loc) const override {
2912+
if (expected == got) {
2913+
return success();
2914+
}
2915+
// Check whether the encodings are structurally the same.
2916+
auto expectedLL = triton::gpu::toLinearLayout(shape, expected);
2917+
auto gotLL = triton::gpu::toLinearLayout(shape, got);
2918+
if (expectedLL != gotLL) {
2919+
return emitError(loc, "Expected result encoding ")
2920+
<< expected << " but was " << got;
2921+
}
2922+
return success();
2923+
}
2924+
2925+
LogicalResult
2926+
inferReshapeOpEncoding(ArrayRef<int64_t> srcShape, Attribute srcEnc,
2927+
ArrayRef<int64_t> dstShape, Attribute &dstEnc,
2928+
std::optional<Location> loc) const override {
2929+
auto result =
2930+
inferReshapeOpLegacyEncoding(srcShape, srcEnc, dstShape, dstEnc);
2931+
if (succeeded(result)) {
2932+
return result;
2933+
}
2934+
2935+
// If the legacy encoding failed use LinearLayouts.
2936+
// Once LinearLayouts are more widely used, we can remove
2937+
// inferReshapeOpLegacyEncoding and simply use LLs.
2938+
auto *ctx = getContext();
2939+
auto src = triton::gpu::toLinearLayout(srcShape, srcEnc);
2940+
if (!src) {
2941+
return emitOptionalError(loc,
2942+
"src encoding does not support linear layout");
2943+
}
2944+
2945+
if (product(srcShape) != product(dstShape)) {
2946+
return emitOptionalError(loc, "numel of dst shape does not match "
2947+
"numel of src shape");
2948+
}
2949+
2950+
auto newRank = dstShape.size();
2951+
SmallVector<std::pair<StringAttr, int32_t>> newOutDims;
2952+
for (auto [dim, size] :
2953+
llvm::zip(standardOutDimNames(ctx, newRank), dstShape)) {
2954+
newOutDims.emplace_back(dim, size);
2955+
}
2956+
auto srcOutDims = llvm::to_vector(src->getOutDimNames());
2957+
// reshapeOp assumes minor-to-major, so we need to transpose the out dims
2958+
// before the reshape
2959+
std::reverse(srcOutDims.begin(), srcOutDims.end());
2960+
std::reverse(newOutDims.begin(), newOutDims.end());
2961+
auto dst = src->transposeOuts(srcOutDims)
2962+
.reshapeOuts(newOutDims)
2963+
.transposeOuts(standardOutDimNames(ctx, newRank));
2964+
dstEnc = LinearEncodingAttr::get(ctx, dst);
2965+
return success();
2966+
}
2967+
29272968
LogicalResult
29282969
inferJoinOpEncoding(Attribute srcEnc, Attribute &dstEnc,
29292970
std::optional<Location> loc) const override {

0 commit comments

Comments
 (0)