Skip to content

Commit 73e1dca

Browse files
committed
code deduplication
1 parent 6af6d65 commit 73e1dca

File tree

1 file changed

+22
-26
lines changed

1 file changed

+22
-26
lines changed

mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -385,35 +385,31 @@ struct ShardingPropagation
385385
shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs());
386386
});
387387

388-
// 1. propagate in reversed order
389-
if (traversal == TraversalOrder::Backward ||
390-
traversal == TraversalOrder::BackwardForward) {
391-
for (Operation &op : llvm::make_early_inc_range(llvm::reverse(block)))
392-
if (failed(visitOp(&op, builder)))
393-
return signalPassFailure();
394-
if (traversal == TraversalOrder::BackwardForward) {
395-
LLVM_DEBUG(DBGS() << "After backward order propagation:\n"
396-
<< funcOp << "\n");
397-
LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp))));
388+
auto traverse = [&](auto &&range, OpBuilder &builder,
389+
const char *order) -> bool {
390+
for (Operation &op : range) {
391+
if (failed(visitOp(&op, builder))) {
392+
signalPassFailure();
393+
return true;
394+
}
398395
}
399-
}
396+
LLVM_DEBUG(DBGS() << "After " << order << " order propagation:\n"
397+
<< funcOp << "\n");
398+
LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp))));
399+
return false;
400+
};
400401

401-
// 2. propagate in original order
402-
if (traversal != TraversalOrder::Backward) {
403-
for (Operation &op : llvm::make_early_inc_range(block))
404-
if (failed(visitOp(&op, builder)))
405-
return signalPassFailure();
406-
if (traversal == TraversalOrder::ForwardBackward) {
407-
LLVM_DEBUG(DBGS() << "After forward order propagation:\n"
408-
<< funcOp << "\n");
409-
LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp))));
410-
}
411-
}
402+
// 1. Propagate in reversed order.
403+
if (traversal == TraversalOrder::Backward ||
404+
traversal == TraversalOrder::BackwardForward)
405+
traverse(llvm::reverse(block), builder, "backward");
406+
407+
// 2. Propagate in original order.
408+
if (traversal != TraversalOrder::Backward)
409+
traverse(block, builder, "forward");
412410

413-
// 3. propagate in backward order if needed
411+
// 3. Propagate in backward order if needed.
414412
if (traversal == TraversalOrder::ForwardBackward)
415-
for (Operation &op : llvm::make_early_inc_range(llvm::reverse(block)))
416-
if (failed(visitOp(&op, builder)))
417-
return signalPassFailure();
413+
traverse(llvm::reverse(block), builder, "backward");
418414
}
419415
};

0 commit comments

Comments
 (0)