Skip to content

Commit 8a9e7b9

Browse files
Merge commit 'f05cdc4cec480f31e79f89e2a8a26c4b51614ac2'
2 parents a36ec66 + f05cdc4 commit 8a9e7b9

File tree

5 files changed

+26
-2
lines changed

5 files changed

+26
-2
lines changed

include/triton/Dialect/Triton/IR/Utility.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,10 @@ template <typename T> auto seq(T start, T end, T step) {
173173
[=](T i) { return start + i * step; });
174174
}
175175

176+
// Combine the current mask with the given predicate.
177+
Value getPredMask(RewriterBase &rewriter, Type typeLike, Value currentMask,
178+
Value pred);
179+
176180
} // namespace triton
177181
} // namespace mlir
178182

lib/Dialect/Triton/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ add_triton_library(TritonIR
88
Traits.cpp
99
Types.cpp
1010
OpInterfaces.cpp
11+
Utility.cpp
1112

1213
DEPENDS
1314
TritonTableGen

lib/Dialect/Triton/IR/Utility.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#include "triton/Dialect/Triton/IR/Utility.h"
2+
#include "triton/Dialect/Triton/IR/Dialect.h"
3+
4+
using namespace mlir;
5+
namespace tt = mlir::triton;
6+
7+
Value tt::getPredMask(RewriterBase &rewriter, Type typeLike, Value currentMask,
8+
Value pred) {
9+
Type maskType = tt::getI1SameShape(typeLike);
10+
Location loc = pred.getLoc();
11+
Value mask = pred;
12+
if (isa<RankedTensorType>(maskType)) {
13+
mask = rewriter.create<tt::SplatOp>(loc, maskType, pred);
14+
}
15+
if (currentMask) {
16+
mask = rewriter.create<arith::AndIOp>(loc, mask, currentMask);
17+
}
18+
return mask;
19+
}

lib/Dialect/Triton/Transforms/LoopInvariantCodeMotion.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
22
#include "triton/Dialect/Triton/IR/Dialect.h"
3+
#include "triton/Dialect/Triton/IR/Utility.h"
34
#include "triton/Dialect/Triton/Transforms/Passes.h"
4-
#include "triton/Dialect/Triton/Transforms/Utility.h"
55
#include "llvm/Support/Debug.h"
66

77
#define GEN_PASS_CLASSES

lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
#include "mlir/IR/TypeUtilities.h"
88
#include "mlir/Interfaces/SideEffectInterfaces.h"
99
#include "mlir/Support/LLVM.h"
10-
#include "triton/Dialect/Triton/Transforms/Utility.h"
10+
#include "triton/Dialect/Triton/IR/Utility.h"
1111
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
1212
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
1313
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"

0 commit comments

Comments
 (0)