Skip to content

Conversation

@Mogball
Copy link
Contributor

@Mogball Mogball commented Jan 3, 2025

This alters the condition in ForOpIterArgsFolder to always remove iter args when their initial value equals the yielded value, not just when the arg has no use.

This alters the condition in ForOpIterArgsFolder to always remove iter
args when their initial value equals the yielded value, not just when
the arg has no use.
@llvmbot
Copy link
Member

llvmbot commented Jan 3, 2025

@llvm/pr-subscribers-mlir-scf

@llvm/pr-subscribers-mlir

Author: Jeff Niu (Mogball)

Changes

This alters the condition in ForOpIterArgsFolder to always remove iter args when their initial value equals the yielded value, not just when the arg has no use.


Full diff: https://github.com/llvm/llvm-project/pull/121555.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (+12-13)
  • (modified) mlir/test/Dialect/SCF/canonicalize.mlir (+18-4)
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index eded1c394f126c..872d34de4495bf 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -872,30 +872,29 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
     newIterArgs.reserve(forOp.getInitArgs().size());
     newYieldValues.reserve(numResults);
     newResultValues.reserve(numResults);
-    for (auto it : llvm::zip(forOp.getInitArgs(),       // iter from outside
-                             forOp.getRegionIterArgs(), // iter inside region
-                             forOp.getResults(),        // op results
-                             forOp.getYieldedValues()   // iter yield
-                             )) {
+    for (auto [init, arg, result, yielded] :
+         llvm::zip(forOp.getInitArgs(),       // iter from outside
+                   forOp.getRegionIterArgs(), // iter inside region
+                   forOp.getResults(),        // op results
+                   forOp.getYieldedValues()   // iter yield
+                   )) {
       // Forwarded is `true` when:
       // 1) The region `iter` argument is yielded.
       // 2) The region `iter` argument has no use, and the corresponding iter
       // operand (input) is yielded.
       // 3) The region `iter` argument has no use, and the corresponding op
       // result has no use.
-      bool forwarded = ((std::get<1>(it) == std::get<3>(it)) ||
-                        (std::get<1>(it).use_empty() &&
-                         (std::get<0>(it) == std::get<3>(it) ||
-                          std::get<2>(it).use_empty())));
+      bool forwarded = (arg == yielded) || (init == yielded) ||
+                       (arg.use_empty() && result.use_empty());
       keepMask.push_back(!forwarded);
       canonicalize |= forwarded;
       if (forwarded) {
-        newBlockTransferArgs.push_back(std::get<0>(it));
-        newResultValues.push_back(std::get<0>(it));
+        newBlockTransferArgs.push_back(init);
+        newResultValues.push_back(init);
         continue;
       }
-      newIterArgs.push_back(std::get<0>(it));
-      newYieldValues.push_back(std::get<3>(it));
+      newIterArgs.push_back(init);
+      newYieldValues.push_back(yielded);
       newBlockTransferArgs.push_back(Value()); // placeholder with null value
       newResultValues.push_back(Value());      // placeholder with null value
     }
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 8c4e7a41ee6bc4..828758df6d31c0 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -408,6 +408,20 @@ func.func @for_yields_4() -> i32 {
 
 // -----
 
+// CHECK-LABEL: @constant_iter_arg
+func.func @constant_iter_arg(%arg0: index, %arg1: index, %arg2: index) {
+  %c0_i32 = arith.constant 0 : i32
+  // CHECK: scf.for %arg3 = %arg0 to %arg1 step %arg2 {
+  %0 = scf.for %i = %arg0 to %arg1 step %arg2 iter_args(%arg3 = %c0_i32) -> i32 {
+    // CHECK-NEXT: "test.use"(%c0_i32)
+    "test.use"(%arg3) : (i32) -> ()
+    scf.yield %c0_i32 : i32
+  }
+  return
+}
+
+// -----
+
 // CHECK-LABEL: @replace_true_if
 func.func @replace_true_if() {
   %true = arith.constant true
@@ -1789,7 +1803,7 @@ module {
 }
 // CHECK-LABEL: @fold_iter_args_not_being_modified_within_scfforall
 //  CHECK-SAME:   (%{{.*}}: index, %[[ARG1:.*]]: tensor<?xf32>, %[[ARG2:.*]]: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
-//       CHECK:    %[[RESULT:.*]] = scf.forall 
+//       CHECK:    %[[RESULT:.*]] = scf.forall
 //  CHECK-SAME:                       shared_outs(%[[ITER_ARG_5:.*]] = %[[ARG2]]) -> (tensor<?xf32>) {
 //       CHECK:      %[[OPERAND0:.*]] = tensor.extract_slice %[[ARG1]]
 //       CHECK:      %[[OPERAND1:.*]] = tensor.extract_slice %[[ITER_ARG_5]]
@@ -1832,7 +1846,7 @@ module {
 }
 // CHECK-LABEL: @fold_iter_args_with_no_use_of_result_scfforall
 //  CHECK-SAME:   (%{{.*}}: index, %[[ARG1:.*]]: tensor<?xf32>, %[[ARG2:.*]]: tensor<?xf32>, %[[ARG3:.*]]: tensor<?xf32>) -> tensor<?xf32> {
-//       CHECK:    %[[RESULT:.*]] = scf.forall 
+//       CHECK:    %[[RESULT:.*]] = scf.forall
 //  CHECK-SAME:                       shared_outs(%[[ITER_ARG_6:.*]] = %[[ARG2]]) -> (tensor<?xf32>) {
 //       CHECK:      %[[OPERAND0:.*]] = tensor.extract_slice %[[ARG1]]
 //       CHECK:      %[[OPERAND1:.*]] = tensor.extract_slice %[[ARG3]]
@@ -1856,7 +1870,7 @@ func.func @index_switch_fold() -> (f32, f32) {
     %y = arith.constant 42.0 : f32
     scf.yield %y : f32
   }
-  
+
   %switch_cst_2 = arith.constant 2: index
   %1 = scf.index_switch %switch_cst_2 -> f32
   case 0 {
@@ -1867,7 +1881,7 @@ func.func @index_switch_fold() -> (f32, f32) {
     %y = arith.constant 42.0 : f32
     scf.yield %y : f32
   }
-  
+
   return %0, %1 : f32, f32
 }
 

@Mogball Mogball merged commit 9d8e634 into main Jan 3, 2025
5 of 7 checks passed
@Mogball Mogball deleted the users/mogball/for_args branch January 3, 2025 19:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants