Skip to content

Conversation

@charithaintc
Copy link
Contributor

In some cases, loop bounds (lower, upper and step) of scf.for can come locally from the parent warp op the scf.for. Current logic will not yield the loop bounds in the new warp op generated during lowering causing sinked scf.for to have non dominating use.

In this PR, we have added logic to yield loop bounds by default (treat them as other operands of scf.for) which fixes this bug.

@llvmbot
Copy link
Member

llvmbot commented Oct 14, 2025

@llvm/pr-subscribers-mlir

Author: Charitha Saumya (charithaintc)

Changes

In some cases, loop bounds (lower, upper and step) of scf.for can come locally from the parent warp op the scf.for. Current logic will not yield the loop bounds in the new warp op generated during lowering causing sinked scf.for to have non dominating use.

In this PR, we have added logic to yield loop bounds by default (treat them as other operands of scf.for) which fixes this bug.


Full diff: https://github.com/llvm/llvm-project/pull/163443.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+18-7)
  • (modified) mlir/test/Dialect/Vector/vector-warp-distribute.mlir (+35)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index e95338f7d18be..2ee65dc0f902a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -2038,11 +2038,19 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
     }
 
     // Newly created `WarpOp` will yield values in following order:
-    // 1. All init args of the `ForOp`.
-    // 2. All escaping values.
-    // 3. All non-`ForOp` yielded values.
+    // 1. Loop bounds.
+    // 2. All init args of the `ForOp`.
+    // 3. All escaping values.
+    // 4. All non-`ForOp` yielded values.
     SmallVector<Value> newWarpOpYieldValues;
     SmallVector<Type> newWarpOpDistTypes;
+    newWarpOpYieldValues.insert(
+        newWarpOpYieldValues.end(),
+        {forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep()});
+    newWarpOpDistTypes.insert(newWarpOpDistTypes.end(),
+                              {forOp.getLowerBound().getType(),
+                               forOp.getUpperBound().getType(),
+                               forOp.getStep().getType()});
     for (auto [i, initArg] : llvm::enumerate(forOp.getInitArgs())) {
       newWarpOpYieldValues.push_back(initArg);
       // Compute the distributed type for this init arg.
@@ -2081,20 +2089,23 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
 
     // Next, we create a new `ForOp` with the init args yielded by the new
     // `WarpOp`.
+    const unsigned initArgsStartIdx = 3; // After loop bounds.
     const unsigned escapingValuesStartIdx =
+        initArgsStartIdx +
         forOp.getInitArgs().size(); // `ForOp` init args are positioned before
                                     // escaping values in the new `WarpOp`.
     SmallVector<Value> newForOpOperands;
-    for (size_t i = 0; i < escapingValuesStartIdx; ++i)
+    for (size_t i = initArgsStartIdx; i < escapingValuesStartIdx; ++i)
       newForOpOperands.push_back(newWarpOp.getResult(i));
 
     // Create a new `ForOp` outside the new `WarpOp` region.
     OpBuilder::InsertionGuard g(rewriter);
     rewriter.setInsertionPointAfter(newWarpOp);
     auto newForOp = scf::ForOp::create(
-        rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
-        forOp.getStep(), newForOpOperands, /*bodyBuilder=*/nullptr,
-        forOp.getUnsignedCmp());
+        rewriter, forOp.getLoc(), /**LowerBound=**/ newWarpOp.getResult(0),
+        /**UpperBound=**/ newWarpOp.getResult(1),
+        /**Step=**/ newWarpOp.getResult(2), newForOpOperands,
+        /*bodyBuilder=*/nullptr, forOp.getUnsignedCmp());
     // Next, we insert a new `WarpOp` (called inner `WarpOp`) inside the
     // newly created `ForOp`. This `WarpOp` will contain all ops that were
     // contained within the original `ForOp` body.
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index bb7639204022f..ab87684dbb01a 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -473,6 +473,41 @@ func.func @warp_scf_for_use_from_above(%arg0: index) {
   return
 }
 
