@@ -839,8 +839,7 @@ mlir::scf::replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp,
839839namespace {
840840// Fold away ForOp iter arguments when:
841841// 1) The op yields the iter arguments.
842- // 2) The iter arguments have no use and the corresponding outer region
843- // iterators (inputs) are yielded.
842+ // 2) The argument's corresponding outer region iterators (inputs) are yielded.
844843// 3) The iter arguments have no use and the corresponding (operation) results
845844// have no use.
846845//
@@ -872,30 +871,28 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
872871 newIterArgs.reserve (forOp.getInitArgs ().size ());
873872 newYieldValues.reserve (numResults);
874873 newResultValues.reserve (numResults);
875- for (auto it : llvm::zip (forOp.getInitArgs (), // iter from outside
876- forOp.getRegionIterArgs (), // iter inside region
877- forOp.getResults (), // op results
878- forOp.getYieldedValues () // iter yield
879- )) {
874+ for (auto [init, arg, result, yielded] :
875+ llvm::zip (forOp.getInitArgs (), // iter from outside
876+ forOp.getRegionIterArgs (), // iter inside region
877+ forOp.getResults (), // op results
878+ forOp.getYieldedValues () // iter yield
879+ )) {
880880 // Forwarded is `true` when:
881881 // 1) The region `iter` argument is yielded.
882- // 2) The region `iter` argument has no use, and the corresponding iter
883- // operand (input) is yielded.
882+ // 2) The region `iter` argument the corresponding input is yielded.
884883 // 3) The region `iter` argument has no use, and the corresponding op
885884 // result has no use.
886- bool forwarded = ((std::get<1 >(it) == std::get<3 >(it)) ||
887- (std::get<1 >(it).use_empty () &&
888- (std::get<0 >(it) == std::get<3 >(it) ||
889- std::get<2 >(it).use_empty ())));
885+ bool forwarded = (arg == yielded) || (init == yielded) ||
886+ (arg.use_empty () && result.use_empty ());
890887 keepMask.push_back (!forwarded);
891888 canonicalize |= forwarded;
892889 if (forwarded) {
893- newBlockTransferArgs.push_back (std::get< 0 >(it) );
894- newResultValues.push_back (std::get< 0 >(it) );
890+ newBlockTransferArgs.push_back (init );
891+ newResultValues.push_back (init );
895892 continue ;
896893 }
897- newIterArgs.push_back (std::get< 0 >(it) );
898- newYieldValues.push_back (std::get< 3 >(it) );
894+ newIterArgs.push_back (init );
895+ newYieldValues.push_back (yielded );
899896 newBlockTransferArgs.push_back (Value ()); // placeholder with null value
900897 newResultValues.push_back (Value ()); // placeholder with null value
901898 }
0 commit comments