diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h index 536cbf9018e89..e486bb627474d 100644 --- a/mlir/include/mlir/IR/Block.h +++ b/mlir/include/mlir/IR/Block.h @@ -16,6 +16,8 @@ #include "mlir/IR/BlockSupport.h" #include "mlir/IR/Visitors.h" +#include "llvm/ADT/SmallPtrSet.h" + namespace llvm { class BitVector; class raw_ostream; @@ -264,6 +266,19 @@ class alignas(8) Block : public IRObjectWithUseList, succ_iterator succ_end() { return getSuccessors().end(); } SuccessorRange getSuccessors() { return SuccessorRange(this); } + /// Return "true" if there is a path from this block to the given block + /// (according to the successors relationship). Both blocks must be in the + /// same region. Paths that contain a block from `except` do not count. + /// This function returns "false" if `other` is in `except`. + /// + /// Note: This function performs a block graph traversal and its complexity + /// linear in the number of blocks in the parent region. + /// + /// Note: Reachability is a necessary but insufficient condition for + /// dominance. Do not use this function in places where you need to check for + /// dominance. + bool isReachable(Block *other, SmallPtrSet &&except = {}); + //===--------------------------------------------------------------------===// // Walkers //===--------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp index 829e954d53b25..d1e6acef324fb 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -273,25 +273,6 @@ static bool happensBefore(Operation *a, Operation *b, return false; } -static bool isReachable(Block *from, Block *to, ArrayRef except) { - DenseSet visited; - SmallVector worklist; - for (Block *succ : from->getSuccessors()) - worklist.push_back(succ); - while (!worklist.empty()) { - Block *next = worklist.pop_back_val(); - if (llvm::is_contained(except, next)) - continue; - if (next == to) - return true; - if (!visited.insert(next).second) - continue; - for (Block *succ : next->getSuccessors()) - worklist.push_back(succ); - } - return false; -} - /// Return `true` if op dominance can be used to rule out a read-after-write /// conflicts based on the ordering of ops. Returns `false` if op dominance /// cannot be used to due region-based loops. @@ -427,8 +408,8 @@ static bool canUseOpDominanceDueToBlocks(OpOperand *uRead, OpOperand *uWrite, Block *writeBlock = uWrite->getOwner()->getBlock(); for (Value def : definitions) { Block *defBlock = def.getParentBlock(); - if (isReachable(readBlock, writeBlock, {defBlock}) && - isReachable(writeBlock, readBlock, {defBlock})) + if (readBlock->isReachable(writeBlock, {defBlock}) && + writeBlock->isReachable(readBlock, {defBlock})) return false; } diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp index 3a30382114c8d..bd5f06a3b46d4 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -73,20 +73,7 @@ bool TransferOptimization::isReachable(Operation *start, Operation *dest) { // Simple case where the start op dominate the destination. if (dominators.dominates(start, dest)) return true; - Block *startBlock = start->getBlock(); - Block *destBlock = dest->getBlock(); - SmallVector worklist(startBlock->succ_begin(), - startBlock->succ_end()); - SmallPtrSet visited; - while (!worklist.empty()) { - Block *bb = worklist.pop_back_val(); - if (!visited.insert(bb).second) - continue; - if (dominators.dominates(bb, destBlock)) - return true; - worklist.append(bb->succ_begin(), bb->succ_end()); - } - return false; + return start->getBlock()->isReachable(dest->getBlock()); } /// For transfer_write to overwrite fully another transfer_write must: diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp index 65099f8ff15a6..4b1568219fb37 100644 --- a/mlir/lib/IR/Block.cpp +++ b/mlir/lib/IR/Block.cpp @@ -7,9 +7,12 @@ //===----------------------------------------------------------------------===// #include "mlir/IR/Block.h" + #include "mlir/IR/Builders.h" #include "mlir/IR/Operation.h" #include "llvm/ADT/BitVector.h" +#include "llvm/ADT/SmallPtrSet.h" + using namespace mlir; //===----------------------------------------------------------------------===// @@ -331,7 +334,7 @@ unsigned PredecessorIterator::getSuccessorIndex() const { } //===----------------------------------------------------------------------===// -// SuccessorRange +// Successors //===----------------------------------------------------------------------===// SuccessorRange::SuccessorRange() : SuccessorRange(nullptr, 0) {} @@ -349,6 +352,26 @@ SuccessorRange::SuccessorRange(Operation *term) : SuccessorRange() { base = term->getBlockOperands().data(); } +bool Block::isReachable(Block *other, SmallPtrSet &&except) { + assert(getParent() == other->getParent() && "expected same region"); + if (except.contains(other)) { + // Fast path: If `other` is in the `except` set, there can be no path from + // "this" to `other` (that does not pass through an excluded block). + return false; + } + SmallVector worklist(succ_begin(), succ_end()); + while (!worklist.empty()) { + Block *next = worklist.pop_back_val(); + if (next == other) + return true; + // Note: `except` keeps track of already visited blocks. + if (!except.insert(next).second) + continue; + worklist.append(next->succ_begin(), next->succ_end()); + } + return false; +} + //===----------------------------------------------------------------------===// // BlockRange //===----------------------------------------------------------------------===//