+// -----
+// CHECK-PROP-LABEL:  func.func @warp_scf_for_local_loop_bounds
+// CHECK-PROP:          (%{{.*}}: index, %[[ARG1:[a-zA-Z0-9]+]]: index) {
+// CHECK-PROP:          %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32] args(%[[ARG1]] : index) -> (vector<4xf32>) {
+// CHECK-PROP:          ^bb0(%{{.*}}: index):
+// CHECK-PROP:            %[[T2:.*]] = "some_def"() : () -> vector<128xf32>
+// CHECK-PROP:            gpu.yield %[[T2]] : vector<128xf32>
+// CHECK-PROP:          }
+// CHECK-PROP:          %[[FOR:.*]] = scf.for %{{.*}} to %[[ARG1]] step %{{.*}} iter_args(%{{.*}}) -> (vector<4xf32>) {
+// CHECK-PROP:            %[[W2:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32]
+// CHECK-PROP-SAME:          args(%{{.*}} : vector<4xf32>) -> (vector<4xf32>) {
+// CHECK-PROP:            ^bb0(%{{.*}}: vector<128xf32>):
+// CHECK-PROP:              gpu.yield %{{.*}} : vector<128xf32>
+// CHECK-PROP:            }
+// CHECK-PROP:            scf.yield %[[W2]] : vector<4xf32>
+// CHECK-PROP:          }
+// CHECK-PROP:          "some_use"(%[[FOR]]) : (vector<4xf32>) -> ()
+// CHECK-PROP:          return
+func.func @warp_scf_for_local_loop_bounds(%arg0: index, %bound: index) {
+  %c1 = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %0 = gpu.warp_execute_on_lane_0(%arg0)[32]
+    args(%bound : index) -> (vector<4xf32>) {
+    ^bb0(%arg1: index):
+    %ini = "some_def"() : () -> (vector<128xf32>)
+    %3 = scf.for %arg3 = %c0 to %arg1 step %c1 iter_args(%arg4 = %ini) -> (vector<128xf32>) {
+      %acc = "some_def"(%arg4) : (vector<128xf32>) -> (vector<128xf32>)
+      scf.yield %acc : vector<128xf32>
+    }
+    gpu.yield %3 : vector<128xf32>
+  }
+  "some_use"(%0) : (vector<4xf32>) -> ()
+  return
+}
+
 // -----
 
 // CHECK-PROP-LABEL:   func @warp_scf_for_swap(

@llvmbot
Copy link
Member

llvmbot commented Oct 14, 2025

@llvm/pr-subscribers-mlir-vector

Author: Charitha Saumya (charithaintc)

Changes

In some cases, loop bounds (lower, upper and step) of scf.for can come locally from the parent warp op the scf.for. Current logic will not yield the loop bounds in the new warp op generated during lowering causing sinked scf.for to have non dominating use.

In this PR, we have added logic to yield loop bounds by default (treat them as other operands of scf.for) which fixes this bug.


Full diff: https://github.com/llvm/llvm-project/pull/163443.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+18-7)
  • (modified) mlir/test/Dialect/Vector/vector-warp-distribute.mlir (+35)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index e95338f7d18be..2ee65dc0f902a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -2038,11 +2038,19 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
     }
 
     // Newly created `WarpOp` will yield values in following order:
-    // 1. All init args of the `ForOp`.
-    // 2. All escaping values.
-    // 3. All non-`ForOp` yielded values.
+    // 1. Loop bounds.
+    // 2. All init args of the `ForOp`.
+    // 3. All escaping values.
+    // 4. All non-`ForOp` yielded values.
     SmallVector<Value> newWarpOpYieldValues;
     SmallVector<Type> newWarpOpDistTypes;
+    newWarpOpYieldValues.insert(
+        newWarpOpYieldValues.end(),
+        {forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep()});
+    newWarpOpDistTypes.insert(newWarpOpDistTypes.end(),
+                              {forOp.getLowerBound().getType(),
+                               forOp.getUpperBound().getType(),
+                               forOp.getStep().getType()});
     for (auto [i, initArg] : llvm::enumerate(forOp.getInitArgs())) {
       newWarpOpYieldValues.push_back(initArg);
       // Compute the distributed type for this init arg.
@@ -2081,20 +2089,23 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
 
     // Next, we create a new `ForOp` with the init args yielded by the new
     // `WarpOp`.
+    const unsigned initArgsStartIdx = 3; // After loop bounds.
     const unsigned escapingValuesStartIdx =
+        initArgsStartIdx +
         forOp.getInitArgs().size(); // `ForOp` init args are positioned before
                                     // escaping values in the new `WarpOp`.
     SmallVector<Value> newForOpOperands;
-    for (size_t i = 0; i < escapingValuesStartIdx; ++i)
+    for (size_t i = initArgsStartIdx; i < escapingValuesStartIdx; ++i)
       newForOpOperands.push_back(newWarpOp.getResult(i));
 
     // Create a new `ForOp` outside the new `WarpOp` region.
     OpBuilder::InsertionGuard g(rewriter);
     rewriter.setInsertionPointAfter(newWarpOp);
     auto newForOp = scf::ForOp::create(
-        rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
-        forOp.getStep(), newForOpOperands, /*bodyBuilder=*/nullptr,
-        forOp.getUnsignedCmp());
+        rewriter, forOp.getLoc(), /**LowerBound=**/ newWarpOp.getResult(0),
+        /**UpperBound=**/ newWarpOp.getResult(1),
+        /**Step=**/ newWarpOp.getResult(2), newForOpOperands,
+        /*bodyBuilder=*/nullptr, forOp.getUnsignedCmp());
     // Next, we insert a new `WarpOp` (called inner `WarpOp`) inside the
     // newly created `ForOp`. This `WarpOp` will contain all ops that were
     // contained within the original `ForOp` body.
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index bb7639204022f..ab87684dbb01a 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -473,6 +473,41 @@ func.func @warp_scf_for_use_from_above(%arg0: index) {
   return
 }
 
