diff --git a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h index 3dea99cd6b3e5..1c88e7f8a0a8c 100644 --- a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h +++ b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h @@ -39,6 +39,9 @@ struct LoopNestStateCollector { SmallVector forOps; SmallVector loadOpInsts; SmallVector storeOpInsts; + // Collection of possible memory side-effecting ops, other than the ops + // already accounted for like load/store/alloc/free. + SmallVector memoryEffectOpInsts; bool hasNonAffineRegionOp = false; // Collects load and store operations, and whether or not a region holding op @@ -65,7 +68,8 @@ struct MemRefDependenceGraph { SmallVector loads; // List of store op insts. SmallVector stores; - + // List of memory effect op insts other than the ones already accounted for. + SmallVector memEffects; Node(unsigned id, Operation *op) : id(id), op(op) {} // Returns the load op count for 'memref'. diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp index 194ee9115e3d7..831c04c5ce2d8 100644 --- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp @@ -50,6 +50,27 @@ void LoopNestStateCollector::collect(Operation *opToWalk) { loadOpInsts.push_back(op); else if (isa(op)) storeOpInsts.push_back(op); + else { + // Checking to see if op has alloc or dealloc effect. + auto hasNoAllocOrFreeEffect = [op]() -> bool { + if (auto effectInterface = dyn_cast(op)) { + SmallVector effects; + effectInterface.getEffects(effects); + // if any of the effects of this op are alloc or free type then + // return false. + return !llvm::any_of( + effects, [](const MemoryEffects::EffectInstance &it) -> bool { + return isa( + it.getEffect()); + }); + } + return true; + }; + // TODO: Add handling of alloc/dealloc ops in MDG construction. + // Current version filters them out. + if (!mlir::isMemoryEffectFree(op) && hasNoAllocOrFreeEffect()) + memoryEffectOpInsts.push_back(op); + } }); } @@ -113,6 +134,8 @@ bool MemRefDependenceGraph::init() { // Map from a memref to the set of ids of the nodes that have ops accessing // the memref. DenseMap> memrefAccesses; + // Vector of nodes with possible memoryEffects in order of traversal. + SmallVector memoryEffectNodes; DenseMap forToNodeMap; for (Operation &op : block) { @@ -129,13 +152,24 @@ bool MemRefDependenceGraph::init() { for (auto *opInst : collector.loadOpInsts) { node.loads.push_back(opInst); auto memref = cast(opInst).getMemRef(); + if (llvm::is_contained(memrefAccesses, memref)) { + for (auto *n : memoryEffectNodes) + memrefAccesses[memref].insert(n->id); + } memrefAccesses[memref].insert(node.id); } for (auto *opInst : collector.storeOpInsts) { node.stores.push_back(opInst); auto memref = cast(opInst).getMemRef(); + if (llvm::is_contained(memrefAccesses, memref)) { + for (auto *n : memoryEffectNodes) + memrefAccesses[memref].insert(n->id); + } memrefAccesses[memref].insert(node.id); } + node.memEffects = collector.memoryEffectOpInsts; + if (!node.memEffects.empty()) + memoryEffectNodes.push_back(&node); forToNodeMap[&op] = node.id; nodes.insert({node.id, node}); } else if (dyn_cast(op)) { @@ -214,17 +248,23 @@ bool MemRefDependenceGraph::init() { } // Walk memref access lists and add graph edges between dependent nodes. + // add edge between nodes accessing memrefs where atleast one of them has a + // write op or has a potential side-effect op. for (auto &memrefAndList : memrefAccesses) { unsigned n = memrefAndList.second.size(); for (unsigned i = 0; i < n; ++i) { unsigned srcId = memrefAndList.second[i]; - bool srcHasStore = - getNode(srcId)->getStoreOpCount(memrefAndList.first) > 0; + Node *srcNode = getNode(srcId); + bool srcMayHaveMemEffOrStore = + !srcNode->memEffects.empty() || + srcNode->getStoreOpCount(memrefAndList.first) > 0; for (unsigned j = i + 1; j < n; ++j) { unsigned dstId = memrefAndList.second[j]; - bool dstHasStore = - getNode(dstId)->getStoreOpCount(memrefAndList.first) > 0; - if (srcHasStore || dstHasStore) + Node *dstNode = getNode(dstId); + bool dstMayHaveMemEffOrStore = + !dstNode->memEffects.empty() || + dstNode->getStoreOpCount(memrefAndList.first) > 0; + if (srcMayHaveMemEffOrStore || dstMayHaveMemEffOrStore) addEdge(srcId, dstId, memrefAndList.first); } } diff --git a/mlir/test/Dialect/Affine/loop-fusion-4.mlir b/mlir/test/Dialect/Affine/loop-fusion-4.mlir index 3fc31ad0d77b8..bde22c6e43d14 100644 --- a/mlir/test/Dialect/Affine/loop-fusion-4.mlir +++ b/mlir/test/Dialect/Affine/loop-fusion-4.mlir @@ -1,5 +1,6 @@ // RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{mode=producer}))' -split-input-file | FileCheck %s --check-prefix=PRODUCER-CONSUMER // RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{fusion-maximal mode=sibling}))' -split-input-file | FileCheck %s --check-prefix=SIBLING-MAXIMAL +// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{fusion-maximal}))' -split-input-file | FileCheck %s --check-prefix=FUSION-MAXIMAL // Part I of fusion tests in mlir/test/Transforms/loop-fusion.mlir. // Part II of fusion tests in mlir/test/Transforms/loop-fusion-2.mlir @@ -226,3 +227,29 @@ func.func @fuse_higher_dim_nest_into_lower_dim_nest() { // PRODUCER-CONSUMER: return return } + +// ----- + +// FUSION-MAXIMAL-LABEL: func @should_not_fuse_loops_with_side_effects +func.func @should_not_fuse_loops_with_side_effects() { + %arg0=memref.alloc(): memref<10xf32> + affine.for %i = 0 to 10 { + %0 = affine.load %arg0[%i] : memref<10xf32> + "bar"():()->() + } + affine.for %j = 0 to 10 { + %0 = affine.load %arg0[%j] : memref<10xf32> + "foo"():()->() + } + // Should not fuse loops with side-effects. + // FUSION-MAXIMAL: affine.for %{{.*}} = 0 to 10 + // FUSION-MAXIMAL-NEXT: affine.load + // FUSION-MAXIMAL-NEXT: "bar" + // FUSION-MAXIMAL-NEXT: } + // FUSION-MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 10 + // FUSION-MAXIMAL-NEXT: affine.load + // FUSION-MAXIMAL-NEXT: "foo" + // FUSION-MAXIMAL-NEXT: } + // FUSION-MAXIMAL: return + return + } \ No newline at end of file