Skip to content

Commit aea1468

Browse files
authored
Fix flaky build (#4082)
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent cb6997a commit aea1468

File tree

7 files changed

+35
-19
lines changed

7 files changed

+35
-19
lines changed
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#ifndef TRITON_TRANSFORMS_UTILITY_H
2+
#define TRITON_TRANSFORMS_UTILITY_H
3+
4+
#include "triton/Dialect/Triton/IR/Dialect.h"
5+
6+
using namespace mlir;
7+
8+
Value getPredMask(RewriterBase &rewriter, Type typeLike, Value currentMask,
9+
Value pred);
10+
11+
#endif // TRITON_TRANSFORMS_UTILITY_H

include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,6 @@ Value sinkValueRedefinition(RewriterBase &rewriter, Value in, Value out,
6464
bool loopHasDistGreaterThanOne(scf::ForOp forOp);
6565
bool isOuterLoop(scf::ForOp forOp);
6666

67-
Value getPredMask(RewriterBase &rewriter, Type typeLike, Value currentMask,
68-
Value pred);
69-
7067
/// Function to mask operations during scheduling.
7168
Operation *predicateOp(RewriterBase &rewriter, Operation *op, Value pred);
7269

lib/Dialect/Triton/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ add_triton_library(TritonTransforms
88
LoopUnroll.cpp
99
ReorderBroadcast.cpp
1010
RewriteTensorPointer.cpp
11+
Utility.cpp
1112

1213
DEPENDS
1314
TritonTransformsIncGen

lib/Dialect/Triton/Transforms/LoopInvariantCodeMotion.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
22
#include "triton/Dialect/Triton/IR/Dialect.h"
33
#include "triton/Dialect/Triton/Transforms/Passes.h"
4-
#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h"
4+
#include "triton/Dialect/Triton/Transforms/Utility.h"
55
#include "llvm/Support/Debug.h"
66

77
#define GEN_PASS_CLASSES
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#include "triton/Dialect/Triton/Transforms/Utility.h"
2+
#include "triton/Dialect/Triton/IR/Dialect.h"
3+
4+
using namespace mlir;
5+
namespace tt = mlir::triton;
6+
7+
// Combine the current mask with the given predicate.
8+
Value getPredMask(RewriterBase &rewriter, Type typeLike, Value currentMask,
9+
Value pred) {
10+
Type maskType = tt::getI1SameShape(typeLike);
11+
Location loc = pred.getLoc();
12+
Value mask = pred;
13+
if (isa<RankedTensorType>(maskType)) {
14+
mask = rewriter.create<tt::SplatOp>(loc, maskType, pred);
15+
}
16+
if (currentMask) {
17+
mask = rewriter.create<arith::AndIOp>(loc, mask, currentMask);
18+
}
19+
return mask;
20+
}

lib/Dialect/TritonGPU/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ add_triton_library(TritonGPUTransforms
4444
MLIRTransforms
4545
MLIRTransformUtils
4646
TritonAnalysis
47+
TritonTransforms
4748
TritonIR
4849
TritonGPUIR
4950
TritonNvidiaGPUIR

lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "mlir/IR/TypeUtilities.h"
88
#include "mlir/Interfaces/SideEffectInterfaces.h"
99
#include "mlir/Support/LLVM.h"
10+
#include "triton/Dialect/Triton/Transforms/Utility.h"
1011
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
1112
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
1213
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
@@ -154,21 +155,6 @@ bool mlir::triton::isOuterLoop(scf::ForOp forOp) {
154155
});
155156
}
156157

157-
// Combine the current mask with the given predicate.
158-
Value mlir::triton::getPredMask(RewriterBase &rewriter, Type typeLike,
159-
Value currentMask, Value pred) {
160-
Type maskType = tt::getI1SameShape(typeLike);
161-
Location loc = pred.getLoc();
162-
Value mask = pred;
163-
if (isa<RankedTensorType>(maskType)) {
164-
mask = rewriter.create<tt::SplatOp>(loc, maskType, pred);
165-
}
166-
if (currentMask) {
167-
mask = rewriter.create<arith::AndIOp>(loc, mask, currentMask);
168-
}
169-
return mask;
170-
}
171-
172158
// Function to mask operations during scheduling.
173159
Operation *mlir::triton::predicateOp(RewriterBase &rewriter, Operation *op,
174160
Value pred) {

0 commit comments

Comments
 (0)