@@ -385,35 +385,31 @@ struct ShardingPropagation
385385 shardingOp.printLoopTypesAndIndexingMaps (llvm::dbgs ());
386386 });
387387
388- // 1. propagate in reversed order
389- if (traversal == TraversalOrder::Backward ||
390- traversal == TraversalOrder::BackwardForward) {
391- for (Operation &op : llvm::make_early_inc_range (llvm::reverse (block)))
392- if (failed (visitOp (&op, builder)))
393- return signalPassFailure ();
394- if (traversal == TraversalOrder::BackwardForward) {
395- LLVM_DEBUG (DBGS () << " After backward order propagation:\n "
396- << funcOp << " \n " );
397- LLVM_DEBUG (assert (succeeded (mlir::verify (funcOp))));
388+ auto traverse = [&](auto &&range, OpBuilder &builder,
389+ const char *order) -> bool {
390+ for (Operation &op : range) {
391+ if (failed (visitOp (&op, builder))) {
392+ signalPassFailure ();
393+ return true ;
394+ }
398395 }
399- }
396+ LLVM_DEBUG (DBGS () << " After " << order << " order propagation:\n "
397+ << funcOp << " \n " );
398+ LLVM_DEBUG (assert (succeeded (mlir::verify (funcOp))));
399+ return false ;
400+ };
400401
401- // 2. propagate in original order
402- if (traversal != TraversalOrder::Backward) {
403- for (Operation &op : llvm::make_early_inc_range (block))
404- if (failed (visitOp (&op, builder)))
405- return signalPassFailure ();
406- if (traversal == TraversalOrder::ForwardBackward) {
407- LLVM_DEBUG (DBGS () << " After forward order propagation:\n "
408- << funcOp << " \n " );
409- LLVM_DEBUG (assert (succeeded (mlir::verify (funcOp))));
410- }
411- }
402+ // 1. Propagate in reversed order.
403+ if (traversal == TraversalOrder::Backward ||
404+ traversal == TraversalOrder::BackwardForward)
405+ traverse (llvm::reverse (block), builder, " backward" );
406+
407+ // 2. Propagate in original order.
408+ if (traversal != TraversalOrder::Backward)
409+ traverse (block, builder, " forward" );
412410
413- // 3. propagate in backward order if needed
411+ // 3. Propagate in backward order if needed.
414412 if (traversal == TraversalOrder::ForwardBackward)
415- for (Operation &op : llvm::make_early_inc_range (llvm::reverse (block)))
416- if (failed (visitOp (&op, builder)))
417- return signalPassFailure ();
413+ traverse (llvm::reverse (block), builder, " backward" );
418414 }
419415};
0 commit comments