Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 27 additions & 43 deletions mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,50 +43,34 @@ static bool overrideBuffer(Operation *op, Value buffer) {
/// propagate the type change and erase old subview ops.
static void replaceUsesAndPropagateType(RewriterBase &rewriter,
Operation *oldOp, Value val) {
SmallVector<Operation *> opsToDelete;
SmallVector<OpOperand *> operandsToReplace;

// Save the operand to replace / delete later (avoid iterator invalidation).
// TODO: can we use an early_inc iterator?
for (OpOperand &use : oldOp->getUses()) {
// Non-subview ops will be replaced by `val`.
auto subviewUse = dyn_cast<memref::SubViewOp>(use.getOwner());
if (!subviewUse) {
operandsToReplace.push_back(&use);
// Iterate with early_inc to erase current user inside the loop.
for (OpOperand &use : llvm::make_early_inc_range(oldOp->getUses())) {
Operation *user = use.getOwner();
if (auto subviewUse = dyn_cast<memref::SubViewOp>(user)) {
// `subview(old_op)` is replaced by a new `subview(val)`.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(subviewUse);
MemRefType newType = memref::SubViewOp::inferRankReducedResultType(
subviewUse.getType().getShape(), cast<MemRefType>(val.getType()),
subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(),
subviewUse.getStaticStrides());
Value newSubview = memref::SubViewOp::create(
rewriter, subviewUse->getLoc(), newType, val,
subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(),
subviewUse.getMixedStrides());

// Ouch recursion ... is this really necessary?
replaceUsesAndPropagateType(rewriter, subviewUse, newSubview);

// Safe to erase.
rewriter.eraseOp(subviewUse);
continue;
}

// `subview(old_op)` is replaced by a new `subview(val)`.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(subviewUse);
MemRefType newType = memref::SubViewOp::inferRankReducedResultType(
subviewUse.getType().getShape(), cast<MemRefType>(val.getType()),
subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(),
subviewUse.getStaticStrides());
Value newSubview = memref::SubViewOp::create(
rewriter, subviewUse->getLoc(), newType, val,
subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(),
subviewUse.getMixedStrides());

// Ouch recursion ... is this really necessary?
replaceUsesAndPropagateType(rewriter, subviewUse, newSubview);

opsToDelete.push_back(use.getOwner());
// Non-subview: replace with new value.
rewriter.startOpModification(user);
use.set(val);
rewriter.finalizeOpModification(user);
}

// Perform late replacement.
// TODO: can we use an early_inc iterator?
for (OpOperand *operand : operandsToReplace) {
Operation *op = operand->getOwner();
rewriter.startOpModification(op);
operand->set(val);
rewriter.finalizeOpModification(op);
}

// Perform late op erasure.
// TODO: can we use an early_inc iterator?
for (Operation *op : opsToDelete)
rewriter.eraseOp(op);
}

// Transformation to do multi-buffering/array expansion to remove dependencies
Expand Down Expand Up @@ -216,8 +200,8 @@ mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp,
offsets, sizes, strides);
LLVM_DEBUG(DBGS() << "--multi-buffered slice: " << subview << "\n");

// 5. Due to the recursive nature of replaceUsesAndPropagateType , we need to
// handle dealloc uses separately..
// 5. Due to the recursive nature of replaceUsesAndPropagateType , we need
// to handle dealloc uses separately..
for (OpOperand &use : llvm::make_early_inc_range(allocOp->getUses())) {
auto deallocOp = dyn_cast<memref::DeallocOp>(use.getOwner());
if (!deallocOp)
Expand Down
Loading