-
Notifications
You must be signed in to change notification settings - Fork 15k
[mlir][vector] Add support for yielding loop bounds in scf.for
distribution.
#163443
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector] Add support for yielding loop bounds in scf.for
distribution.
#163443
Conversation
@llvm/pr-subscribers-mlir Author: Charitha Saumya (charithaintc) ChangesIn some cases, loop bounds (lower, upper and step) of In this PR, we have added logic to yield loop bounds by default (treat them as other operands of Full diff: https://github.com/llvm/llvm-project/pull/163443.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index e95338f7d18be..2ee65dc0f902a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -2038,11 +2038,19 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
}
// Newly created `WarpOp` will yield values in following order:
- // 1. All init args of the `ForOp`.
- // 2. All escaping values.
- // 3. All non-`ForOp` yielded values.
+ // 1. Loop bounds.
+ // 2. All init args of the `ForOp`.
+ // 3. All escaping values.
+ // 4. All non-`ForOp` yielded values.
SmallVector<Value> newWarpOpYieldValues;
SmallVector<Type> newWarpOpDistTypes;
+ newWarpOpYieldValues.insert(
+ newWarpOpYieldValues.end(),
+ {forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep()});
+ newWarpOpDistTypes.insert(newWarpOpDistTypes.end(),
+ {forOp.getLowerBound().getType(),
+ forOp.getUpperBound().getType(),
+ forOp.getStep().getType()});
for (auto [i, initArg] : llvm::enumerate(forOp.getInitArgs())) {
newWarpOpYieldValues.push_back(initArg);
// Compute the distributed type for this init arg.
@@ -2081,20 +2089,23 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
// Next, we create a new `ForOp` with the init args yielded by the new
// `WarpOp`.
+ const unsigned initArgsStartIdx = 3; // After loop bounds.
const unsigned escapingValuesStartIdx =
+ initArgsStartIdx +
forOp.getInitArgs().size(); // `ForOp` init args are positioned before
// escaping values in the new `WarpOp`.
SmallVector<Value> newForOpOperands;
- for (size_t i = 0; i < escapingValuesStartIdx; ++i)
+ for (size_t i = initArgsStartIdx; i < escapingValuesStartIdx; ++i)
newForOpOperands.push_back(newWarpOp.getResult(i));
// Create a new `ForOp` outside the new `WarpOp` region.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointAfter(newWarpOp);
auto newForOp = scf::ForOp::create(
- rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
- forOp.getStep(), newForOpOperands, /*bodyBuilder=*/nullptr,
- forOp.getUnsignedCmp());
+ rewriter, forOp.getLoc(), /**LowerBound=**/ newWarpOp.getResult(0),
+ /**UpperBound=**/ newWarpOp.getResult(1),
+ /**Step=**/ newWarpOp.getResult(2), newForOpOperands,
+ /*bodyBuilder=*/nullptr, forOp.getUnsignedCmp());
// Next, we insert a new `WarpOp` (called inner `WarpOp`) inside the
// newly created `ForOp`. This `WarpOp` will contain all ops that were
// contained within the original `ForOp` body.
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index bb7639204022f..ab87684dbb01a 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -473,6 +473,41 @@ func.func @warp_scf_for_use_from_above(%arg0: index) {
return
}
+// -----
+// CHECK-PROP-LABEL: func.func @warp_scf_for_local_loop_bounds
+// CHECK-PROP: (%{{.*}}: index, %[[ARG1:[a-zA-Z0-9]+]]: index) {
+// CHECK-PROP: %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32] args(%[[ARG1]] : index) -> (vector<4xf32>) {
+// CHECK-PROP: ^bb0(%{{.*}}: index):
+// CHECK-PROP: %[[T2:.*]] = "some_def"() : () -> vector<128xf32>
+// CHECK-PROP: gpu.yield %[[T2]] : vector<128xf32>
+// CHECK-PROP: }
+// CHECK-PROP: %[[FOR:.*]] = scf.for %{{.*}} to %[[ARG1]] step %{{.*}} iter_args(%{{.*}}) -> (vector<4xf32>) {
+// CHECK-PROP: %[[W2:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32]
+// CHECK-PROP-SAME: args(%{{.*}} : vector<4xf32>) -> (vector<4xf32>) {
+// CHECK-PROP: ^bb0(%{{.*}}: vector<128xf32>):
+// CHECK-PROP: gpu.yield %{{.*}} : vector<128xf32>
+// CHECK-PROP: }
+// CHECK-PROP: scf.yield %[[W2]] : vector<4xf32>
+// CHECK-PROP: }
+// CHECK-PROP: "some_use"(%[[FOR]]) : (vector<4xf32>) -> ()
+// CHECK-PROP: return
+func.func @warp_scf_for_local_loop_bounds(%arg0: index, %bound: index) {
+ %c1 = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %0 = gpu.warp_execute_on_lane_0(%arg0)[32]
+ args(%bound : index) -> (vector<4xf32>) {
+ ^bb0(%arg1: index):
+ %ini = "some_def"() : () -> (vector<128xf32>)
+ %3 = scf.for %arg3 = %c0 to %arg1 step %c1 iter_args(%arg4 = %ini) -> (vector<128xf32>) {
+ %acc = "some_def"(%arg4) : (vector<128xf32>) -> (vector<128xf32>)
+ scf.yield %acc : vector<128xf32>
+ }
+ gpu.yield %3 : vector<128xf32>
+ }
+ "some_use"(%0) : (vector<4xf32>) -> ()
+ return
+}
+
// -----
// CHECK-PROP-LABEL: func @warp_scf_for_swap(
|
@llvm/pr-subscribers-mlir-vector Author: Charitha Saumya (charithaintc) ChangesIn some cases, loop bounds (lower, upper and step) of In this PR, we have added logic to yield loop bounds by default (treat them as other operands of Full diff: https://github.com/llvm/llvm-project/pull/163443.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index e95338f7d18be..2ee65dc0f902a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -2038,11 +2038,19 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
}
// Newly created `WarpOp` will yield values in following order:
- // 1. All init args of the `ForOp`.
- // 2. All escaping values.
- // 3. All non-`ForOp` yielded values.
+ // 1. Loop bounds.
+ // 2. All init args of the `ForOp`.
+ // 3. All escaping values.
+ // 4. All non-`ForOp` yielded values.
SmallVector<Value> newWarpOpYieldValues;
SmallVector<Type> newWarpOpDistTypes;
+ newWarpOpYieldValues.insert(
+ newWarpOpYieldValues.end(),
+ {forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep()});
+ newWarpOpDistTypes.insert(newWarpOpDistTypes.end(),
+ {forOp.getLowerBound().getType(),
+ forOp.getUpperBound().getType(),
+ forOp.getStep().getType()});
for (auto [i, initArg] : llvm::enumerate(forOp.getInitArgs())) {
newWarpOpYieldValues.push_back(initArg);
// Compute the distributed type for this init arg.
@@ -2081,20 +2089,23 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
// Next, we create a new `ForOp` with the init args yielded by the new
// `WarpOp`.
+ const unsigned initArgsStartIdx = 3; // After loop bounds.
const unsigned escapingValuesStartIdx =
+ initArgsStartIdx +
forOp.getInitArgs().size(); // `ForOp` init args are positioned before
// escaping values in the new `WarpOp`.
SmallVector<Value> newForOpOperands;
- for (size_t i = 0; i < escapingValuesStartIdx; ++i)
+ for (size_t i = initArgsStartIdx; i < escapingValuesStartIdx; ++i)
newForOpOperands.push_back(newWarpOp.getResult(i));
// Create a new `ForOp` outside the new `WarpOp` region.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointAfter(newWarpOp);
auto newForOp = scf::ForOp::create(
- rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
- forOp.getStep(), newForOpOperands, /*bodyBuilder=*/nullptr,
- forOp.getUnsignedCmp());
+ rewriter, forOp.getLoc(), /**LowerBound=**/ newWarpOp.getResult(0),
+ /**UpperBound=**/ newWarpOp.getResult(1),
+ /**Step=**/ newWarpOp.getResult(2), newForOpOperands,
+ /*bodyBuilder=*/nullptr, forOp.getUnsignedCmp());
// Next, we insert a new `WarpOp` (called inner `WarpOp`) inside the
// newly created `ForOp`. This `WarpOp` will contain all ops that were
// contained within the original `ForOp` body.
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index bb7639204022f..ab87684dbb01a 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -473,6 +473,41 @@ func.func @warp_scf_for_use_from_above(%arg0: index) {
return
}
+// -----
+// CHECK-PROP-LABEL: func.func @warp_scf_for_local_loop_bounds
+// CHECK-PROP: (%{{.*}}: index, %[[ARG1:[a-zA-Z0-9]+]]: index) {
+// CHECK-PROP: %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32] args(%[[ARG1]] : index) -> (vector<4xf32>) {
+// CHECK-PROP: ^bb0(%{{.*}}: index):
+// CHECK-PROP: %[[T2:.*]] = "some_def"() : () -> vector<128xf32>
+// CHECK-PROP: gpu.yield %[[T2]] : vector<128xf32>
+// CHECK-PROP: }
+// CHECK-PROP: %[[FOR:.*]] = scf.for %{{.*}} to %[[ARG1]] step %{{.*}} iter_args(%{{.*}}) -> (vector<4xf32>) {
+// CHECK-PROP: %[[W2:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32]
+// CHECK-PROP-SAME: args(%{{.*}} : vector<4xf32>) -> (vector<4xf32>) {
+// CHECK-PROP: ^bb0(%{{.*}}: vector<128xf32>):
+// CHECK-PROP: gpu.yield %{{.*}} : vector<128xf32>
+// CHECK-PROP: }
+// CHECK-PROP: scf.yield %[[W2]] : vector<4xf32>
+// CHECK-PROP: }
+// CHECK-PROP: "some_use"(%[[FOR]]) : (vector<4xf32>) -> ()
+// CHECK-PROP: return
+func.func @warp_scf_for_local_loop_bounds(%arg0: index, %bound: index) {
+ %c1 = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %0 = gpu.warp_execute_on_lane_0(%arg0)[32]
+ args(%bound : index) -> (vector<4xf32>) {
+ ^bb0(%arg1: index):
+ %ini = "some_def"() : () -> (vector<128xf32>)
+ %3 = scf.for %arg3 = %c0 to %arg1 step %c1 iter_args(%arg4 = %ini) -> (vector<128xf32>) {
+ %acc = "some_def"(%arg4) : (vector<128xf32>) -> (vector<128xf32>)
+ scf.yield %acc : vector<128xf32>
+ }
+ gpu.yield %3 : vector<128xf32>
+ }
+ "some_use"(%0) : (vector<4xf32>) -> ()
+ return
+}
+
// -----
// CHECK-PROP-LABEL: func @warp_scf_for_swap(
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suppose it uses WarpOpForwardOperand
for clean-up of warp result usage. The case is valid, although generally I'd expect scalar usages inside warpOp to be "forwarded" to outer definitions beforehand, similar to how scalar definitions are hoisted.
I get your point. But above not always true. In flash attention loop bounds comes from kernel args. then they can not be hoisted like scalar constants. hence compilation fails in flash attention. Even the test case I have in this PR also fails (much simpler). |
Yeah, sure, I did not mean to hoist anything. I meant that scalar values coming from the outside may be referenced directly, not necessarily via the warpOp arguments. Meaning that some pass, similar to hoisting in terms of when it should be called, would clean up things beforehand, simplifying the work for the actual distribution patterns. |
agreed. uniform scalars/vectors are hoisted before even we begin distribution. I guess this is to reduce compile time (no need of repeated application of WarpForward + WarpDead) . but unfortunately that does not help us here. |
So the test fails if the loop op references
? |
no. only if the %bound is a warp of arg, it will fail. I guess it is sensitive to pattern application ordering. but with this fix all the loop op arguments are yielded regardless of their nature. So it should always work. |
Yes, the fix is fine. I was just confused by your reply
to my comment
|
I see. sorry about the confusion. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Not sure why @Groverkss wasn't added to this PR automatically. We may have messed up the CODEOWNERS file again? |
Hi @dcaballe, We made several changes to this file over the last few months and only these reviewers were added by default. |
I think it is because @Groverkss is not on the llvm-project/.github/CODEOWNERS Lines 79 to 93 in 35cd291
|
Hi @Groverkss, Do you have any comments/concerns regarding this change? Otherwise I would like to merge this PR. Thanks! |
In some cases, loop bounds (lower, upper and step) of
scf.for
can come locally from the parent warp op thescf.for
. Current logic will not yield the loop bounds in the new warp op generated during lowering causing sinkedscf.for
to have non dominating use.In this PR, we have added logic to yield loop bounds by default (treat them as other operands of
scf.for
) which fixes this bug.