+// -----
+// CHECK-PROP-LABEL:  func.func @warp_scf_for_local_loop_bounds
+// CHECK-PROP:          (%{{.*}}: index, %[[ARG1:[a-zA-Z0-9]+]]: index) {
+// CHECK-PROP:          %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32] args(%[[ARG1]] : index) -> (vector<4xf32>) {
+// CHECK-PROP:          ^bb0(%{{.*}}: index):
+// CHECK-PROP:            %[[T2:.*]] = "some_def"() : () -> vector<128xf32>
+// CHECK-PROP:            gpu.yield %[[T2]] : vector<128xf32>
+// CHECK-PROP:          }
+// CHECK-PROP:          %[[FOR:.*]] = scf.for %{{.*}} to %[[ARG1]] step %{{.*}} iter_args(%{{.*}}) -> (vector<4xf32>) {
+// CHECK-PROP:            %[[W2:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32]
+// CHECK-PROP-SAME:          args(%{{.*}} : vector<4xf32>) -> (vector<4xf32>) {
+// CHECK-PROP:            ^bb0(%{{.*}}: vector<128xf32>):
+// CHECK-PROP:              gpu.yield %{{.*}} : vector<128xf32>
+// CHECK-PROP:            }
+// CHECK-PROP:            scf.yield %[[W2]] : vector<4xf32>
+// CHECK-PROP:          }
+// CHECK-PROP:          "some_use"(%[[FOR]]) : (vector<4xf32>) -> ()
+// CHECK-PROP:          return
+func.func @warp_scf_for_local_loop_bounds(%arg0: index, %bound: index) {
+  %c1 = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %0 = gpu.warp_execute_on_lane_0(%arg0)[32]
+    args(%bound : index) -> (vector<4xf32>) {
+    ^bb0(%arg1: index):
+    %ini = "some_def"() : () -> (vector<128xf32>)
+    %3 = scf.for %arg3 = %c0 to %arg1 step %c1 iter_args(%arg4 = %ini) -> (vector<128xf32>) {
+      %acc = "some_def"(%arg4) : (vector<128xf32>) -> (vector<128xf32>)
+      scf.yield %acc : vector<128xf32>
+    }
+    gpu.yield %3 : vector<128xf32>
+  }
+  "some_use"(%0) : (vector<4xf32>) -> ()
+  return
+}
+
 // -----
 
 // CHECK-PROP-LABEL:   func @warp_scf_for_swap(

Copy link
Contributor

@akroviakov akroviakov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose it uses WarpOpForwardOperand for clean-up of warp result usage. The case is valid, although generally I'd expect scalar usages inside warpOp to be "forwarded" to outer definitions beforehand, similar to how scalar definitions are hoisted.

@charithaintc
Copy link
Contributor Author

The case is valid, although generally I'd expect scalar usages inside warpOp to be "forwarded" to outer definitions beforehand, similar to how scalar definitions are hoisted.

I get your point. But above not always true. In flash attention loop bounds comes from kernel args. then they can not be hoisted like scalar constants. hence compilation fails in flash attention. Even the test case I have in this PR also fails (much simpler).

@akroviakov
Copy link
Contributor

Yeah, sure, I did not mean to hoist anything. I meant that scalar values coming from the outside may be referenced directly, not necessarily via the warpOp arguments. Meaning that some pass, similar to hoisting in terms of when it should be called, would clean up things beforehand, simplifying the work for the actual distribution patterns.

@charithaintc
Copy link
Contributor Author

Yeah, sure, I did not mean to hoist anything. I meant that scalar values coming from the outside may be referenced directly, not necessarily via the warpOp arguments. Meaning that some pass, similar to hoisting in terms of when it should be called, would clean up things beforehand, simplifying the work for the actual distribution patterns.

agreed. uniform scalars/vectors are hoisted before even we begin distribution. I guess this is to reduce compile time (no need of repeated application of WarpForward + WarpDead) . but unfortunately that does not help us here.

@akroviakov
Copy link
Contributor

but unfortunately that does not help us here.

So the test fails if the loop op references %bound directly:

