Skip to content

Commit b7a6ffb

Browse files
Merge OpenAI commit dbc85fc (#5210)
This PR change the Triton base from 1b27b93 to dbc85fc (Sep 23). Pass rate: 96.32%->96.23%
2 parents 10fac7c + 67a7052 commit b7a6ffb

File tree

45 files changed

+506
-229
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+506
-229
lines changed

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,8 +272,8 @@ llvm::SmallVector<unsigned>
272272
expandMatrixOrderWithBatch(llvm::ArrayRef<unsigned> o);
273273

274274
// Return true if the two layouts represent the exact same mapping.
275-
bool areLayoutsEquivalent(ArrayRef<int64_t> shape, DistributedEncodingTrait lhs,
276-
DistributedEncodingTrait rhs);
275+
bool areLayoutsEquivalent(ArrayRef<int64_t> shape, LayoutEncodingTrait lhs,
276+
LayoutEncodingTrait rhs);
277277

278278
// Return true if the innermost numElems are contiguous.
279279
bool isInnermostContiguous(MemDescType type, unsigned numElems);
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#ifndef TRITON_DIALECT_TRITONGPU_TRANSFORMS_LAYOUT_PROPAGATION_UTILITY_H_
2+
#define TRITON_DIALECT_TRITONGPU_TRANSFORMS_LAYOUT_PROPAGATION_UTILITY_H_
3+
4+
#include "triton/Dialect/Triton/IR/Dialect.h"
5+
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
6+
#include "triton/Tools/LinearLayout.h"
7+
#include <optional>
8+
9+
namespace mlir::triton::gpu {
10+
11+
// Given the result |dstLayout|, infer the source layout that we should use for
12+
// global load if we propagate through op def chain of |defOp|. Returns
13+
// std::nullopt if fails to infer or cannot reach a global load.
14+
std::optional<std::pair<triton::LoadOp, LinearLayout>>
15+
inferSourceLoadLayout(const LinearLayout &dstLayout, Operation *defOp);
16+
std::optional<std::pair<triton::LoadOp, LinearLayout>>
17+
inferSourceLoadLayout(LinearEncodingAttr dstLayout, Operation *defOp);
18+
19+
} // namespace mlir::triton::gpu
20+
21+
#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_LAYOUT_PROPAGATION_UTILITY_H_

lib/Analysis/AxisInfo.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -988,6 +988,35 @@ class MaxMinOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
988988
}
989989
};
990990

991+
class TransOpAxisInfoVisitor final
992+
: public AxisInfoVisitorImpl<triton::TransOp> {
993+
public:
994+
using AxisInfoVisitorImpl<triton::TransOp>::AxisInfoVisitorImpl;
995+
996+
AxisInfo
997+
getAxisInfo(triton::TransOp op,
998+
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
999+
AxisInfo srcInfo = operands[0]->getValue();
1000+
auto order = op.getOrder();
1001+
auto rank = srcInfo.getRank();
1002+
1003+
// Apply the transpose permutation to all axis info properties
1004+
AxisInfo::DimVectorT contiguity;
1005+
AxisInfo::DimVectorT divisibility;
1006+
AxisInfo::DimVectorT constancy;
1007+
1008+
for (int d = 0; d < rank; ++d) {
1009+
int srcDim = order[d];
1010+
contiguity.push_back(srcInfo.getContiguity(srcDim));
1011+
divisibility.push_back(srcInfo.getDivisibility(srcDim));
1012+
constancy.push_back(srcInfo.getConstancy(srcDim));
1013+
}
1014+
1015+
return AxisInfo(contiguity, divisibility, constancy,
1016+
srcInfo.getConstantValue());
1017+
}
1018+
};
1019+
9911020
//===----------------------------------------------------------------------===//
9921021
// AxisInfoAnalysis
9931022
//===----------------------------------------------------------------------===//
@@ -1032,6 +1061,7 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver,
10321061
MaxMinOpAxisInfoVisitor<arith::MinSIOp>,
10331062
MaxMinOpAxisInfoVisitor<arith::MinUIOp>>();
10341063
visitors.append<LoadOpAxisInfoVisitor>();
1064+
visitors.append<TransOpAxisInfoVisitor>();
10351065

10361066
if (callback)
10371067
callback(visitors);

lib/Dialect/Triton/IR/Traits.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,41 @@
66
#include "triton/Dialect/Triton/IR/Dialect.h"
77
#include "triton/Dialect/Triton/IR/Types.h"
88
#include "triton/Dialect/Triton/IR/Utility.h"
9+
#include "triton/Dialect/TritonGPU/IR/Types.h"
910
#include "llvm/Support/ErrorHandling.h"
1011

1112
using namespace mlir;
13+
using namespace mlir::triton::gpu;
1214

