Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ struct LoopNestStateCollector {
SmallVector<AffineForOp, 4> forOps;
SmallVector<Operation *, 4> loadOpInsts;
SmallVector<Operation *, 4> storeOpInsts;
// Collection of possible memory side-effecting ops, other than the ops
// already accounted for like load/store/alloc/free.
SmallVector<Operation *, 4> memoryEffectOpInsts;
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modified the LoopNestStateCollector structure to hold memoryEffectOps.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing code comment here on what memory effect ops are. Load and store also have memory effects. Add doc comment for this variable.

Copy link
Author

@SwapnilGhanshyala SwapnilGhanshyala Jan 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

bool hasNonAffineRegionOp = false;

// Collects load and store operations, and whether or not a region holding op
Expand All @@ -65,7 +68,8 @@ struct MemRefDependenceGraph {
SmallVector<Operation *, 4> loads;
// List of store op insts.
SmallVector<Operation *, 4> stores;

// List of memory effect op insts other than the ones already accounted for.
SmallVector<Operation *, 4> memEffects;
Node(unsigned id, Operation *op) : id(id), op(op) {}

// Returns the load op count for 'memref'.
Expand Down
50 changes: 45 additions & 5 deletions mlir/lib/Dialect/Affine/Analysis/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,27 @@ void LoopNestStateCollector::collect(Operation *opToWalk) {
loadOpInsts.push_back(op);
else if (isa<AffineWriteOpInterface>(op))
storeOpInsts.push_back(op);
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Filtering out ops with alloc and free effects, reason in #13 (comment)

else {
// Checking to see if op has alloc or dealloc effect.
auto hasNoAllocOrFreeEffect = [op]() -> bool {
if (auto effectInterface = dyn_cast<MemoryEffectOpInterface>(op)) {
SmallVector<MemoryEffects::EffectInstance, 1> effects;
effectInterface.getEffects(effects);
// if any of the effects of this op are alloc or free type then
// return false.
return !llvm::any_of(
effects, [](const MemoryEffects::EffectInstance &it) -> bool {
return isa<MemoryEffects::Allocate, MemoryEffects::Free>(
it.getEffect());
});
}
return true;
};
// TODO: Add handling of alloc/dealloc ops in MDG construction.
// Current version filters them out.
if (!mlir::isMemoryEffectFree(op) && hasNoAllocOrFreeEffect())
memoryEffectOpInsts.push_back(op);
}
});
}

