Skip to content

Commit 763d9a5

Browse files
Merge OpenAI Triton commit e71689d (#4590)
This PR change the Triton base from 09d5113 to e71689d (Jun 24). Pass rate: 97.12%
2 parents 2b612ce + 1d605e1 commit 763d9a5

File tree

58 files changed

+653
-285
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+653
-285
lines changed

README.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,17 @@ downloads a prebuilt LLVM, but you can also build LLVM from source and use that.
5555
LLVM does not have a stable API, so the Triton build will not work at an
5656
arbitrary LLVM version.
5757

58+
For convenience, use the following command to build LLVM and install Triton with the custom LLVM:
59+
60+
```shell
61+
make dev-install-llvm
62+
```
63+
64+
<details>
65+
<summary>
66+
Alternatively, follow these steps to build LLVM from source manually.
67+
</summary>
68+
5869
1. Find the version of LLVM that Triton builds against. Check
5970
`cmake/llvm-hash.txt` to see the current version. For example, if it says:
6071
49af6502c6dcb4a7f7520178bd14df396f78240c
@@ -86,6 +97,8 @@ arbitrary LLVM version.
8697
LLVM_SYSPATH=$LLVM_BUILD_DIR \
8798
pip install -e .
8899

100+
</details>
101+
89102
# Tips for building
90103

91104
- Set `TRITON_BUILD_WITH_CLANG_LLD=true` as an environment variable to use clang

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -296,8 +296,6 @@ class TritonLLVMIRRewriter : public IRRewriter, public TritonLLVMOpBuilder {
296296
// Types
297297
#define ptr_ty(...) LLVM::LLVMPointerType::get(__VA_ARGS__)
298298
#define int_ty(width) rewriter.getIntegerType(width)
299-
#define i64_ty rewriter.getIntegerType(64)
300-
#define i32_ty rewriter.getIntegerType(32)
301299
#define i16_ty rewriter.getIntegerType(16)
302300
#define i32_ty rewriter.getIntegerType(32)
303301
#define i64_ty rewriter.getIntegerType(64)

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1002,11 +1002,11 @@ An encoding for tensors that have been produced by MFMA matrix core instructions
10021002
available on AMD Instinct GPUs of CDNA architectures.
10031003

10041004
It is characterized by the following parameters:
1005-
- `versionMajor` and `versionMinor` indicates the GPU architecture:
1006-
- 1.0: gfx908, i.e. CDNA1
1007-
- 2.0: gfx90a: i.e. CDNA2
1008-
- 3.0: gfx942: CDNA3
1009-
- 4.0: gfx950: CDNA4
1005+
- `version` indicates the GPU architecture:
1006+
- 1: gfx908: CDNA1
1007+
- 2: gfx90a: CDNA2
1008+
- 3: gfx942: CDNA3
1009+
- 4: gfx950: CDNA4
10101010
- `warpsPerCTA` indicates the warp layout in the block.
10111011
- `MDim` and `NDim` indicate the dimension of the output of the mfma instruction.
10121012
- `isTransposed` indicates the result tensor is transposed so that it can be converted to dotOperand layout
@@ -1096,8 +1096,7 @@ V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,
10961096

10971097
let parameters = (
10981098
ins
1099-
"unsigned": $versionMajor,
1100-
"unsigned": $versionMinor,
1099+
"unsigned": $version,
11011100
ArrayRefParameter<"unsigned">:$warpsPerCTA,
11021101
"unsigned":$MDim,
11031102
"unsigned":$NDim,

include/triton/Dialect/TritonGPU/Transforms/Schedule.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ class CoarseSchedule {
8585
using Cluster = ClusterList::iterator;
8686
using ClusterHash = size_t;
8787

88-
DenseMap<Operation *, std::pair<int, Cluster>> opToStageAndCluster;
88+
llvm::MapVector<Operation *, std::pair<int, Cluster>> opToStageAndCluster;
8989

9090
void setNumStages(int numStages) { this->numStages = numStages; }
9191
int getNumStages() const { return numStages; }

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,10 @@ applyLinearLayout(Location loc, RewriterBase &rewriter,
154154
auto b = TritonLLVMOpBuilder(loc, rewriter);
155155
assert(layout.getNumInDims() == indices.size());
156156
assert(llvm::equal(layout.getInDimNames(), llvm::make_first_range(indices)));
157+
// Trivial layout
158+
if (layout.getNumOutDims() == 0) {
159+
return {};
160+
}
157161

158162
// This function can emit a lot of MLIR code, which ultimately makes
159163
// compilation slow. (We think this shouldn't be the case -- it's not *that*
@@ -167,25 +171,29 @@ applyLinearLayout(Location loc, RewriterBase &rewriter,
167171
SmallVector<std::pair<StringAttr, int32_t>> constantIns;
168172
SmallVector<std::pair<StringAttr, Value>> nonConstantIns;
169173
for (auto [inDimName, idx] : indices) {
170-
if (auto constant = idx.getDefiningOp<LLVM::ConstantOp>()) {
171-
constantIns.push_back(
172-
{inDimName, cast<IntegerAttr>(constant.getValue()).getInt()});
174+
APInt constant;
175+
if (matchPattern(idx, m_ConstantInt(&constant))) {
176+
constantIns.push_back({inDimName, constant.getSExtValue()});
173177
} else {
174178
constantIns.push_back({inDimName, 0});
175179
nonConstantIns.push_back({inDimName, idx});
176180
}
177181
}
178-
SmallVector<int32_t> constantComponent =
179-
llvm::to_vector(llvm::make_second_range(layout.apply(constantIns)));
180182

183+
// Compute constant part of the output and wrap it as values
181184
Value zero = b.i32_val(0);
182185
SmallVector<std::pair<StringAttr, Value>> outIndices;
183-
for (auto [i, outDimName] : llvm::enumerate(layout.getOutDimNames())) {
184-
if (constantComponent[i] == 0)
186+
for (auto [outDimName, constant] : layout.apply(constantIns)) {
187+
if (constant == 0)
185188
outIndices.push_back({outDimName, zero});
186189
else
187-
outIndices.push_back({outDimName, b.i32_val(constantComponent[i])});
190+
outIndices.push_back({outDimName, b.i32_val(constant)});
191+
}
192+
193+
if (nonConstantIns.size() == 0) {
194+
return outIndices;
188195
}
196+
189197
// Happy path: Only one output.
190198
if (outIndices.size() == 1) {
191199
SmallVector<StringAttr> inDimNames;

lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -504,9 +504,13 @@ struct MemDescSubviewOpConversion
504504
// The order gives us the honest-to-goodness layout rank
505505
auto srcAllocShape =
506506
srcTy.getAllocShape().take_back(getOrder(srcTy).size());
507-
auto llInv = toLinearLayout(srcAllocShape, srcTy.getEncoding()).invert();
508-
offset =
509-
applyLinearLayout(loc, rewriter, llInv, logicalOffsets)[0].second;
507+
auto ll = toLinearLayout(srcAllocShape, srcTy.getEncoding());
508+
// Checked in the verifier.
509+
assert(ll.getInDimSize(str_attr("block")) == 1);
510+
auto kOffset = str_attr("offset");
511+
ll = ll.reshapeIns({{kOffset, ll.getTotalInDimSize()}});
512+
offset = applyLinearLayout(loc, rewriter, ll.invert(), logicalOffsets)[0]
513+
.second;
510514
}
511515

512516
auto base = smemObj.getBase();

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 35 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
#include "triton/Tools/LayoutUtils.h"
2424
#include "triton/Tools/LinearLayout.h"
2525
#include "triton/Tools/StrUtil.h"
26-
#include "triton/Tools/Sys/GetEnv.hpp"
2726
#include "llvm/ADT/SmallSet.h"
2827
#include "llvm/ADT/TypeSwitch.h"
2928
#include "llvm/Support/MathExtras.h"
@@ -428,6 +427,15 @@ getDefaultBlockedEncoding(MLIRContext *context, ArrayRef<int64_t> shape,
428427
return encoding;
429428
}
430429

430+
bool isSplitCompatible(MLIRContext *ctx, const LinearLayout &ll) {
431+
auto lastDim = ll.getNumOutDims() - 1;
432+
auto kReg = StringAttr::get(ctx, "register");
433+
auto kLastDim = StringAttr::get(ctx, "dim" + std::to_string(lastDim));
434+
auto sublayout =
435+
ll.sublayout({kReg}, {kLastDim}).removeZeroBasesAlongDim(kReg);
436+
return sublayout == LinearLayout::identity1D(2, kReg, kLastDim);
437+
}
438+
431439
LogicalResult tryJoinOnAxis(MLIRContext *ctx, const LinearLayout &inLl,
432440
LinearLayout &outLl, bool fwdInference, int axis,
433441
std::optional<Location> loc) {
@@ -1331,8 +1339,7 @@ Attribute AMDMfmaEncodingAttr::parse(AsmParser &parser, Type type) {
13311339
if (parser.parseGreater().failed())
13321340
return {};
13331341

1334-
unsigned versionMajor = 0;
1335-
unsigned versionMinor = 0;
1342+
unsigned version = 0;
13361343
SmallVector<unsigned> warpsPerCTA;
13371344
SmallVector<unsigned> instrShape;
13381345
bool isTransposed;
@@ -1341,12 +1348,8 @@ Attribute AMDMfmaEncodingAttr::parse(AsmParser &parser, Type type) {
13411348
std::optional<SmallVector<unsigned>> CTAOrder;
13421349

13431350
for (const NamedAttribute &attr : dict) {
1344-
if (attr.getName() == "versionMajor") {
1345-
if (parseUInt(parser, attr, versionMajor, "versionMajor").failed())
1346-
return {};
1347-
}
1348-
if (attr.getName() == "versionMinor") {
1349-
if (parseUInt(parser, attr, versionMinor, "versionMinor").failed())
1351+
if (attr.getName() == "version") {
1352+
if (parseUInt(parser, attr, version, "verison").failed())
13501353
return {};
13511354
}
13521355
if (attr.getName() == "warpsPerCTA") {
@@ -1385,14 +1388,13 @@ Attribute AMDMfmaEncodingAttr::parse(AsmParser &parser, Type type) {
13851388
return {};
13861389

13871390
return parser.getChecked<AMDMfmaEncodingAttr>(
1388-
parser.getContext(), versionMajor, versionMinor, warpsPerCTA,
1389-
instrShape[0], instrShape[1], isTransposed, *CTALayout);
1391+
parser.getContext(), version, warpsPerCTA, instrShape[0], instrShape[1],
1392+
isTransposed, *CTALayout);
13901393
}
13911394

13921395
void AMDMfmaEncodingAttr::print(AsmPrinter &printer) const {
13931396
printer << "<{"
1394-
<< "versionMajor = " << getVersionMajor() //
1395-
<< ", versionMinor = " << getVersionMinor() //
1397+
<< "version = " << getVersion() //
13961398
<< ", warpsPerCTA = [" << getWarpsPerCTA() << "]" //
13971399
<< ", instrShape = [" << ArrayRef{getMDim(), getNDim()} << "]" //
13981400
<< ", isTransposed = " << getIsTransposed();
@@ -1401,17 +1403,12 @@ void AMDMfmaEncodingAttr::print(AsmPrinter &printer) const {
14011403
printer << "}>";
14021404
}
14031405

1404-
LogicalResult
1405-
AMDMfmaEncodingAttr::verify(function_ref<mlir::InFlightDiagnostic()> emitError,
1406-
unsigned versionMajor, unsigned versionMinor,
1407-
llvm::ArrayRef<unsigned int> warpsPerCTA,
1408-
unsigned mDim, unsigned nDim, bool isTransposed,
1409-
mlir::triton::gpu::CTALayoutAttr) {
1410-
if (!(versionMajor >= 0 && versionMajor <= 4)) {
1411-
return emitError() << "major version must be in the [0, 4] range";
1412-
}
1413-
if (versionMinor != 0) {
1414-
return emitError() << "minor version must be 0";
1406+
LogicalResult AMDMfmaEncodingAttr::verify(
1407+
function_ref<mlir::InFlightDiagnostic()> emitError, unsigned version,
1408+
llvm::ArrayRef<unsigned int> warpsPerCTA, unsigned mDim, unsigned nDim,
1409+
bool isTransposed, mlir::triton::gpu::CTALayoutAttr) {
1410+
if (!(version >= 0 && version <= 4)) {
1411+
return emitError() << "version must be in the [0, 4] range";
14151412
}
14161413
if (!((mDim == 32 && nDim == 32) || (mDim == 16 && nDim == 16))) {
14171414
return emitError()
@@ -1965,7 +1962,7 @@ SwizzledSharedEncodingAttr AMDMfmaEncodingAttr::composeSharedLayoutForOperand(
19651962
bool isKContig = sharedOrder[0] == kDimIndex;
19661963
// GFX950 supports LDS transpose load instructions, so we need swizzling even
19671964
// when K dimension is not the contiguous dimension.
1968-
bool isGFX950 = getVersionMajor() == 4;
1965+
bool isGFX950 = getVersion() == 4;
19691966
bool swizzleNonKContig =
19701967
isGFX950 && (elemBitWidth == 8 || elemBitWidth == 16);
19711968

@@ -2654,7 +2651,19 @@ struct TritonGPUInferLayoutInterface
26542651
inferDefaultJoinOpEncoding(Attribute srcEnc, Attribute &dstEnc,
26552652
ArrayRef<int64_t> shape,
26562653
std::optional<Location> loc) const override {
2657-
if (auto enc = mlir::dyn_cast<BlockedEncodingAttr>(srcEnc)) {
2654+
auto ctx = getContext();
2655+
if (auto enc = mlir::dyn_cast<SliceEncodingAttr>(srcEnc);
2656+
enc && enc.getDim() == shape.size()) {
2657+
SmallVector<int64_t> joinedShape(shape);
2658+
joinedShape.push_back(2);
2659+
auto parent = enc.getParent();
2660+
auto parentLL = toLinearLayout(joinedShape, parent);
2661+
2662+
if (isSplitCompatible(ctx, parentLL)) {
2663+
dstEnc = parent;
2664+
return success();
2665+
}
2666+
} else if (auto enc = mlir::dyn_cast<BlockedEncodingAttr>(srcEnc)) {
26582667
// JoinOp takes two tensors of shape AxBxC and generates a tensor of shape
26592668
// AxBxCx2. The encoding is the same as the input, but with 2 elems per
26602669
// thread in the new dimension. The new dimension is the fastest running
@@ -2679,8 +2688,6 @@ struct TritonGPUInferLayoutInterface
26792688
return success();
26802689
}
26812690

2682-
auto ctx = getContext();
2683-
26842691
// Append dim to shape
26852692
auto ll = toLinearLayout(shape, srcEnc);
26862693
SmallVector<int64_t> dstShape(shape.begin(), shape.end());
@@ -2757,7 +2764,6 @@ struct TritonGPUInferLayoutInterface
27572764
if (!result.succeeded()) {
27582765
return failure();
27592766
}
2760-
27612767
// Remove last dim from newLl (which should be 1)
27622768
SmallVector<int64_t> dstShape(shape.begin(), shape.end());
27632769
dstShape.pop_back();

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1529,7 +1529,7 @@ chooseMfmaLikeStoreLayout(RankedTensorType valType) {
15291529

15301530
Type elemType = valType.getElementType();
15311531
if (!(valType.getRank() == 2 && (elemType.isF16() || elemType.isBF16()) &&
1532-
mfmaLayout.getVersionMajor() == 4 && mfmaLayout.getIsTransposed() &&
1532+
mfmaLayout.getVersion() == 4 && mfmaLayout.getIsTransposed() &&
15331533
(isMfma32 || validForMfma16)))
15341534
return {};
15351535

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -723,8 +723,14 @@ LogicalResult MemDescSubviewOp::verify() {
723723
auto ctx = getContext();
724724
// The order gives us the honest-to-goodness layout rank
725725
auto srcAllocShape = srcTy.getAllocShape().take_back(getOrder(srcTy).size());
726-
auto llInv =
727-
triton::gpu::toLinearLayout(srcAllocShape, srcTy.getEncoding()).invert();
726+
auto ll = triton::gpu::toLinearLayout(srcAllocShape, srcTy.getEncoding());
727+
// NYI: We don't support non-trivial block dimension for now.
728+
auto kBlock = mlir::StringAttr::get(getContext(), "block");
729+
if (ll.getInDimSize(kBlock) != 1) {
730+
return emitError("non-trivial block dimension not supported");
731+
}
732+
733+
auto llInv = ll.invert();
728734
auto kDim = mlir::StringAttr::get(ctx, "dim" + llvm::Twine(dim));
729735
llvm::SmallVector<std::pair<mlir::StringAttr, int32_t>> namedOffsets;
730736
for (auto d : standardOutDimNames(ctx, srcTy.getRank())) {

lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,8 @@ tt::CoarseSchedule::splitClusterBefore(Operation *op, scf::ForOp forOp) {
126126
bool tt::CoarseSchedule::isOpBefore(Operation *a, Operation *b) const {
127127
assert(opToStageAndCluster.count(a) && opToStageAndCluster.count(b) &&
128128
"Operations must be in the schedule");
129-
auto [aStage, aCluster] = opToStageAndCluster.at(a);
130-
auto [bStage, bCluster] = opToStageAndCluster.at(b);
129+
auto [aStage, aCluster] = opToStageAndCluster.lookup(a);
130+
auto [bStage, bCluster] = opToStageAndCluster.lookup(b);
131131
if (aStage != bStage) {
132132
return aStage < bStage;
133133
}
@@ -141,14 +141,15 @@ bool tt::CoarseSchedule::isOpInEarlierCluster(Operation *a,
141141
Operation *b) const {
142142
assert(opToStageAndCluster.count(a) && opToStageAndCluster.count(b) &&
143143
"Operations must be in the schedule");
144-
return clusters.isBefore(opToStageAndCluster.at(a).second,
145-
opToStageAndCluster.at(b).second);
144+
return clusters.isBefore(opToStageAndCluster.lookup(a).second,
145+
opToStageAndCluster.lookup(b).second);
146146
}
147147

148148
bool tt::CoarseSchedule::isOpInSameCluster(Operation *a, Operation *b) const {
149149
assert(opToStageAndCluster.count(a) && opToStageAndCluster.count(b) &&
150150
"Operations must be in the schedule");
151-
return opToStageAndCluster.at(a).second == opToStageAndCluster.at(b).second;
151+
return opToStageAndCluster.lookup(a).second ==
152+
opToStageAndCluster.lookup(b).second;
152153
}
153154

154155
SmallVector<std::tuple<Operation *, int, tt::CoarseSchedule::Cluster>>

0 commit comments

Comments
 (0)