|
2 | 2 | #include "mlir/Analysis/TopologicalSortUtils.h" |
3 | 3 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
4 | 4 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 5 | +#include "mlir/Dialect/UB/IR/UBOps.h" |
5 | 6 | #include "mlir/IR/ImplicitLocOpBuilder.h" |
6 | 7 | #include "mlir/IR/TypeUtilities.h" |
7 | 8 | #include "mlir/Interfaces/SideEffectInterfaces.h" |
8 | 9 | #include "mlir/Support/LLVM.h" |
| 10 | +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
9 | 11 | #include "triton/Analysis/AxisInfo.h" |
10 | 12 | #include "triton/Dialect/Triton/IR/Utility.h" |
11 | 13 | #include "triton/Dialect/TritonGPU/IR/Dialect.h" |
@@ -279,6 +281,69 @@ Operation *mlir::triton::predicateOp(RewriterBase &rewriter, Operation *op, |
279 | 281 | return op; |
280 | 282 | } |
281 | 283 |
|
| 284 | +Operation *mlir::triton::wrapInMaskOp(RewriterBase &rewriter, Operation *op, |
| 285 | + Value pred) { |
| 286 | + auto mask = |
| 287 | + rewriter.create<ttg::MaskOp>(op->getLoc(), op->getResultTypes(), pred); |
| 288 | + rewriter.createBlock(&mask->getRegion(0)); |
| 289 | + rewriter.setInsertionPointToStart(&mask->getRegion(0).front()); |
| 290 | + auto newOp = rewriter.clone(*op); |
| 291 | + rewriter.create<ttg::MaskReturnOp>(op->getLoc(), newOp->getResults()); |
| 292 | + op->replaceAllUsesWith(mask->getResults()); |
| 293 | + rewriter.eraseOp(op); |
| 294 | + return mask; |
| 295 | +} |
| 296 | + |
| 297 | +void mlir::triton::resolveMaskOp(ModuleOp moduleOp, |
| 298 | + DenseSet<ttg::MaskOp> &peeledMaskOps) { |
| 299 | + IRRewriter rewriter(moduleOp); |
| 300 | + |
| 301 | + // Canonicalize the IR to simplify the arithmetic ops defining the mask |
| 302 | + auto arithDialect = |
| 303 | + moduleOp.getContext()->getLoadedDialect<arith::ArithDialect>(); |
| 304 | + RewritePatternSet patterns(moduleOp.getContext()); |
| 305 | + arithDialect->getCanonicalizationPatterns(patterns); |
| 306 | + if (mlir::applyPatternsGreedily(moduleOp, std::move(patterns)).failed()) |
| 307 | + return llvm::report_fatal_error("Failed to canonicalize the IR"); |
| 308 | + |
| 309 | + // Prune all the statically dead mask ops in the epilogue. This is a |
| 310 | + // hack, ideally we should do it for all the mask ops, but it is incorrect if |
| 311 | + // we have speculatively executed async cp operations that will store to shmem |
| 312 | + // even if the mask is false. |
| 313 | + for (auto maskOp : peeledMaskOps) { |
| 314 | + rewriter.setInsertionPoint(maskOp); |
| 315 | + while (&maskOp.getBody()->front() != maskOp.getBody()->getTerminator()) { |
| 316 | + Operation *op = &maskOp.getBody()->front(); |
| 317 | + if (isConstantIntValue(maskOp.getPred(), 0)) { |
| 318 | + if (op->getNumResults() > 0) { |
| 319 | + SmallVector<Value> results; |
| 320 | + for (auto result : op->getResults()) { |
| 321 | + auto poisonOp = rewriter.create<mlir::ub::PoisonOp>( |
| 322 | + op->getLoc(), result.getType()); |
| 323 | + results.push_back(poisonOp); |
| 324 | + } |
| 325 | + op->replaceAllUsesWith(results); |
| 326 | + } |
| 327 | + op->erase(); |
| 328 | + } |
| 329 | + } |
| 330 | + } |
| 331 | + |
| 332 | + SmallVector<ttg::MaskOp> maskOps; |
| 333 | + moduleOp->walk([&](ttg::MaskOp maskOp) { maskOps.push_back(maskOp); }); |
| 334 | + for (auto maskOp : maskOps) { |
| 335 | + rewriter.setInsertionPoint(maskOp); |
| 336 | + while (&maskOp.getBody()->front() != maskOp.getBody()->getTerminator()) { |
| 337 | + Operation *op = &maskOp.getBody()->front(); |
| 338 | + rewriter.moveOpBefore(op, maskOp); |
| 339 | + op = triton::predicateOp(rewriter, op, maskOp.getPred()); |
| 340 | + } |
| 341 | + maskOp->replaceAllUsesWith( |
| 342 | + maskOp.getBody()->getTerminator()->getOperands()); |
| 343 | + maskOp->erase(); |
| 344 | + } |
| 345 | +} |
| 346 | + |
282 | 347 | // Return true if the given ForOp has the attribute |
283 | 348 | // `tt.disallow_acc_multi_buffer` set to true. |
284 | 349 | bool mlir::triton::getDisallowAccMultiBuffer(scf::ForOp forOp) { |
|
0 commit comments