Skip to content

Commit 1ddca06

Browse files
Merge OpenAI Triton commit 0a8e3cc (#3565)
This PR change the Triton base from 63cecbd to 0a8e3cc (Feb 24). Pass rate: 97.65%->89.74% (#3307) Please do not squash and merge this PR.
2 parents 98909c3 + 61625fc commit 1ddca06

36 files changed

+2837
-931
lines changed

bin/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ target_link_libraries(triton-opt PRIVATE
1414
# tests
1515
TritonTestAnalysis
1616
TritonTestDialectTritonGPU
17+
TritonAMDGPUTestAnalysis
1718
# MLIR core
1819
MLIROptLib
1920
MLIRPass
@@ -33,6 +34,7 @@ target_link_libraries(triton-reduce PRIVATE
3334
# tests
3435
TritonTestAnalysis
3536
TritonTestDialectTritonGPU
37+
TritonAMDGPUTestAnalysis
3638
# MLIR core
3739
MLIRReduceLib
3840
MLIRPass
@@ -51,6 +53,7 @@ target_link_libraries(triton-lsp PRIVATE
5153
# tests
5254
TritonTestAnalysis
5355
TritonTestDialectTritonGPU
56+
TritonAMDGPUTestAnalysis
5457
# MLIR core
5558
MLIRLspServerLib
5659
MLIRPass
@@ -89,4 +92,5 @@ target_link_libraries(triton-tensor-layout PRIVATE
8992
${dialect_libs}
9093
TritonTestAnalysis
9194
TritonTestDialectTritonGPU
95+
TritonAMDGPUTestAnalysis
9296
)

bin/RegisterTritonDialects.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ void registerTestAlignmentPass();
4848
void registerTestAllocationPass();
4949
void registerTestLivenessPass();
5050
void registerTestMembarPass();
51+
void registerTestTritonAMDGPURangeAnalysis();
5152
} // namespace test
5253
} // namespace mlir
5354

@@ -62,6 +63,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
6263
mlir::test::registerTestAllocationPass();
6364
mlir::test::registerTestLivenessPass();
6465
mlir::test::registerTestMembarPass();
66+
mlir::test::registerTestTritonAMDGPURangeAnalysis();
6567
mlir::triton::registerConvertTritonToTritonGPUPass();
6668
mlir::triton::intel::registerConvertTritonToTritonGPUWarpPass();
6769
mlir::triton::intel::registerTritonIntelRemoveMasks();

include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
101101
if (!axisInfo)
102102
// axis info (e.g., constancy) not available
103103
return resultVals;
104-
SmallVector<unsigned> contigPerThread = getContigPerThread(encoding);
104+
SmallVector<unsigned> contigPerThread = getContigPerThread(rtType);
105105
if (rank != contigPerThread.size())
106106
return resultVals;
107107

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -104,20 +104,13 @@ SmallVector<unsigned> getWarpsPerCTA(Attribute layout);
104104

105105
SmallVector<unsigned> getSizePerThread(Attribute layout);
106106

107-
// Returns the number of contiguous elements that each thread
108-
// has access to, on each dimension of the tensor. E.g.
109-
// for a blocked layout with sizePerThread = [1, 4], returns [1, 4],
110-
// regardless of the shape of the tensor.
111-
SmallVector<unsigned> getContigPerThread(Attribute layout);
112-
113-
// Returns the number of non-replicated contiguous elements that each thread
114-
// has access to, on each dimension of the tensor. For a blocked layout
107+
// Returns the number of contiguous elements of the logical tensor that each
108+
// thread has access to, on each dimension of the tensor. For a blocked layout
115109
// with sizePerThread = [1, 4] and tensor shape = [128, 1], the elements
116110
// for thread 0 would be [A_{0, 0}, A_{0, 0}, A_{0, 0}, A_{0, 0}], returns [1,
117111
// 1]. Whereas for a tensor shape [128, 128], the elements for thread 0 would be
118112
// [A_{0, 0}, A_{0, 1}, A_{0, 2}, A_{0, 3}], returns [1, 4].
119-
SmallVector<unsigned> getUniqueContigPerThread(Attribute layout,
120-
ArrayRef<int64_t> tensorShape);
113+
SmallVector<unsigned> getContigPerThread(RankedTensorType tensorType);
121114

122115
// Returns the number of threads per warp that have access to non-replicated
123116
// elements of the tensor. E.g. for a blocked layout with sizePerThread = [1,

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,14 @@
88

99
#include "triton/Tools/LinearLayout.h"
1010

11+
namespace mlir::triton {
12+
enum class ScaleDotElemType : uint32_t;
13+
} // namespace mlir::triton
14+
1115
namespace mlir::triton::gpu {
1216
class SwizzledSharedEncodingAttr;
1317
class NVMMASharedEncodingAttr;
18+
class AMDMfmaEncodingAttr;
1419

1520
// - BlockedEncodingAttrs have the following input dimensions.
1621
//
@@ -261,6 +266,20 @@ LinearLayout chooseLdMatrixLayout(Attribute enc, ArrayRef<int64_t> shape,
261266
// tensor from shared memory using the `ds_read_tr` instruction for AMD GPUs.
262267
LinearLayout chooseDsReadB64Tr16Layout(Attribute enc, ArrayRef<int64_t> shape,
263268
int32_t elemBitWidth);
269+
270+
// Create LinearLayout for mxfp4 and mxfp8 operand in scaled mfma.
271+
// For mxfp4, we use dot layout directly. Mxfp8 is not covered by dot
272+
// layout, so we need to manually create linear layout for it.
273+
LinearLayout
274+
chooseScaledMfmaOperandLayout(AMDMfmaEncodingAttr mfmaEnc, int kWidth,
275+
int dotOperandIdx, ScaleDotElemType elemType,
276+
llvm::ArrayRef<int64_t> dotOperandShape);
277+
278+
// Create LinearLayout for scale in scaled mfma.
279+
LinearLayout chooseScaledMfmaScaleLayout(
280+
MLIRContext *ctx, int dotOperandIdx,
281+
const std::vector<std::vector<int32_t>> &dotOperandWarpBasis,
282+
ArrayRef<int64_t> dotOperandShape, unsigned mfmaMDim);
264283
} // namespace mlir::triton::gpu
265284

266285
#endif // TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 2 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -532,10 +532,6 @@ We call each individual tile "rep".
532532
InterfaceMethod<"Get the shape of the values per thread.",
533533
"SmallVector<unsigned>",
534534
"getSizePerThread">,
535-
536-
InterfaceMethod<"Gets the number of contiguous elements per thread.",
537-
"SmallVector<unsigned>",
538-
"getContigPerThread">,
539535
InterfaceMethod<"Convert to LinearLayout.",
540536
"LinearLayout",
541537
"toLinearLayout",
@@ -819,12 +815,7 @@ for
819815
}]>
820816
];
821817

822-
let extraClassDeclaration = extraDistributedDeclaration # [{
823-
SmallVector<unsigned> getContigPerThread() {
824-
// Block encoding is dense stride layout. The elements per thread are contiguous.
825-
return getSizePerThread();
826-
};
827-
}];
818+
let extraClassDeclaration = extraDistributedDeclaration;
828819

