diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 83ae79ce48266..448141735ba7f 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -843,9 +843,8 @@ namespace { // 3) The iter arguments have no use and the corresponding (operation) results // have no use. // -// These arguments must be defined outside of -// the ForOp region and can just be forwarded after simplifying the op inits, -// yields and returns. +// These arguments must be defined outside of the ForOp region and can just be +// forwarded after simplifying the op inits, yields and returns. // // The implementation uses `inlineBlockBefore` to steal the content of the // original ForOp and avoid cloning. @@ -871,6 +870,7 @@ struct ForOpIterArgsFolder : public OpRewritePattern { newIterArgs.reserve(forOp.getInitArgs().size()); newYieldValues.reserve(numResults); newResultValues.reserve(numResults); + DenseMap, std::pair> initYieldToArg; for (auto [init, arg, result, yielded] : llvm::zip(forOp.getInitArgs(), // iter from outside forOp.getRegionIterArgs(), // iter inside region @@ -884,13 +884,32 @@ struct ForOpIterArgsFolder : public OpRewritePattern { // result has no use. bool forwarded = (arg == yielded) || (init == yielded) || (arg.use_empty() && result.use_empty()); - keepMask.push_back(!forwarded); - canonicalize |= forwarded; if (forwarded) { + canonicalize = true; + keepMask.push_back(false); newBlockTransferArgs.push_back(init); newResultValues.push_back(init); continue; } + + // Check if a previous kept argument always has the same values for init + // and yielded values. + if (auto it = initYieldToArg.find({init, yielded}); + it != initYieldToArg.end()) { + canonicalize = true; + keepMask.push_back(false); + auto [sameArg, sameResult] = it->second; + rewriter.replaceAllUsesWith(arg, sameArg); + rewriter.replaceAllUsesWith(result, sameResult); + // The replacement value doesn't matter because there are no uses. + newBlockTransferArgs.push_back(init); + newResultValues.push_back(init); + continue; + } + + // This value is kept. + initYieldToArg.insert({{init, yielded}, {arg, result}}); + keepMask.push_back(true); newIterArgs.push_back(init); newYieldValues.push_back(yielded); newBlockTransferArgs.push_back(Value()); // placeholder with null value diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir index 828758df6d31c..c18bd617216f1 100644 --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -821,6 +821,24 @@ func.func @fold_away_iter_and_result_with_no_use(%arg0 : i32, // ----- +// CHECK-LABEL: @replace_duplicate_iter_args +// CHECK-SAME: [[LB:%arg[0-9]]]: index, [[UB:%arg[0-9]]]: index, [[STEP:%arg[0-9]]]: index, [[A:%arg[0-9]]]: index, [[B:%arg[0-9]]]: index +func.func @replace_duplicate_iter_args(%lb: index, %ub: index, %step: index, %a: index, %b: index) -> (index, index, index, index) { + // CHECK-NEXT: [[RES:%.*]]:2 = scf.for {{.*}} iter_args([[K0:%.*]] = [[A]], [[K1:%.*]] = [[B]]) + %0:4 = scf.for %i = %lb to %ub step %step iter_args(%k0 = %a, %k1 = %b, %k2 = %b, %k3 = %a) -> (index, index, index, index) { + // CHECK-NEXT: [[V0:%.*]] = arith.addi [[K0]], [[K1]] + %1 = arith.addi %k0, %k1 : index + // CHECK-NEXT: [[V1:%.*]] = arith.addi [[K1]], [[K0]] + %2 = arith.addi %k2, %k3 : index + // CHECK-NEXT: yield [[V0]], [[V1]] + scf.yield %1, %2, %2, %1 : index, index, index, index + } + // CHECK: return [[RES]]#0, [[RES]]#1, [[RES]]#1, [[RES]]#0 + return %0#0, %0#1, %0#2, %0#3 : index, index, index, index +} + +// ----- + func.func private @do(%arg0: tensor) -> tensor func.func @matmul_on_tensors(%t0: tensor<32x1024xf32>) -> tensor {