Skip to content

Commit 82b4715

Browse files
[NFC][MatmulLoopPipeline] Use existing utility (#4084)
Signed-off-by: Whitney Tsang <[email protected]>
1 parent aea1468 commit 82b4715

File tree

3 files changed

+9
-18
lines changed

3 files changed

+9
-18
lines changed

include/triton/Dialect/Triton/Transforms/Utility.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55

66
using namespace mlir;
77

8+
namespace mlir::triton {
9+
810
Value getPredMask(RewriterBase &rewriter, Type typeLike, Value currentMask,
911
Value pred);
1012

13+
} // namespace mlir::triton
14+
1115
#endif // TRITON_TRANSFORMS_UTILITY_H

lib/Dialect/Triton/Transforms/Utility.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ using namespace mlir;
55
namespace tt = mlir::triton;
66

77
// Combine the current mask with the given predicate.
8-
Value getPredMask(RewriterBase &rewriter, Type typeLike, Value currentMask,
9-
Value pred) {
8+
Value tt::getPredMask(RewriterBase &rewriter, Type typeLike, Value currentMask,
9+
Value pred) {
1010
Type maskType = tt::getI1SameShape(typeLike);
1111
Location loc = pred.getLoc();
1212
Value mask = pred;

third_party/intel/lib/TritonIntelGPUTransforms/Pipeliner/MatmulLoopPipeline.cpp

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "mlir/Interfaces/SideEffectInterfaces.h"
77
#include "triton/Analysis/AxisInfo.h"
88
#include "triton/Dialect/Triton/IR/Dialect.h"
9+
#include "triton/Dialect/Triton/Transforms/Utility.h"
910
#include "llvm/ADT/TypeSwitch.h"
1011
#include "llvm/Support/Casting.h"
1112
#include "llvm/Support/Debug.h"
@@ -150,21 +151,6 @@ static void collectOpsToPipeline(scf::ForOp forOp,
150151
}
151152
}
152153

153-
/// Return a new mask of type of shape \p typeLike, and value combining the
154-
/// current mask \p currentMask with the given predicate \p pred.
155-
static Value computeNewMask(RewriterBase &rewriter, Type typeLike,
156-
Value currentMask, Value pred) {
157-
Location loc = pred.getLoc();
158-
Value mask = pred;
159-
Type maskType = tt::getI1SameShape(tt::getPointeeType(typeLike));
160-
161-
if (isa<RankedTensorType>(maskType))
162-
mask = rewriter.create<tt::SplatOp>(loc, maskType, pred);
163-
164-
return currentMask ? rewriter.create<arith::AndIOp>(loc, mask, currentMask)
165-
: mask;
166-
}
167-
168154
/// Function to mask operations during scheduling.
169155
static Operation *predicateOp(RewriterBase &rewriter, Operation *op,
170156
Value pred) {
@@ -176,7 +162,8 @@ static Operation *predicateOp(RewriterBase &rewriter, Operation *op,
176162
.Case<tt::LoadOp, ttgi::PrefetchOp>([&](auto op) {
177163
rewriter.setInsertionPoint(op);
178164
Value mask =
179-
computeNewMask(rewriter, op.getPtr().getType(), op.getMask(), pred);
165+
tt::getPredMask(rewriter, tt::getPointeeType(op.getPtr().getType()),
166+
op.getMask(), pred);
180167
op.getMaskMutable().assign(mask);
181168
return op;
182169
});

0 commit comments

Comments
 (0)