829820
let hasCustomAssemblyFormat = 1;
830821
}
@@ -972,17 +963,6 @@ V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,
972963
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> operandShape, int kWidth, int opIdx) const;
973964
SmallVector<unsigned> getRepOrderForOperand(int opIdx) const;
974965
SmallVector<unsigned> getThreadsPerWarpForOperand(int opIdx) const;
975-
976-
SmallVector<unsigned> getContigPerThread() {
977-
auto rank = getWarpsPerCTA().size();
978-
SmallVector<unsigned> contigPerThread(rank, 1);
979-
if (getIsTransposed())
980-
contigPerThread[rank - 1] = 4;
981-
else
982-
contigPerThread[rank - 2] = 4;
983-
return contigPerThread;
984-
};
985-
986966
}];
987967

988968
let genVerifyDecl = 1;
@@ -1100,16 +1080,6 @@ Row |
11001080
SmallVector<unsigned> getRepOrderForOperand(int opIdx) const;
11011081
SmallVector<unsigned> getThreadsPerWarpForOperand(int opIdx) const;
11021082
static SmallVector<unsigned> getMNKDimPerInstr();
1103-
1104-
SmallVector<unsigned> getContigPerThread() {
1105-
auto rank = getWarpsPerCTA().size();
1106-
assert(rank == 2 || rank == 3);
1107-
SmallVector<unsigned> contigPerThread(rank, 1);
1108-
if (getVersion() == 2) {
1109-
contigPerThread[rank - 2] = 8;
1110-
}
1111-
return contigPerThread;
1112-
};
11131083
}];
11141084
}
11151085