Expand Down Expand Up @@ -113,6 +134,8 @@ bool MemRefDependenceGraph::init() {
// Map from a memref to the set of ids of the nodes that have ops accessing
// the memref.
DenseMap<Value, SetVector<unsigned>> memrefAccesses;
// Vector of nodes with possible memoryEffects in order of traversal.
SmallVector<Node *> memoryEffectNodes;
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This vector holds the for-nodes with potential memory effects in the order of traversal

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

affine.for nodes?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes


DenseMap<Operation *, unsigned> forToNodeMap;
for (Operation &op : block) {
Expand All @@ -129,13 +152,24 @@ bool MemRefDependenceGraph::init() {
for (auto *opInst : collector.loadOpInsts) {
node.loads.push_back(opInst);
auto memref = cast<AffineReadOpInterface>(opInst).getMemRef();
if (llvm::is_contained(memrefAccesses, memref)) {
for (auto *n : memoryEffectNodes)
memrefAccesses[memref].insert(n->id);
}
memrefAccesses[memref].insert(node.id);
}
for (auto *opInst : collector.storeOpInsts) {
node.stores.push_back(opInst);
auto memref = cast<AffineWriteOpInterface>(opInst).getMemRef();
if (llvm::is_contained(memrefAccesses, memref)) {
for (auto *n : memoryEffectNodes)
memrefAccesses[memref].insert(n->id);
}
memrefAccesses[memref].insert(node.id);
}
node.memEffects = collector.memoryEffectOpInsts;
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the current for-node exhibits memory effects, it is appended to the sequential list of 'for-nodes' that have memory effects.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are combining ops with all memory effects but you need those with write effects below. See below.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you respond to this comment?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, we are considering side-effects as unknown entities that could be potentially read or write ops. However, if we consider side-effects with only write effect. we might lose out on potential WAR dependency in the MDGs.

See comment response #13 (comment)

if (!node.memEffects.empty())
memoryEffectNodes.push_back(&node);
forToNodeMap[&op] = node.id;
nodes.insert({node.id, node});
} else if (dyn_cast<AffineReadOpInterface>(op)) {
Expand Down Expand Up @@ -214,17 +248,23 @@ bool MemRefDependenceGraph::init() {
}

// Walk memref access lists and add graph edges between dependent nodes.
// add edge between nodes accessing memrefs where atleast one of them has a
// write op or has a potential side-effect op.
for (auto &memrefAndList : memrefAccesses) {
unsigned n = memrefAndList.second.size();
for (unsigned i = 0; i < n; ++i) {
unsigned srcId = memrefAndList.second[i];
bool srcHasStore =
getNode(srcId)->getStoreOpCount(memrefAndList.first) > 0;
Node *srcNode = getNode(srcId);
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An edge is introduced between the srcNode and dstNode if either has a memory effect or a store op.
The current algorithm tracks memoryEffectOps which are within a loop body. Top-level ops with memoryEffect are handled separately for fusion validity checks.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add code comments as suitable instead of github comments.

bool srcMayHaveMemEffOrStore =
!srcNode->memEffects.empty() ||
srcNode->getStoreOpCount(memrefAndList.first) > 0;
for (unsigned j = i + 1; j < n; ++j) {
unsigned dstId = memrefAndList.second[j];
bool dstHasStore =
getNode(dstId)->getStoreOpCount(memrefAndList.first) > 0;
if (srcHasStore || dstHasStore)
Node *dstNode = getNode(dstId);
bool dstMayHaveMemEffOrStore =
!dstNode->memEffects.empty() ||
dstNode->getStoreOpCount(memrefAndList.first) > 0;
if (srcMayHaveMemEffOrStore || dstMayHaveMemEffOrStore)
Copy link
Author

@SwapnilGhanshyala SwapnilGhanshyala Jan 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The check on the destination cannot be just for writes. Otherwise, it will miss out on handling cases where srcNode has load op and dstNode has both load and a memoryEffectOp(which has unknown side-effects).

The following sibling loops below will merge, whereas they shouldn't, if we remove the !dstNode->memEffects.empty() check:

func.func @should_not_fuse_loops_with_side_effects(%arg0:memref<10xf32>) {
    affine.for %i = 0 to 10 {
      %0 = affine.load %arg0[%i] : memref<10xf32>
    }
    affine.for %j = 0 to 10 {
      %0 = affine.load %arg0[%j] : memref<10xf32>
     "foo"():()->()
    }
    return
  }

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also changed variable names to a better representation dstMayHaveMemEffOrStore

addEdge(srcId, dstId, memrefAndList.first);
}
}
Expand Down
27 changes: 27 additions & 0 deletions mlir/test/Dialect/Affine/loop-fusion-4.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{mode=producer}))' -split-input-file | FileCheck %s --check-prefix=PRODUCER-CONSUMER
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{fusion-maximal mode=sibling}))' -split-input-file | FileCheck %s --check-prefix=SIBLING-MAXIMAL
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{fusion-maximal}))' -split-input-file | FileCheck %s --check-prefix=FUSION-MAXIMAL

// Part I of fusion tests in mlir/test/Transforms/loop-fusion.mlir.
// Part II of fusion tests in mlir/test/Transforms/loop-fusion-2.mlir
Expand Down Expand Up @@ -226,3 +227,29 @@ func.func @fuse_higher_dim_nest_into_lower_dim_nest() {
// PRODUCER-CONSUMER: return
return
}

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A small test demonstrating the handling of side-effects when fusing loops with fusion-maximal

// -----

// FUSION-MAXIMAL-LABEL: func @should_not_fuse_loops_with_side_effects
func.func @should_not_fuse_loops_with_side_effects() {
%arg0=memref.alloc(): memref<10xf32>
affine.for %i = 0 to 10 {
%0 = affine.load %arg0[%i] : memref<10xf32>
"bar"():()->()
}
affine.for %j = 0 to 10 {
%0 = affine.load %arg0[%j] : memref<10xf32>
"foo"():()->()
}
// Should not fuse loops with side-effects.
// FUSION-MAXIMAL: affine.for %{{.*}} = 0 to 10
// FUSION-MAXIMAL-NEXT: affine.load
// FUSION-MAXIMAL-NEXT: "bar"
// FUSION-MAXIMAL-NEXT: }
// FUSION-MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 10
// FUSION-MAXIMAL-NEXT: affine.load
// FUSION-MAXIMAL-NEXT: "foo"
// FUSION-MAXIMAL-NEXT: }
// FUSION-MAXIMAL: return
return
}