Skip to content

Commit f7a5264

Browse files
authored
[mlir][vector] Add support for yielding loop bounds in scf.for distribution. (llvm#163443)
In some cases, loop bounds (lower, upper and step) of `scf.for` can come locally from the parent warp op the `scf.for`. Current logic will not yield the loop bounds in the new warp op generated during lowering causing sinked `scf.for` to have non dominating use. In this PR, we have added logic to yield loop bounds by default (treat them as other operands of `scf.for`) which fixes this bug.
1 parent 404099d commit f7a5264

File tree

2 files changed

+54
-7
lines changed

2 files changed

+54
-7
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2032,11 +2032,19 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
20322032
}
20332033

20342034
// Newly created `WarpOp` will yield values in following order:
2035-
// 1. All init args of the `ForOp`.
2036-
// 2. All escaping values.
2037-
// 3. All non-`ForOp` yielded values.
2035+
// 1. Loop bounds.
2036+
// 2. All init args of the `ForOp`.
2037+
// 3. All escaping values.
2038+
// 4. All non-`ForOp` yielded values.
20382039
SmallVector<Value> newWarpOpYieldValues;
20392040
SmallVector<Type> newWarpOpDistTypes;
2041+
newWarpOpYieldValues.insert(
2042+
newWarpOpYieldValues.end(),
2043+
{forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep()});
2044+
newWarpOpDistTypes.insert(newWarpOpDistTypes.end(),
2045+
{forOp.getLowerBound().getType(),
2046+
forOp.getUpperBound().getType(),
2047+
forOp.getStep().getType()});
20402048
for (auto [i, initArg] : llvm::enumerate(forOp.getInitArgs())) {
20412049
newWarpOpYieldValues.push_back(initArg);
20422050
// Compute the distributed type for this init arg.
@@ -2072,20 +2080,24 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
20722080

20732081
// Next, we create a new `ForOp` with the init args yielded by the new
20742082
// `WarpOp`.
2083+
const unsigned initArgsStartIdx = 3; // After loop bounds.
20752084
const unsigned escapingValuesStartIdx =
2085+
initArgsStartIdx +
20762086
forOp.getInitArgs().size(); // `ForOp` init args are positioned before
20772087
// escaping values in the new `WarpOp`.
20782088
SmallVector<Value> newForOpOperands;
2079-
for (size_t i = 0; i < escapingValuesStartIdx; ++i)
2089+
for (size_t i = initArgsStartIdx; i < escapingValuesStartIdx; ++i)
20802090
newForOpOperands.push_back(newWarpOp.getResult(newIndices[i]));
20812091

20822092
// Create a new `ForOp` outside the new `WarpOp` region.
20832093
OpBuilder::InsertionGuard g(rewriter);
20842094
rewriter.setInsertionPointAfter(newWarpOp);
20852095
auto newForOp = scf::ForOp::create(
2086-
rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
2087-
forOp.getStep(), newForOpOperands, /*bodyBuilder=*/nullptr,
2088-
forOp.getUnsignedCmp());
2096+
rewriter, forOp.getLoc(),
2097+
/**LowerBound=**/ newWarpOp.getResult(newIndices[0]),
2098+
/**UpperBound=**/ newWarpOp.getResult(newIndices[1]),
2099+
/**Step=**/ newWarpOp.getResult(newIndices[2]), newForOpOperands,
2100+
/*bodyBuilder=*/nullptr, forOp.getUnsignedCmp());
20892101
// Next, we insert a new `WarpOp` (called inner `WarpOp`) inside the
20902102
// newly created `ForOp`. This `WarpOp` will contain all ops that were
20912103
// contained within the original `ForOp` body.

mlir/test/Dialect/Vector/vector-warp-distribute.mlir

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,41 @@ func.func @warp_scf_for_use_from_above(%arg0: index) {
473473
return
474474
}
475475

476+
// -----
477+
// CHECK-PROP-LABEL: func.func @warp_scf_for_local_loop_bounds
478+
// CHECK-PROP: (%{{.*}}: index, %[[ARG1:[a-zA-Z0-9]+]]: index) {
479+
// CHECK-PROP: %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32] args(%[[ARG1]] : index) -> (vector<4xf32>) {
480+
// CHECK-PROP: ^bb0(%{{.*}}: index):
481+
// CHECK-PROP: %[[T2:.*]] = "some_def"() : () -> vector<128xf32>
482+
// CHECK-PROP: gpu.yield %[[T2]] : vector<128xf32>
483+
// CHECK-PROP: }
484+
// CHECK-PROP: %[[FOR:.*]] = scf.for %{{.*}} to %[[ARG1]] step %{{.*}} iter_args(%{{.*}}) -> (vector<4xf32>) {
485+
// CHECK-PROP: %[[W2:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32]
486+
// CHECK-PROP-SAME: args(%{{.*}} : vector<4xf32>) -> (vector<4xf32>) {
487+
// CHECK-PROP: ^bb0(%{{.*}}: vector<128xf32>):
488+
// CHECK-PROP: gpu.yield %{{.*}} : vector<128xf32>
489+
// CHECK-PROP: }
490+
// CHECK-PROP: scf.yield %[[W2]] : vector<4xf32>
491+
// CHECK-PROP: }
492+
// CHECK-PROP: "some_use"(%[[FOR]]) : (vector<4xf32>) -> ()
493+
// CHECK-PROP: return
494+
func.func @warp_scf_for_local_loop_bounds(%arg0: index, %bound: index) {
495+
%c1 = arith.constant 1 : index
496+
%c0 = arith.constant 0 : index
497+
%0 = gpu.warp_execute_on_lane_0(%arg0)[32]
498+
args(%bound : index) -> (vector<4xf32>) {
499+
^bb0(%arg1: index):
500+
%ini = "some_def"() : () -> (vector<128xf32>)
501+
%3 = scf.for %arg3 = %c0 to %arg1 step %c1 iter_args(%arg4 = %ini) -> (vector<128xf32>) {
502+
%acc = "some_def"(%arg4) : (vector<128xf32>) -> (vector<128xf32>)
503+
scf.yield %acc : vector<128xf32>
504+
}
505+
gpu.yield %3 : vector<128xf32>
506+
}
507+
"some_use"(%0) : (vector<4xf32>) -> ()
508+
return
509+
}
510+
476511
// -----
477512

478513
// CHECK-PROP-LABEL: func @warp_scf_for_swap(

0 commit comments

Comments
 (0)