func.func @warp_scf_for_local_loop_bounds(%arg0: index, %bound: index) {
...
  %0 = gpu.warp_execute_on_lane_0(%arg0)[32] -> (vector<4xf32>) {
    %ini = "some_def"() : () -> (vector<128xf32>)
    %3 = scf.for %arg3 = %c0 to %bound step %c1 iter_args(%arg4 = %ini) -> (vector<128xf32>) {

?

@charithaintc
Copy link
Contributor Author

but unfortunately that does not help us here.

So the test fails if the loop op references %bound directly:

func.func @warp_scf_for_local_loop_bounds(%arg0: index, %bound: index) {
...
  %0 = gpu.warp_execute_on_lane_0(%arg0)[32] -> (vector<4xf32>) {
    %ini = "some_def"() : () -> (vector<128xf32>)
    %3 = scf.for %arg3 = %c0 to %bound step %c1 iter_args(%arg4 = %ini) -> (vector<128xf32>) {

?

no. only if the %bound is a warp of arg, it will fail. I guess it is sensitive to pattern application ordering. but with this fix all the loop op arguments are yielded regardless of their nature. So it should always work.

@akroviakov
Copy link
Contributor

Yes, the fix is fine. I was just confused by your reply

but unfortunately that does not help us here.

to my comment

I'd expect scalar usages inside warpOp to be "forwarded" to outer definitions beforehand

I meant that scalar values coming from the outside may be referenced directly, not necessarily via the warpOp arguments. Meaning that some pass, similar to hoisting in terms of when it should be called, would clean up things beforehand

@charithaintc
Copy link
Contributor Author

Yes, the fix is fine. I was just confused by your reply

but unfortunately that does not help us here.

to my comment

I'd expect scalar usages inside warpOp to be "forwarded" to outer definitions beforehand

I meant that scalar values coming from the outside may be referenced directly, not necessarily via the warpOp arguments. Meaning that some pass, similar to hoisting in terms of when it should be called, would clean up things beforehand

I see. sorry about the confusion.

Copy link
Contributor

@Jianhui-Li Jianhui-Li left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@dcaballe dcaballe requested a review from Groverkss October 15, 2025 21:33
@dcaballe
Copy link
Contributor

Not sure why @Groverkss wasn't added to this PR automatically. We may have messed up the CODEOWNERS file again?

@charithaintc
Copy link
Contributor Author

Not sure why @Groverkss wasn't added to this PR automatically. We may have messed up the CODEOWNERS file again?

Hi @dcaballe, We made several changes to this file over the last few months and only these reviewers were added by default.

@hanhanW
Copy link
Contributor

hanhanW commented Oct 16, 2025

Not sure why @Groverkss wasn't added to this PR automatically. We may have messed up the CODEOWNERS file again?

https://docs.github.com/en/repositories/managing-your-repositorys-settings-and-features/customizing-your-repository/about-code-owners#example-of-a-codeowners-file

# Order is important; the last matching pattern takes the most
# precedence. When someone opens a pull request that only
# modifies JS files, only @js-owner and not the global
# owner(s) will be requested for a review.
*.js    @js-owner #This is an inline comment.

I think it is because @Groverkss is not on the /mlir/lib/Dialect/Vector/Transforms/* list, so the "global" owners do not get the request. It is a known issue to me, and the solution is adding your handle to all the path that you care.

# Vector Dialect in MLIR.
/mlir/**/*AMX* @aartbik @dcaballe
/mlir/**/*Neon* @banach-space @dcaballe @nicolasvasilache
/mlir/**/*SME* @banach-space @dcaballe @nicolasvasilache
/mlir/**/*SVE* @banach-space @dcaballe @nicolasvasilache
/mlir/**/*VectorInterfaces* @dcaballe @nicolasvasilache
/mlir/**/*VectorToSCF* @banach-space @dcaballe @matthias-springer @nicolasvasilache
/mlir/**/*VectorToLLVM* @banach-space @dcaballe @nicolasvasilache
/mlir/**/*X86Vector* @aartbik @dcaballe @nicolasvasilache
/mlir/include/mlir/Dialect/Vector @banach-space @dcaballe @nicolasvasilache @Groverkss
/mlir/include/mlir/Dialect/Vector/IR @kuhar
/mlir/lib/Dialect/Vector @banach-space @dcaballe @nicolasvasilache @Groverkss
/mlir/lib/Dialect/Vector/Transforms/* @banach-space @dcaballe @hanhanW @nicolasvasilache
/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @banach-space @dcaballe @MaheshRavishankar @nicolasvasilache
/mlir/**/*EmulateNarrowType* @dcaballe @hanhanW

@charithaintc
Copy link
Contributor Author

Not sure why @Groverkss wasn't added to this PR automatically. We may have messed up the CODEOWNERS file again?

Hi @Groverkss, Do you have any comments/concerns regarding this change? Otherwise I would like to merge this PR. Thanks!

@charithaintc charithaintc merged commit f7a5264 into llvm:main Oct 17, 2025
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants