66#include " mlir/Interfaces/SideEffectInterfaces.h"
77#include " triton/Analysis/AxisInfo.h"
88#include " triton/Dialect/Triton/IR/Dialect.h"
9+ #include " triton/Dialect/Triton/Transforms/Utility.h"
910#include " llvm/ADT/TypeSwitch.h"
1011#include " llvm/Support/Casting.h"
1112#include " llvm/Support/Debug.h"
@@ -150,21 +151,6 @@ static void collectOpsToPipeline(scf::ForOp forOp,
150151 }
151152}
152153
153- // / Return a new mask of type of shape \p typeLike, and value combining the
154- // / current mask \p currentMask with the given predicate \p pred.
155- static Value computeNewMask (RewriterBase &rewriter, Type typeLike,
156- Value currentMask, Value pred) {
157- Location loc = pred.getLoc ();
158- Value mask = pred;
159- Type maskType = tt::getI1SameShape (tt::getPointeeType (typeLike));
160-
161- if (isa<RankedTensorType>(maskType))
162- mask = rewriter.create <tt::SplatOp>(loc, maskType, pred);
163-
164- return currentMask ? rewriter.create <arith::AndIOp>(loc, mask, currentMask)
165- : mask;
166- }
167-
168154// / Function to mask operations during scheduling.
169155static Operation *predicateOp (RewriterBase &rewriter, Operation *op,
170156 Value pred) {
@@ -176,7 +162,8 @@ static Operation *predicateOp(RewriterBase &rewriter, Operation *op,
176162 .Case <tt::LoadOp, ttgi::PrefetchOp>([&](auto op) {
177163 rewriter.setInsertionPoint (op);
178164 Value mask =
179- computeNewMask (rewriter, op.getPtr ().getType (), op.getMask (), pred);
165+ tt::getPredMask (rewriter, tt::getPointeeType (op.getPtr ().getType ()),
166+ op.getMask (), pred);
180167 op.getMaskMutable ().assign (mask);
181168 return op;
182169 });
0 commit comments