Skip to content

Commit b5fe971

Browse files
Merge OpenAI Triton commit 9f88c7f (#4425)
This PR change the Triton base from d25fc5f to 9f88c7f (May 29). Pass rate: 97.23% Please do not squash and merge this PR.
2 parents f2314b9 + 0d8ae42 commit b5fe971

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+1774
-1025
lines changed

.github/workflows/build-macos.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ jobs:
1313
strategy:
1414
matrix:
1515
runner: ${{ fromJson(inputs.matrix) }}
16-
timeout-minutes: 40
16+
timeout-minutes: 60
1717
env:
1818
RUNNER_TYPE: ${{ matrix.runner[0] }}
1919
name: Build MacOS

bin/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ target_link_libraries(triton-opt PRIVATE
1313
${triton_libs}
1414
# tests
1515
TritonTestAnalysis
16+
TritonTestDialect
1617
TritonAMDGPUTestAnalysis
1718
# MLIR core
1819
MLIROptLib
@@ -32,6 +33,7 @@ target_link_libraries(triton-reduce PRIVATE
3233
${triton_libs}
3334
# tests
3435
TritonTestAnalysis
36+
TritonTestDialect
3537
TritonAMDGPUTestAnalysis
3638
# MLIR core
3739
MLIRReduceLib
@@ -50,6 +52,7 @@ target_link_libraries(triton-lsp PRIVATE
5052
${triton_libs}
5153
# tests
5254
TritonTestAnalysis
55+
TritonTestDialect
5356
TritonAMDGPUTestAnalysis
5457
# MLIR core
5558
MLIRLspServerLib
@@ -88,5 +91,6 @@ target_link_libraries(triton-tensor-layout PRIVATE
8891
${conversion_libs}
8992
${dialect_libs}
9093
TritonTestAnalysis
94+
TritonTestDialect
9195
TritonAMDGPUTestAnalysis
9296
)

bin/RegisterTritonDialects.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ void registerTestLivenessPass();
5454
void registerTestMembarPass();
5555
void registerTestAMDGPUMembarPass();
5656
void registerTestTritonAMDGPURangeAnalysis();
57+
void registerTestLoopPeelingPass();
5758
} // namespace test
5859
} // namespace mlir
5960

@@ -68,6 +69,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
6869
mlir::test::registerTestAllocationPass();
6970
mlir::test::registerTestLivenessPass();
7071
mlir::test::registerTestMembarPass();
72+
mlir::test::registerTestLoopPeelingPass();
7173
mlir::test::registerTestAMDGPUMembarPass();
7274
mlir::test::registerTestTritonAMDGPURangeAnalysis();
7375
mlir::triton::registerConvertTritonToTritonGPUPass();

include/triton/Dialect/Triton/IR/Utility.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,9 @@ template <typename T> auto seq(T start, T end, T step) {
177177
Value getPredMask(RewriterBase &rewriter, Type typeLike, Value currentMask,
178178
Value pred);
179179

180+
// Get the value of the induction variable at the end of the loop.
181+
Value getLastInductionValue(OpBuilder &b, scf::ForOp loop);
182+
180183
MakeTensorPtrOp getMakeTensorPtrOp(Value v);
181184

182185
} // namespace triton
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#ifndef TRITON_DIALECT_TRITON_TRANSFORMS_LOOP_PEELING_H_
2+
#define TRITON_DIALECT_TRITON_TRANSFORMS_LOOP_PEELING_H_
3+
4+
#include "mlir/Dialect/SCF/IR/SCF.h"
5+
6+
namespace mlir {
7+
namespace triton {
8+
9+
// Peel the single last iteration of the loop.
10+
void peelLoopEpilogue(
11+
scf::ForOp forOp,
12+
function_ref<Operation *(RewriterBase &, Operation *, bool)>
13+
processPeeledOp = nullptr);
14+
15+
} // namespace triton
16+
} // namespace mlir
17+
18+
#endif // TRITON_DIALECT_TRITON_TRANSFORMS_LOOP_PEELING_H_

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,11 @@ def NVMMASharedEncodingAttr :
437437
} else {
438438
swizzlingByteWidth = 0;
439439
}
440-
if (shapePerCTA.size() < 2 || shapePerCTA[order[1]] < 8) {
440+
int flattenOutterDim = 1;
441+
for (int i = 1; i < shapePerCTA.size(); i++) {
442+
flattenOutterDim *= shapePerCTA[order[i]];
443+
}
444+
if (shapePerCTA.size() < 2 || flattenOutterDim < 8) {
441445
swizzlingByteWidth = 0;
442446
}
443447
bool transposed = order[0] == 0;

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,24 @@ def TTG_PredicateStageOp: TTG_Op<"predicate_stage",
337337
let assemblyFormat = "$iv `,` $ub `,` $step `maxStage` $maxStage `stage` $stage attr-dict `:` type($iv) `->` type($result)";
338338
}
339339

340+
def TTG_MaskOp: TTG_Op<"mask",
341+
[SingleBlock]> {
342+
let summary = "mask op for pipelining";
343+
let arguments = (ins I1:$pred);
344+
let results = (outs Variadic<AnyType>:$result);
345+
let regions = (region SizedRegion<1>:$region);
346+
let builders = [
347+
OpBuilder<(ins "Value":$pred)>,
348+
];
349+
}
350+
351+
def TTG_MaskReturnOp: TTG_Op<"mask.return",
352+
[HasParent<"MaskOp">, Pure, Terminator, ReturnLike]> {
353+
let summary = "terminator for mask operator";
354+
let arguments = (ins Variadic<AnyType>:$result);
355+
let assemblyFormat = "$result attr-dict `:` type($result)";
356+
}
357+
340358
def TTG_Fp4ToFpOp : TTG_Op<"fp4_to_fp", [Pure]> {
341359
let summary = "Upcast fp4 (e2m1) to fp";
342360

@@ -450,7 +468,9 @@ def TTG_WarpSpecializeOp : TTG_Op<"warp_specialize", [
450468
let builders = [
451469
OpBuilder<(ins "TypeRange":$resultTypes,
452470
"ArrayRef<int32_t>":$partitionNumWarps,
453-
"unsigned":$numPartitionRegions)>
471+
"unsigned":$numPartitionRegions)>,
472+
OpBuilder<(ins "TypeRange":$resultTypes, "ValueRange":$explicitCaptures,
473+
"ArrayRef<int32_t>":$partitionNumWarps)>,
454474
];
455475

456476
let hasVerifier = 1;

lib/Dialect/Triton/IR/Utility.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "triton/Dialect/Triton/IR/Utility.h"
22
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
3+
#include "mlir/Dialect/SCF/IR/SCF.h"
34
#include "triton/Dialect/Triton/IR/Dialect.h"
45

56
using namespace mlir;
@@ -90,3 +91,15 @@ tt::MakeTensorPtrOp tt::getMakeTensorPtrOp(Value v) {
9091
}
9192
llvm_unreachable("Unable to getMakeTensorPtr()");
9293
}
94+
95+
Value tt::getLastInductionValue(OpBuilder &b, scf::ForOp loop) {
96+
Location loc = loop.getLoc();
97+
// (ub - lb -1) // step * step + lb
98+
Value diff =
99+
b.create<arith::SubIOp>(loc, loop.getUpperBound(), loop.getLowerBound());
100+
diff = b.create<arith::SubIOp>(
101+
loc, diff, b.create<arith::ConstantOp>(loc, b.getI32IntegerAttr(1)));
102+
Value ceilStep = b.create<arith::MulIOp>(
103+
loc, b.create<arith::DivSIOp>(loc, diff, loop.getStep()), loop.getStep());
104+
return b.create<arith::AddIOp>(loc, ceilStep, loop.getLowerBound());
105+
}

lib/Dialect/Triton/Transforms/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ add_triton_library(TritonTransforms
66
Combine.cpp
77
LoopAwareCSE.cpp
88
LoopInvariantCodeMotion.cpp
9+
LoopPeeling.cpp
910
LoopUnroll.cpp
1011
ReorderBroadcast.cpp
1112
RewriteTensorPointer.cpp
@@ -20,5 +21,7 @@ add_triton_library(TritonTransforms
2021
LINK_LIBS PUBLIC
2122
MLIRPass
2223
MLIRTransformUtils
24+
MLIRTransforms
25+
MLIRSCFToControlFlow
2326
TritonIR
2427
)
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
#include "triton/Dialect/Triton/Transforms/LoopPeeling.h"
2+
#include "mlir/Dialect/SCF/IR/SCF.h"
3+
#include "mlir/Pass/Pass.h"
4+
#include "triton/Dialect/Triton/IR/Utility.h"
5+
6+
using namespace mlir;
7+
8+
namespace mlir {
9+
namespace triton {
10+
11+
void peelLoopEpilogue(
12+
scf::ForOp forOp,
13+
function_ref<Operation *(RewriterBase &, Operation *, bool)>
14+
processPeeledOp) {
15+
SmallVector<Operation *> loopBodyOps;
16+
IRRewriter rewriter(forOp);
17+
Location loc = forOp.getLoc();
18+
Type type = forOp.getStep().getType();
19+
20+
// Fetch loop bounds and step
21+
Value lowerBound = forOp.getLowerBound();
22+
Value upperBound = forOp.getUpperBound();
23+
Value step = forOp.getStep();
24+
Value newUpperBound = rewriter.create<arith::SubIOp>(loc, upperBound, step);
25+
26+
rewriter.setInsertionPointAfter(forOp);
27+
Value lastIV = getLastInductionValue(rewriter, forOp);
28+
29+
auto cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
30+
lowerBound, upperBound);
31+
32+
// Create an if op to execute the peeled iteration
33+
IRMapping map;
34+
map.map(forOp.getRegionIterArgs(), forOp.getResults());
35+
map.map(forOp.getInductionVar(), lastIV);
36+
auto ifOp = rewriter.create<scf::IfOp>(loc, forOp.getResultTypes(), cond,
37+
/*hasElse=*/true);
38+
ifOp.getThenRegion().front().erase();
39+
forOp.getBodyRegion().cloneInto(&ifOp.getThenRegion(), map);
40+
rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
41+
rewriter.create<scf::YieldOp>(loc, forOp.getResults());
42+
43+
forOp->replaceUsesWithIf(ifOp, [&](OpOperand &operand) {
44+
return !ifOp->isAncestor(operand.getOwner());
45+
});
46+
47+
forOp.getUpperBoundMutable().assign(newUpperBound);
48+
49+
if (processPeeledOp) {
50+
for (auto &op :
51+
llvm::make_early_inc_range(forOp.getBody()->without_terminator())) {
52+
Operation *newOp = processPeeledOp(rewriter, &op, /*isEpilogue=*/false);
53+
if (newOp && newOp != &op) {
54+
op.replaceAllUsesWith(newOp);
55+
}
56+
}
57+
for (auto &op : llvm::make_early_inc_range(
58+
ifOp.getThenRegion().front().without_terminator())) {
59+
Operation *newOp = processPeeledOp(rewriter, &op, /*isEpilogue=*/true);
60+
if (newOp && newOp != &op) {
61+
op.replaceAllUsesWith(newOp);
62+
}
63+
}
64+
}
65+
}
66+
67+
} // namespace triton
68+
} // namespace mlir

0 commit comments

Comments
 (0)