Skip to content

Commit 24caa77

Browse files
Merge OpenAI Triton commit 0e94b6c (#4392)
This PR change the Triton base from 75fe113 to 0e94b6c (May 26). Pass rate: 97.23% -> 97.23%
2 parents 40fd289 + cabb9f8 commit 24caa77

File tree

76 files changed

+1020
-740
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

76 files changed

+1020
-740
lines changed

bin/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ target_link_libraries(triton-opt PRIVATE
1313
${triton_libs}
1414
# tests
1515
TritonTestAnalysis
16+
TritonTestDialect
1617
TritonAMDGPUTestAnalysis
1718
# MLIR core
1819
MLIROptLib
@@ -32,6 +33,7 @@ target_link_libraries(triton-reduce PRIVATE
3233
${triton_libs}
3334
# tests
3435
TritonTestAnalysis
36+
TritonTestDialect
3537
TritonAMDGPUTestAnalysis
3638
# MLIR core
3739
MLIRReduceLib
@@ -50,6 +52,7 @@ target_link_libraries(triton-lsp PRIVATE
5052
${triton_libs}
5153
# tests
5254
TritonTestAnalysis
55+
TritonTestDialect
5356
TritonAMDGPUTestAnalysis
5457
# MLIR core
5558
MLIRLspServerLib
@@ -88,5 +91,6 @@ target_link_libraries(triton-tensor-layout PRIVATE
8891
${conversion_libs}
8992
${dialect_libs}
9093
TritonTestAnalysis
94+
TritonTestDialect
9195
TritonAMDGPUTestAnalysis
9296
)

bin/RegisterTritonDialects.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ void registerTestLivenessPass();
5454
void registerTestMembarPass();
5555
void registerTestAMDGPUMembarPass();
5656
void registerTestTritonAMDGPURangeAnalysis();
57+
void registerTestLoopPeelingPass();
5758
} // namespace test
5859
} // namespace mlir
5960

@@ -68,6 +69,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
6869
mlir::test::registerTestAllocationPass();
6970
mlir::test::registerTestLivenessPass();
7071
mlir::test::registerTestMembarPass();
72+
mlir::test::registerTestLoopPeelingPass();
7173
mlir::test::registerTestAMDGPUMembarPass();
7274
mlir::test::registerTestTritonAMDGPURangeAnalysis();
7375
mlir::triton::registerConvertTritonToTritonGPUPass();
@@ -113,7 +115,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
113115
mlir::registerTritonAMDFoldTrueCmpI();
114116

115117
// NVWS passes
116-
mlir::registerNVWSTransformsPasses();
118+
mlir::triton::registerNVWSTransformsPasses();
117119

118120
// NVGPU transform passes
119121
mlir::registerNVHopperTransformsPasses();

include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,8 @@ class TargetInfoBase {
9494

9595
virtual bool supportVectorizedAtomics() const = 0;
9696

97-
virtual bool supportLdStMatrix() const = 0;
97+
virtual bool supportLdMatrix() const { return false; }
98+
virtual bool supportStMatrix() const { return false; }
9899

99100
// Annotate target specific information to local store operations during
100101
// lowering to LLVM.

include/triton/Dialect/Triton/IR/Utility.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,9 @@ template <typename T> auto seq(T start, T end, T step) {
177177
Value getPredMask(RewriterBase &rewriter, Type typeLike, Value currentMask,
178178
Value pred);
179179

180+
// Get the value of the induction variable at the end of the loop.
181+
Value getLastInductionValue(OpBuilder &b, scf::ForOp loop);
182+
180183
MakeTensorPtrOp getMakeTensorPtrOp(Value v);
181184

182185
} // namespace triton
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#ifndef TRITON_DIALECT_TRITON_TRANSFORMS_LOOP_PEELING_H_
2+
#define TRITON_DIALECT_TRITON_TRANSFORMS_LOOP_PEELING_H_
3+
4+
#include "mlir/Dialect/SCF/IR/SCF.h"
5+
6+
namespace mlir {
7+
namespace triton {
8+
9+
// Peel the single last iteration of the loop.
10+
void peelLoopEpilogue(
11+
scf::ForOp forOp,
12+
function_ref<Operation *(RewriterBase &, Operation *, bool)>
13+
processPeeledOp = nullptr);
14+
15+
} // namespace triton
16+
} // namespace mlir
17+
18+
#endif // TRITON_DIALECT_TRITON_TRANSFORMS_LOOP_PEELING_H_

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,24 @@ def TTG_PredicateStageOp: TTG_Op<"predicate_stage",
337337
let assemblyFormat = "$iv `,` $ub `,` $step `maxStage` $maxStage `stage` $stage attr-dict `:` type($iv) `->` type($result)";
338338
}
339339

340+
def TTG_MaskOp: TTG_Op<"mask",
341+
[SingleBlock]> {
342+
let summary = "mask op for pipelining";
343+
let arguments = (ins I1:$pred);
344+
let results = (outs Variadic<AnyType>:$result);
345+
let regions = (region SizedRegion<1>:$region);
346+
let builders = [
347+
OpBuilder<(ins "Value":$pred)>,
348+
];
349+
}
350+
351+
def TTG_MaskReturnOp: TTG_Op<"mask.return",
352+
[HasParent<"MaskOp">, Pure, Terminator, ReturnLike]> {
353+
let summary = "terminator for mask operator";
354+
let arguments = (ins Variadic<AnyType>:$result);
355+
let assemblyFormat = "$result attr-dict `:` type($result)";
356+
}
357+
340358
def TTG_Fp4ToFpOp : TTG_Op<"fp4_to_fp", [Pure]> {
341359
let summary = "Upcast fp4 (e2m1) to fp";
342360

include/triton/Target/LLVMIR/Passes.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55

66
namespace mlir {
77

8-
/// Create a pass to add DIScope
9-
std::unique_ptr<Pass> createLLVMDIScopePass();
8+
// Generate the pass class declarations.
9+
#define GEN_PASS_DECL
10+
#include "triton/Target/LLVMIR/Passes.h.inc"
1011

11-
/// Generate the code for registering conversion passes.
12+
// Generate the code for registering conversion passes.
1213
#define GEN_PASS_REGISTRATION
1314
#include "triton/Target/LLVMIR/Passes.h.inc"
1415

include/triton/Target/LLVMIR/Passes.td

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@ def LLVMDIScope: Pass<"enable-line-info", "mlir::ModuleOp"> {
88
let description = [{
99
This pass materializes line mapping information for LLVM IR dialect operations.
1010
}];
11-
12-
let constructor = "mlir::createLLVMDIScopePass()";
1311
}
1412

1513
#endif

include/triton/Tools/LayoutUtils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,11 @@ std::optional<ColumnAction> regPermForDivideLeft(const LinearLayout &A,
125125
// such that action.apply(A) has the broadcasted registers removed
126126
ColumnAction actionRemoveBroadcastedRegs(const LinearLayout &layout);
127127

128+
// For a layout A with A.hasInDim(kReg), repeat the values so that they have
129+
// the same broadcasting as layout
130+
SmallVector<Value> broadcastAs(const SmallVector<Value> &values,
131+
const LinearLayout &layout);
132+
128133
// Compute the supremum of two lists.
129134
// Error out if the supremum does not exist (e.g. [a, b] and [b, a]).
130135
// If the supremum is not unique, we return the first list first

include/triton/Tools/LinearLayout.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -825,6 +825,9 @@ class ColumnAction {
825825
// [x[0], x[4], x[1], x[5], x[2], x[6], x[3], x[7]]
826826
SmallVector<Value> apply(ValueRange values) const;
827827

828+
// Inverse of the action
829+
ColumnAction inverse() const;
830+
828831
std::string toString() const;
829832
};
830833

0 commit comments

Comments
 (0)