Skip to content

Commit c412ccc

Browse files
authored
Merge OpenAI Triton commit 31baa6d (#5048)
This PR change the Triton base from d412906 to 31baa6d (Sep 1). Pass rate: 98.74% --------- Signed-off-by: Anatoly Myachev <[email protected]>
2 parents 7ec882d + fe4639a commit c412ccc

File tree

75 files changed

+2064
-822
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

75 files changed

+2064
-822
lines changed

.github/workflows/build-macos.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,13 @@ jobs:
106106
source ~/.venv/bin/activate
107107
echo "PATH is '$PATH'"
108108
ccache --zero-stats
109+
export PATH="/opt/homebrew/opt/llvm@19/bin:$PATH"
110+
export CC="/opt/homebrew/opt/llvm@19/bin/clang"
111+
export CXX="/opt/homebrew/opt/llvm@19/bin/clang++"
112+
export CXXFLAGS="-stdlib=libc++"
113+
export LDFLAGS="-L/opt/homebrew/opt/llvm@19/lib"
114+
which clang++
115+
clang++ --version
109116
make dev-install
110117
- name: CCache Stats
111118
run: ccache --print-stats

include/triton/Dialect/Triton/IR/TritonOpInterfaces.td

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,12 @@ def DotOpInterface : OpInterface<"DotOpInterface"> {
4949
/*retType=*/"::mlir::Value",
5050
/*methodName=*/"getB",
5151
/*args=*/(ins)>,
52-
InterfaceMethod<
52+
InterfaceMethod<
53+
/*desc=*/"Get the output tensor",
54+
/*retType=*/"::mlir::Value",
55+
/*methodName=*/"getD",
56+
/*args=*/(ins)>,
57+
InterfaceMethod<
5358
/*desc=*/"Verify the dimensions of the A and B DotOp operands.",
5459
/*retType=*/"bool",
5560
/*methodName=*/"verifyDims",
@@ -64,6 +69,7 @@ def DotOpInterface : OpInterface<"DotOpInterface"> {
6469
auto aTy = cast<ShapedType>($_op.getA().getType());
6570
auto bTy = cast<ShapedType>($_op.getB().getType());
6671
auto cTy = cast<ShapedType>($_op->getOperand(2).getType());
72+
auto dTy = cast<ShapedType>($_op.getD().getType());
6773
auto aShape = aTy.getShape();
6874
auto bShape = bTy.getShape();
6975
auto cShape = cTy.getShape();

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1051,15 +1051,17 @@ def TT_MakeTensorDescOp : TT_Op<"make_tensor_descriptor", [
10511051
let arguments = (ins
10521052
TT_Ptr:$base,
10531053
Variadic<I32>:$shape,
1054-
Variadic<I64>:$strides
1054+
Variadic<I64>:$strides,
1055+
DefaultValuedAttr<TT_PaddingOptionAttr, "::mlir::triton::PaddingOption::PAD_ZERO">:$padding
10551056
);
10561057

10571058
let results = (outs TT_TensorDescType:$result);
10581059

10591060
let assemblyFormat = "$base `,` `[` $shape `]` `,` `[` $strides `]` attr-dict `:` type($base) `,` type($result)";
10601061

10611062
let builders = [
1062-
OpBuilder<(ins "Value":$base, "ValueRange":$shape, "ValueRange":$strides, "ArrayRef<int32_t>":$blockShape, "bool":$isSignedInteger)>
1063+
OpBuilder<(ins "Value":$base, "ValueRange":$shape, "ValueRange":$strides, "ArrayRef<int32_t>":$blockShape, "bool":$isSignedInteger,
1064+
"triton::PaddingOption":$padding)>
10631065
];
10641066

10651067
let extraClassDeclaration = [{

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class AMDRotatingSharedEncodingAttr;
1919
class AMDMfmaEncodingAttr;
2020
class TensorOrMemDesc;
2121
class MemDescType;
22+
class CTALayoutAttr;
2223

2324
// - BlockedEncodingAttrs have the following input dimensions.
2425
//
@@ -126,6 +127,13 @@ LinearLayout chooseScaledMfmaScaleLayout(MLIRContext *ctx, int dotOperandIdx,
126127
ArrayRef<unsigned> tilesPerWarp,
127128
ArrayRef<unsigned> warpsPerCTA);
128129

130+
LinearLayout getSM120DotScaledScaleLayout(MLIRContext *ctx, int dotOperandIdx,
131+
ArrayRef<int64_t> dotOperandShape,
132+
ArrayRef<unsigned> tilesPerWarp,
133+
ArrayRef<unsigned> warpsPerCTA,
134+
unsigned instrM, unsigned instrN,
135+
CTALayoutAttr ctaLayoutAttr);
136+
129137
// Create LinearLayout for nvidia mma tile.
130138
LinearLayout nvidiaMmaTile(MLIRContext *ctx, ArrayRef<unsigned> tileShape,
131139
unsigned kWidth, ArrayRef<unsigned> order,

lib/Analysis/Utility.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -649,6 +649,8 @@ bool supportMMA(triton::DotOp op, int version) {
649649
if (version == 5) {
650650
if (triton::tools::getBoolEnv("DISABLE_MMA_V5"))
651651
return false;
652+
RankedTensorType typeA = op.getA().getType();
653+
int k = typeA.getShape().back();
652654
auto retType = op.getType();
653655
auto retShapePerCTA = getShapePerCTA(retType);
654656
auto rank = retShapePerCTA.size();
@@ -662,8 +664,11 @@ bool supportMMA(triton::DotOp op, int version) {
662664
// Currently only support numWarps 4 or 8 for TMEM load and store.
663665
return false;
664666
}
667+
// If k size is smaller than the native mma size, we cannot use MMA.
668+
if (k < 256 / aElemTy.getIntOrFloatBitWidth())
669+
return false;
665670
if (!(retShapePerCTA[rank - 2] % 64 == 0 &&
666-
retShapePerCTA[rank - 1] % 8 == 0))
671+
retShapePerCTA[rank - 1] % 16 == 0))
667672
return false;
668673
return true;
669674
}
@@ -683,7 +688,7 @@ bool supportMMA(triton::DotOp op, int version) {
683688
if (rank == 3)
684689
return false;
685690
if (!(numWarps % 4 == 0 && retShapePerCTA[rank - 2] % 64 == 0 &&
686-
retShapePerCTA[rank - 1] % 8 == 0 &&
691+
retShapePerCTA[rank - 1] % 16 == 0 &&
687692
(llvm::isa<Float8E5M2Type, Float8E4M3FNType>(aElemTy) ||
688693
aElemTy.isInteger(8) || aElemTy.isF16() || aElemTy.isBF16() ||
689694
aElemTy.isF32()))) {

lib/Conversion/TritonGPUToLLVM/HistogramOpToLLVM.cpp

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ static SmallVector<Value> computeWarpLevelHistogram(
2525
int numBits = llvm::Log2_64(numBins);
2626
int numBitsLaneId = llvm::Log2_64(numThreadPerWarp);
2727
unsigned numElementsPerThreads = getTotalElemsPerThread(srcType);
28-
unsigned numThreadWithUniqueData = getThreadsPerWarp(srcType)[0];
2928
// The histogram is distributed across threads, each thread owns `numBins /
3029
// numThreadPerWarp` bins.
3130
SmallVector<Value> warpLevelHistogram(numBins / numThreadPerWarp, zero);
@@ -43,10 +42,6 @@ static SmallVector<Value> computeWarpLevelHistogram(
4342
numThreadPerWarp == 32 ? 0xFFFFFFFF : 0xFFFFFFFFFFFFFFFF;
4443
Value fullMask = b.int_val(numThreadPerWarp, fullMaskValue);
4544
Value mask = fullMask;
46-
// If not all threads have unique data, mask out the redundant ones.
47-
if (numThreadWithUniqueData < numThreadPerWarp) {
48-
mask = b.int_val(numThreadPerWarp, (1ULL << numThreadWithUniqueData) - 1);
49-
}
5045
for (int i = 0; i < numBitsLaneId; i++) {
5146
Value updateMask =
5247
b.select(b.icmp_ne(b.and_(threadId, b.i32_val(1 << i)), zero),
@@ -96,8 +91,6 @@ static SmallVector<Value> computeCrossWarpHistogram(
9691
Value threadId, int numWarps) {
9792
auto b = TritonLLVMOpBuilder(loc, rewriter);
9893
SmallVector<Value> histogramValues;
99-
unsigned numWarpsWithUniqueData = mlir::triton::gpu::getWarpsPerCTA(
100-
srcType.getEncoding(), srcType.getShape())[0];
10194
Value laneId = b.and_(threadId, b.i32_val(numThreadPerWarp - 1));
10295
// Initialize the shared memory with zeros.
10396
int64_t numElementPerThread =
@@ -112,19 +105,6 @@ static SmallVector<Value> computeCrossWarpHistogram(
112105
}
113106
b.barrier();
114107
Block *afterAtomics = nullptr;
115-
// If some warps have replicated data we need to skip those warps when
116-
// accumulating.
117-
if (numWarpsWithUniqueData < numWarps) {
118-
Block *currentBlock = rewriter.getInsertionBlock();
119-
afterAtomics =
120-
rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
121-
Block *atomicBlock = rewriter.createBlock(afterAtomics);
122-
rewriter.setInsertionPointToEnd(currentBlock);
123-
Value cond = b.icmp_ult(
124-
threadId, b.i32_val(numWarpsWithUniqueData * numThreadPerWarp));
125-
rewriter.create<LLVM::CondBrOp>(loc, cond, atomicBlock, afterAtomics);
126-
rewriter.setInsertionPointToStart(atomicBlock);
127-
}
128108
// Apply atomic add to update the histogram in shared memory.
129109
for (int i = 0; i < warpLevelHistogram.size(); ++i) {
130110
Value warpLevelHistogramValue = warpLevelHistogram[i];
@@ -209,6 +189,24 @@ struct HistogramOpConversion
209189
loc, rewriter, srcType, baseSharedMemPtr, warpLevelHistogram, numBins,
210190
numThreadsPerWarp, innerDimIndices, threadId, numWarps);
211191

192+
// Depending on the layout, some threads may have duplicate data. We can
193+
// account for this by calculating a "replication factor" and dividing the
194+
// results by it to avoid overcounting.
195+
auto replicationFactor = numWarps * numThreadsPerWarp;
196+
auto threadsPerWarp = getThreadsPerWarp(srcType);
197+
auto warpsPerCTA =
198+
getWarpsPerCTA(srcType.getEncoding(), srcType.getShape());
199+
replicationFactor /= std::accumulate(
200+
threadsPerWarp.begin(), threadsPerWarp.end(), 1, std::multiplies<>());
201+
replicationFactor /= std::accumulate(warpsPerCTA.begin(), warpsPerCTA.end(),
202+
1, std::multiplies<>());
203+
204+
auto b = TritonLLVMOpBuilder(loc, rewriter);
205+
for (auto i = 0; i < histogramValue.size(); ++i) {
206+
histogramValue[i] =
207+
b.sdiv(histogramValue[i], b.i32_val(replicationFactor));
208+
}
209+
212210
Value results = packLLElements(loc, typeConverter, histogramValue, rewriter,
213211
op.getType());
214212
rewriter.replaceOp(op, results);

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,8 +1019,8 @@ OpFoldResult AdvanceOp::fold(FoldAdaptor adaptor) {
10191019
//-- MakeTensorDescOp --
10201020
void MakeTensorDescOp::build(OpBuilder &builder, OperationState &state,
10211021
Value base, ValueRange shape, ValueRange strides,
1022-
ArrayRef<int32_t> blockShape,
1023-
bool isSignedInteger) {
1022+
ArrayRef<int32_t> blockShape, bool isSignedInteger,
1023+
triton::PaddingOption padding) {
10241024
auto ptrTy = dyn_cast<triton::PointerType>(base.getType());
10251025
if (!ptrTy) {
10261026
llvm::report_fatal_error("Expected pointer type");
@@ -1030,7 +1030,8 @@ void MakeTensorDescOp::build(OpBuilder &builder, OperationState &state,
10301030
auto blockTy = RankedTensorType::get(blockShape64, elemTy);
10311031
auto descTy =
10321032
TensorDescType::get(builder.getContext(), blockTy, isSignedInteger);
1033-
return build(builder, state, descTy, base, shape, strides);
1033+
auto paddingAttr = PaddingOptionAttr::get(builder.getContext(), padding);
1034+
return build(builder, state, descTy, base, shape, strides, paddingAttr);
10341035
}
10351036

10361037
// The following ops, including `call`, `func`, and `return` are copied and

lib/Dialect/Triton/Transforms/RewriteTensorDescriptorToPointer.cpp

Lines changed: 42 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -59,18 +59,21 @@ struct Descriptor {
5959
Value base;
6060
ValueRange shape;
6161
ValueRange strides;
62+
Value paddingOption;
6263
};
6364

6465
Descriptor unpackDescriptor(TensorDescType type, ValueRange pack) {
6566
int rank = type.getBlockType().getRank();
66-
assert(pack.size() == 1 + 2 * static_cast<size_t>(rank) &&
67+
assert(pack.size() == 1 + 2 * static_cast<size_t>(rank) + 1 &&
6768
"Expected tensor descriptors to consist of a pointer, "
68-
"followed by 'rank' shape values and 'rank' stride values.");
69+
"followed by 'rank' shape values and 'rank' stride values, "
70+
"followed by a padding option value.");
6971

7072
Descriptor res;
7173
res.base = pack[0];
7274
res.shape = pack.slice(1, rank);
7375
res.strides = pack.slice(1 + rank, rank);
76+
res.paddingOption = pack[1 + 2 * rank];
7477
return res;
7578
}
7679

@@ -211,16 +214,30 @@ Value generateMask(OpBuilder &builder, const Location &loc,
211214
}
212215

213216
Value generateOther(OpBuilder &builder, Location loc, Type scalarTy,
214-
ArrayRef<int64_t> blockShape) {
217+
ArrayRef<int64_t> blockShape,
218+
Value paddingOption = nullptr) {
215219
auto blockTy = RankedTensorType::get(blockShape, scalarTy);
216-
auto attr = builder.getZeroAttr(blockTy);
217-
return builder.create<arith::ConstantOp>(loc, attr);
220+
if (paddingOption && mlir::isa<FloatType>(scalarTy)) {
221+
auto floatTy = mlir::cast<FloatType>(scalarTy);
222+
auto nan = llvm::APFloat::getNaN(floatTy.getFloatSemantics());
223+
auto nanValue = builder.create<arith::ConstantOp>(
224+
loc,
225+
SplatElementsAttr::get(blockTy, builder.getFloatAttr(floatTy, nan)));
226+
auto zeroValue = builder.create<arith::ConstantOp>(
227+
loc, SplatElementsAttr::get(blockTy, builder.getZeroAttr(floatTy)));
228+
return builder.create<mlir::arith::SelectOp>(loc, paddingOption, nanValue,
229+
zeroValue);
230+
} else {
231+
auto attr = builder.getZeroAttr(blockTy);
232+
return builder.create<arith::ConstantOp>(loc, attr);
233+
}
218234
}
219235

220-
Value generateOther(OpBuilder &builder, Location loc, TensorDescType descTy) {
236+
Value generateOther(OpBuilder &builder, Location loc, TensorDescType descTy,
237+
Value paddingOption = nullptr) {
221238
auto blockTy = descTy.getSignlessBlockType();
222239
return generateOther(builder, loc, blockTy.getElementType(),
223-
blockTy.getShape());
240+
blockTy.getShape(), paddingOption);
224241
}
225242

226243
SmallVector<mlir::Value> castToI64(OpBuilder &builder,
@@ -237,12 +254,17 @@ struct RewriteMakeTensorDesc : OpConversionPattern<triton::MakeTensorDescOp> {
237254
llvm::LogicalResult
238255
matchAndRewrite(triton::MakeTensorDescOp op, OpAdaptor adaptor,
239256
ConversionPatternRewriter &rewriter) const override {
240-
SmallVector<mlir::Value> ptrShapeStrides;
241-
llvm::append_values(ptrShapeStrides, adaptor.getBase());
242-
llvm::append_range(ptrShapeStrides,
257+
SmallVector<mlir::Value> ptrShapeStridesPaddingOption;
258+
llvm::append_values(ptrShapeStridesPaddingOption, adaptor.getBase());
259+
llvm::append_range(ptrShapeStridesPaddingOption,
243260
castToI64(rewriter, adaptor.getShape()));
244-
llvm::append_range(ptrShapeStrides, adaptor.getStrides());
245-
rewriter.replaceOpWithMultiple(op, {ptrShapeStrides});
261+
llvm::append_range(ptrShapeStridesPaddingOption, adaptor.getStrides());
262+
auto paddingOption = rewriter.create<mlir::arith::ConstantOp>(
263+
op.getLoc(), rewriter.getI1Type(),
264+
rewriter.getBoolAttr(adaptor.getPadding() ==
265+
triton::PaddingOption::PAD_NAN));
266+
llvm::append_values(ptrShapeStridesPaddingOption, paddingOption);
267+
rewriter.replaceOpWithMultiple(op, {ptrShapeStridesPaddingOption});
246268
return mlir::success();
247269
}
248270
};
@@ -258,12 +280,11 @@ struct RewriteLoadPattern : OpConversionPattern<triton::DescriptorLoadOp> {
258280
auto descTy = op.getDesc().getType();
259281
auto desc = unpackDescriptor(descTy, adaptor.getDesc());
260282
auto offsets = castToI64(rewriter, op.getIndices());
261-
283+
auto other = generateOther(rewriter, loc, descTy, desc.paddingOption);
262284
auto newLoad = rewriter.replaceOpWithNewOp<triton::LoadOp>(
263285
op, generatePtr(rewriter, loc, blockShape, desc, offsets),
264-
generateMask(rewriter, loc, blockShape, desc, offsets),
265-
generateOther(rewriter, loc, descTy), triton::CacheModifier::NONE,
266-
triton::EvictionPolicy::NORMAL, false);
286+
generateMask(rewriter, loc, blockShape, desc, offsets), other,
287+
triton::CacheModifier::NONE, triton::EvictionPolicy::NORMAL, false);
267288
newLoad->setAttrs(filterSegmentSizes(op->getAttrs()));
268289

269290
return llvm::success();
@@ -327,7 +348,7 @@ struct RewriteGatherPattern : OpConversionPattern<triton::DescriptorGatherOp> {
327348
rewriter, loc, blockShape, desc, op.getXOffsets(), op.getYOffset());
328349
auto other = generateOther(rewriter, loc,
329350
descTy.getSignlessBlockType().getElementType(),
330-
blockShape);
351+
blockShape, desc.paddingOption);
331352
auto newLoad = rewriter.replaceOpWithNewOp<triton::LoadOp>(
332353
op, ptr, mask, other, triton::CacheModifier::NONE,
333354
triton::EvictionPolicy::NORMAL, false);
@@ -471,13 +492,14 @@ class TritonRewriteTensorDescriptorToPointerPass
471492
converter.addConversion([](mlir::triton::TensorDescType t,
472493
llvm::SmallVectorImpl<mlir::Type> &out) {
473494
// We convert a tensor descriptor into an pointer, and a shape and stride
474-
// for each dimension, i.e., we create 1+2*rank values. Note that tensor
475-
// descriptors may be signed/unsigned integers whereas pointers should
476-
// always be signless.
495+
// for each dimension, and padding option. i.e., we create 1+2*rank+1
496+
// values. Note that tensor descriptors may be signed/unsigned integers
497+
// whereas pointers should always be signless.
477498
auto tensorType = t.getSignlessBlockType();
478499
out.push_back(triton::getPointerType(tensorType.getElementType()));
479500
out.insert(out.end(), 2 * tensorType.getRank(),
480501
mlir::IntegerType::get(t.getContext(), 64));
502+
out.push_back(mlir::IntegerType::get(t.getContext(), 1));
481503
return mlir::success();
482504
});
483505

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1339,12 +1339,6 @@ AMDWmmaEncodingAttr::verify(function_ref<mlir::InFlightDiagnostic()> emitError,
13391339
if (version != 1 && version != 2) {
13401340
return emitError() << "WMMA version must be in the [1, 2] range";
13411341
}
1342-
// Transposed layout is needed for bypassing LDS between multiple dots.
1343-
// Version 1 tt.dot results and tt.dot operand layouts are different,
1344-
// therefore we test and support transposed only for version 2.
1345-
if (version != 2 && isTransposed) {
1346-
return emitError() << "Transposed WMMA is supported only for version 2";
1347-
}
13481342
return success();
13491343
}
13501344

@@ -2125,10 +2119,10 @@ LogicalResult DotOperandEncodingAttr::verify(
21252119
}
21262120

21272121
if (auto parentAttr = mlir::dyn_cast<AMDWmmaEncodingAttr>(parent)) {
2128-
if (kWidth != 16 && parentAttr.getVersion() == 1 ||
2122+
if (kWidth != 8 && kWidth != 16 && parentAttr.getVersion() == 1 ||
21292123
kWidth != 4 && kWidth != 8 && kWidth != 16 &&
21302124
parentAttr.getVersion() == 2)
2131-
return emitError() << "ttg.dot_op kWidth parameter must be 16 for "
2125+
return emitError() << "ttg.dot_op kWidth parameter must be 8/16 for "
21322126
"gfx11 and 4/8/16 for gfx12 (including packed "
21332127
"cases for `scaled_dot`)";
21342128
return success();

0 commit comments

Comments
 (0)