1315
LogicalResult OpTrait::impl::verifyEquivalentType(Type typeA, Type typeB) {
16+
auto memdescA = dyn_cast<MemDescType>(typeA);
17+
auto memdescB = dyn_cast<MemDescType>(typeB);
18+
if (memdescA || memdescB) {
19+
if (!memdescA || !memdescB)
20+
return failure();
21+
if (memdescA.getShape() != memdescB.getShape())
22+
return failure();
23+
if (memdescA.getAllocShape() != memdescB.getAllocShape())
24+
return failure();
25+
if (memdescA.getElementType() != memdescB.getElementType())
26+
return failure();
27+
if (memdescA.getMemorySpace() != memdescB.getMemorySpace())
28+
return failure();
29+
if (memdescA.getMutableMemory() != memdescB.getMutableMemory())
30+
return failure();
31+
32+
Attribute encodingA = memdescA.getEncoding();
33+
Attribute encodingB = memdescB.getEncoding();
34+
if (encodingA == encodingB)
35+
return success();
36+
if (static_cast<bool>(encodingA) != static_cast<bool>(encodingB))
37+
return failure();
38+
39+
auto layoutInterface =
40+
cast<triton::DialectInferLayoutInterface>(&encodingA.getDialect());
41+
return layoutInterface->verifyLayoutsAreEqual(memdescA.getShape(),
42+
encodingA, encodingB, {});
43+
}
1444
auto tensorTypeA = dyn_cast<RankedTensorType>(typeA);
1545
auto tensorTypeB = dyn_cast<RankedTensorType>(typeB);
1646
if (!(bool(tensorTypeA) && bool(tensorTypeB)))

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3054,8 +3054,8 @@ struct TritonGPUInferLayoutInterface
30543054
return failure();
30553055

30563056
// Check whether the encodings are structurally the same.
3057-
if (!areLayoutsEquivalent(shape, cast<DistributedEncodingTrait>(expected),
3058-
cast<DistributedEncodingTrait>(got))) {
3057+
if (!areLayoutsEquivalent(shape, cast<LayoutEncodingTrait>(expected),
3058+
cast<LayoutEncodingTrait>(got))) {
30593059
return emitOptionalError(loc, "Expected result encoding ", expected,
30603060
" but was ", got);
30613061
}
@@ -3109,8 +3109,8 @@ struct TritonGPUInferLayoutInterface
31093109
Attribute splitEnc;
31103110
auto result = inferSplitOpEncoding(parent, splitEnc, joinedShape, loc);
31113111
if (succeeded(result) &&
3112-
areLayoutsEquivalent(shape, cast<DistributedEncodingTrait>(splitEnc),
3113-
cast<DistributedEncodingTrait>(srcEnc))) {
3112+
areLayoutsEquivalent(shape, cast<LayoutEncodingTrait>(splitEnc),
3113+
cast<LayoutEncodingTrait>(srcEnc))) {
31143114
dstEnc = parent;
31153115
return success();
31163116
}
@@ -3807,8 +3807,8 @@ int triton::gpu::lookupNumCTAs(OpBuilder &rewriter) {
38073807
}
38083808

38093809
bool triton::gpu::areLayoutsEquivalent(ArrayRef<int64_t> shape,
3810-
DistributedEncodingTrait lhs,
3811-
DistributedEncodingTrait rhs) {
3810+
LayoutEncodingTrait lhs,
3811+
LayoutEncodingTrait rhs) {
38123812
auto lhsLL = triton::gpu::toLinearLayout(shape, lhs);
38133813
auto rhsLL = triton::gpu::toLinearLayout(shape, rhs);
38143814
return lhsLL == rhsLL;

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -515,16 +515,7 @@ LogicalResult MemDescReshapeOp::verify() {
515515
if (failed(inferReturnTypes(getContext(), getLoc(), srcType,
516516
dstType.getShape(), expectedTy)))
517517
return failure();
518-
// Check that the alloc shape separately to give a cleaner error, given that
519-
// it's the most likely source of the error.
520-
if (expectedTy.getAllocShape() != dstType.getAllocShape()) {
521-
return emitError(
522-
"The result alloc shape does not match the expected alloc shape.");
523-
}
524-
if (expectedTy != dstType) {
525-
return emitError("source and destination layout are incompatible.");
526-
}
527-
return success();
518+
return OpTrait::impl::verifyEquivalentType(expectedTy, dstType);
528519
}
529520

530521
static LogicalResult inferMemDescReshapeOpEncoding(ArrayRef<int64_t> srcShape,

lib/Dialect/TritonGPU/Transforms/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ add_triton_library(TritonGPUTransforms
2727
ReorderInstructions.cpp
2828
CoalesceAsyncCopy.cpp
2929
Utility.cpp
30+
LayoutPropagationUtility.cpp
3031
WarpSpecialization/AutomaticWarpSpecialization.cpp
3132
WarpSpecialization/LoadMMASpecialization.cpp
3233
WarpSpecialization/Partition.cpp
@@ -35,6 +36,7 @@ add_triton_library(TritonGPUTransforms
3536
WarpSpecialization/PartitionLoops.cpp
3637
WarpSpecialization/PartitionScheduling.cpp
3738
WarpSpecialization/RewritePartitionDependencies.cpp
39+
3840
DEPENDS
3941
TritonGPUTransformsIncGen
4042

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#include "triton/Dialect/TritonGPU/Transforms/LayoutPropagationUtility.h"
2+
#include "triton/Dialect/Triton/IR/Dialect.h"
3+
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
4+
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
5+
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
6+
#include <optional>
7+
#include <utility>
8+
9+
namespace mlir::triton::gpu {
10+
11+
std::optional<std::pair<triton::LoadOp, LinearLayout>>
12+
inferSourceLoadLayout(const LinearLayout &dstLayout, Operation *defOp) {
13+
if (!defOp)
14+
return std::nullopt;
15+
return inferSourceLoadLayout(
16+
LinearEncodingAttr::get(defOp->getContext(), dstLayout), defOp);
17+
}
18+
19+
std::optional<std::pair<triton::LoadOp, LinearLayout>>
20+
inferSourceLoadLayout(LinearEncodingAttr dstLayout, Operation *defOp) {
21+
Attribute curLayout = dstLayout;
22+
Operation *curOp = defOp;
23+
while (curOp) {
24+
if (isa<triton::LoadOp>(curOp))
25+
break; // Found the load op; we are done here.
26+
27+
if (auto cvtOp = dyn_cast<ConvertLayoutOp>(curOp)) {
28+
// For convert op we keep the current layout to push through further.
29+
curOp = cvtOp.getSrc().getDefiningOp();
30+
} else {
31+
if (curOp->getNumOperands() != 1)
32+
break;
33+
curLayout = inferSrcEncoding(curOp, curLayout);
34+
curOp = curOp->getOperand(0).getDefiningOp();
35+
}
36+
}
37+
auto loadOp = dyn_cast_or_null<triton::LoadOp>(curOp);
38+
if (!loadOp)
39+
return std::nullopt;
40+
auto loadType = dyn_cast<RankedTensorType>(loadOp.getType());
41+
if (!loadType)
42+
return std::nullopt;
43+
44+
return std::make_pair(
45+
loadOp,
46+
toLinearLayout(loadType.getShape(), cast<LinearEncodingAttr>(curLayout)));
47+
}
48+
49+
} // namespace mlir::triton::gpu

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,13 @@
66
#include "mlir/Dialect/SCF/IR/SCF.h"
77
#include "mlir/IR/Dominance.h"
88
#include "mlir/IR/IRMapping.h"
9-
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
109
#include "triton/Analysis/AxisInfo.h"
1110
#include "triton/Dialect/Triton/IR/Dialect.h"
1211
#include "triton/Dialect/Triton/IR/Utility.h"
1312
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
1413
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
1514
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
1615
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
17-
#include "llvm/ADT/SetOperations.h"
1816
#include "llvm/Support/Debug.h"
1917

2018
#define DEBUG_TYPE "ttg-utility"

lib/Dialect/TritonNvidiaGPU/IR/Dialect.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,8 @@ bool isDistributedLayoutSplitMTmemLoadStore(RankedTensorType tensorType,
178178
if (!layout)
179179
return false;
180180
return areLayoutsEquivalent(
181-
tensorType.getShape(), cast<DistributedEncodingTrait>(layout),
182-
cast<DistributedEncodingTrait>(tensorType.getEncoding()));
181+
tensorType.getShape(), cast<LayoutEncodingTrait>(layout),
182+
cast<LayoutEncodingTrait>(tensorType.getEncoding()));
183183
}
184184

185185
SmallVector<DistributedEncodingTrait>
@@ -226,9 +226,10 @@ bool isDistributedLayoutTMemCompatible(Operation *op,
226226
gpu::MemDescType memType) {
227227
SmallVector<DistributedEncodingTrait> layouts =
228228
getTmemCompatibleLayouts(op, tensorType, memType);
229-
auto enc = cast<DistributedEncodingTrait>(tensorType.getEncoding());
229+
auto enc = cast<LayoutEncodingTrait>(tensorType.getEncoding());
230230
return llvm::any_of(layouts, [&](DistributedEncodingTrait layout) {
231-
return areLayoutsEquivalent(tensorType.getShape(), layout, enc);
231+
return areLayoutsEquivalent(tensorType.getShape(),
232+
cast<LayoutEncodingTrait>(layout), enc);
232233
});
233234
}
234235

0 commit comments

Comments
 (0)