@@ -1219,15 +1189,6 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
12191189
SmallVector<unsigned> getRepOrderForOperand(int opIdx) const;
12201190
SmallVector<unsigned> getThreadsPerWarpForOperand(int opIdx) const;
12211191
SmallVector<unsigned> getSizePerThreadForOperand(int kWidth, int opIdx) const;
1222-
1223-
SmallVector<unsigned> getContigPerThread() {
1224-
assert(isAmpere() || isHopper());
1225-
auto rank = getWarpsPerCTA().size();
1226-
SmallVector<unsigned> contigPerThread(rank, 1);
1227-
contigPerThread[rank - 1] = 2;
1228-
return contigPerThread;
1229-
};
1230-
12311192
}];
12321193

12331194
let hasCustomAssemblyFormat = 1;
@@ -1274,13 +1235,6 @@ def SliceEncodingAttr : DistributedEncoding<"SliceEncoding", "slice_encoding"> {
12741235
let extraClassDeclaration = extraDistributedDeclaration # [{
12751236
template<class T>
12761237
SmallVector<T> paddedShape(ArrayRef<T> shape) const;
1277-
1278-
SmallVector<unsigned> getContigPerThread() {
1279-
auto parentLayout = mlir::cast<DistributedEncodingTrait>(getParent());
1280-
auto parentContigPerThread = parentLayout.getContigPerThread();
1281-
parentContigPerThread.erase(parentContigPerThread.begin() + getDim());
1282-
return parentContigPerThread;
1283-
};
12841238
}];
12851239

12861240
let hasCustomAssemblyFormat = 1;
@@ -1348,20 +1302,7 @@ vecIdx (index of the element in the quad; this is always along the k-dim)
13481302

13491303
let assemblyFormat = "`<` `{` struct(params) `}` `>`";
13501304
let genVerifyDecl = 1;
1351-
let extraClassDeclaration = extraDistributedDeclaration # [{
1352-
SmallVector<unsigned> getContigPerThread() {
1353-
auto rank = getWarpsPerCTA().size();
1354-
assert(rank == 2 || rank == 3);
1355-
SmallVector<unsigned> contigPerThread(rank, 1);
1356-
auto kWidth = getKWidth();
1357-
assert(kWidth != 0 && "Do not support kWidth=0");
1358-
if (getOpIdx() == 0)
1359-
contigPerThread[rank - 1] = kWidth;
1360-
else
1361-
contigPerThread[rank - 2] = kWidth;
1362-
return contigPerThread;
1363-
};
1364-
}];
1305+
let extraClassDeclaration = extraDistributedDeclaration;
13651306
}
13661307

13671308
def TTG_SharedMemorySpace : AttrDef<TritonGPU_Dialect, "SharedMemorySpace"> {

lib/Analysis/Utility.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -730,7 +730,6 @@ bool matchMFMAAndDotOperandShuffleCase(RankedTensorType srcTy,
730730
return dotOperandLayout.getParent() == mfmaLayout &&
731731
dotOperandLayout.getOpIdx() == 0 && mfmaLayout.getIsTransposed() &&
732732
dotOperandLayout.getKWidth() == 8 &&
733-
getContigPerThread(mfmaLayout)[1] == 4 &&
734733
((mfmaLayout.getMDim() == 16 && mfmaLayout.getNDim() == 16) ||
735734
(mfmaLayout.getMDim() == 32 && mfmaLayout.getNDim() == 32)) &&
736735
triton::type::isFloat8(srcTy.getElementType()) &&

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -116,19 +116,11 @@ SmallVector<unsigned> getSizePerThread(Attribute layout) {
116116
}
117117
}
118118

119-
SmallVector<unsigned> getContigPerThread(Attribute layout) {
120-
if (auto distributedLayout = dyn_cast<DistributedEncodingTrait>(layout)) {
121-
return distributedLayout.getContigPerThread();
122-
} else {
123-
llvm::report_fatal_error("getContigPerThread not implemented");
124-
return {};
125-
}
126-
}
127-
128-
SmallVector<unsigned> getUniqueContigPerThread(Attribute layout,
129-
ArrayRef<int64_t> shape) {
119+
SmallVector<unsigned> getContigPerThread(RankedTensorType tensorType) {
120+
auto layout = tensorType.getEncoding();
121+
auto shape = tensorType.getShape();
130122
auto linearLayout = toLinearLayout(shape, layout);
131-
auto llAttr = LinearEncodingAttr::get(layout.getContext(), linearLayout);
123+
auto llAttr = LinearEncodingAttr::get(tensorType.getContext(), linearLayout);
132124
return llAttr.getContigPerThread();
133125
}
134126

0 commit comments

Comments
 (0)