diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp index 5891e2fa0067e..6fefe4487ef59 100644 --- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp @@ -343,61 +343,48 @@ static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, return newMemRef; } -/// Walking from node 'srcId' to node 'dstId' (exclusive of 'srcId' and -/// 'dstId'), if there is any non-affine operation accessing 'memref', return -/// true. Otherwise, return false. -static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId, - Value memref, - MemRefDependenceGraph *mdg) { - auto *srcNode = mdg->getNode(srcId); - auto *dstNode = mdg->getNode(dstId); - Value::user_range users = memref.getUsers(); - // For each MemRefDependenceGraph's node that is between 'srcNode' and - // 'dstNode' (exclusive of 'srcNodes' and 'dstNode'), check whether any - // non-affine operation in the node accesses the 'memref'. - for (auto &idAndNode : mdg->nodes) { - Operation *op = idAndNode.second.op; - // Take care of operations between 'srcNode' and 'dstNode'. - if (srcNode->op->isBeforeInBlock(op) && op->isBeforeInBlock(dstNode->op)) { - // Walk inside the operation to find any use of the memref. - // Interrupt the walk if found. - auto walkResult = op->walk([&](Operation *user) { - // Skip affine ops. - if (isa(*user)) - return WalkResult::advance(); - // Find a non-affine op that uses the memref. - if (llvm::is_contained(users, user)) - return WalkResult::interrupt(); - return WalkResult::advance(); - }); - if (walkResult.wasInterrupted()) - return true; - } - } - return false; +/// Returns true if there are any non-affine uses of `memref` in any of +/// the operations between `start` and `end` (both exclusive). Any other +/// than affine read/write are treated as non-affine uses of `memref`. +static bool hasNonAffineUsersOnPath(Operation *start, Operation *end, + Value memref) { + assert(start->getBlock() == end->getBlock()); + assert(start->isBeforeInBlock(end) && "start expected to be before end"); + Block *block = start->getBlock(); + // Check if there is a non-affine memref user in any op between `start` and + // `end`. + return llvm::any_of(memref.getUsers(), [&](Operation *user) { + if (isa(user)) + return false; + Operation *ancestor = block->findAncestorOpInBlock(*user); + return ancestor && start->isBeforeInBlock(ancestor) && + ancestor->isBeforeInBlock(end); + }); } -/// Check whether a memref value in node 'srcId' has a non-affine that -/// is between node 'srcId' and node 'dstId' (exclusive of 'srcNode' and -/// 'dstNode'). -static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId, - MemRefDependenceGraph *mdg) { - // Collect memref values in node 'srcId'. - auto *srcNode = mdg->getNode(srcId); +/// Check whether a memref value used in any operation of 'src' has a +/// non-affine operation that is between `src` and `end` (exclusive of `src` +/// and `end`) where `src` and `end` are expected to be in the same Block. +/// Any other than affine read/write are treated as non-affine uses of memref. +static bool hasNonAffineUsersOnPath(Operation *src, Operation *end) { + assert(src->getBlock() == end->getBlock() && "same block expected"); + + // Trivial case. `src` and `end` are exclusive. + if (src == end || end->isBeforeInBlock(src)) + return false; + + // Collect relevant memref values. llvm::SmallDenseSet memRefValues; - srcNode->op->walk([&](Operation *op) { - // Skip affine ops. - if (isa(op)) - return WalkResult::advance(); + src->walk([&](Operation *op) { for (Value v : op->getOperands()) // Collect memref values only. if (isa(v.getType())) memRefValues.insert(v); return WalkResult::advance(); }); - // Looking for users between node 'srcId' and node 'dstId'. + // Look for non-affine users between `src` and `end`. return llvm::any_of(memRefValues, [&](Value memref) { - return hasNonAffineUsersOnThePath(srcId, dstId, memref, mdg); + return hasNonAffineUsersOnPath(src, end, memref); }); } @@ -884,7 +871,7 @@ struct GreedyFusion { // escaping memrefs so we can limit this check to only scenarios with // escaping memrefs. if (!srcEscapingMemRefs.empty() && - hasNonAffineUsersOnThePath(srcId, dstId, mdg)) { + hasNonAffineUsersOnPath(srcNode->op, dstNode->op)) { LLVM_DEBUG(llvm::dbgs() << "Can't fuse: non-affine users in between the loops\n"); continue; @@ -1247,8 +1234,8 @@ struct GreedyFusion { // Skip if a memref value in one node is used by a non-affine memref // access that lies between 'dstNode' and 'sibNode'. - if (hasNonAffineUsersOnThePath(dstNode->id, sibNode->id, mdg) || - hasNonAffineUsersOnThePath(sibNode->id, dstNode->id, mdg)) + if (hasNonAffineUsersOnPath(dstNode->op, sibNode->op) || + hasNonAffineUsersOnPath(sibNode->op, dstNode->op)) return false; return true; };