Skip to content

Commit d9ad36d

Browse files
committed
add fix
1 parent f5ca0bc commit d9ad36d

File tree

2 files changed

+53
-7
lines changed

2 files changed

+53
-7
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2038,11 +2038,19 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
20382038
}
20392039

20402040
// Newly created `WarpOp` will yield values in following order:
2041-
// 1. All init args of the `ForOp`.
2042-
// 2. All escaping values.
2043-
// 3. All non-`ForOp` yielded values.
2041+
// 1. Loop bounds.
2042+
// 2. All init args of the `ForOp`.
2043+
// 3. All escaping values.
2044+
// 4. All non-`ForOp` yielded values.
20442045
SmallVector<Value> newWarpOpYieldValues;
20452046
SmallVector<Type> newWarpOpDistTypes;
2047+
newWarpOpYieldValues.insert(
2048+
newWarpOpYieldValues.end(),
2049+
{forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep()});
2050+
newWarpOpDistTypes.insert(newWarpOpDistTypes.end(),
2051+
{forOp.getLowerBound().getType(),
2052+
forOp.getUpperBound().getType(),
2053+
forOp.getStep().getType()});
20462054
for (auto [i, initArg] : llvm::enumerate(forOp.getInitArgs())) {
20472055
newWarpOpYieldValues.push_back(initArg);
20482056
// Compute the distributed type for this init arg.
@@ -2081,20 +2089,23 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
20812089

20822090
// Next, we create a new `ForOp` with the init args yielded by the new
20832091
// `WarpOp`.
2092+
const unsigned initArgsStartIdx = 3; // After loop bounds.
20842093
const unsigned escapingValuesStartIdx =
2094+
initArgsStartIdx +
20852095
forOp.getInitArgs().size(); // `ForOp` init args are positioned before
20862096
// escaping values in the new `WarpOp`.
20872097
SmallVector<Value> newForOpOperands;
2088-
for (size_t i = 0; i < escapingValuesStartIdx; ++i)
2098+
for (size_t i = initArgsStartIdx; i < escapingValuesStartIdx; ++i)
20892099
newForOpOperands.push_back(newWarpOp.getResult(i));
20902100

20912101
// Create a new `ForOp` outside the new `WarpOp` region.
20922102
OpBuilder::InsertionGuard g(rewriter);
20932103
rewriter.setInsertionPointAfter(newWarpOp);
20942104
auto newForOp = scf::ForOp::create(
2095-
rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
2096-
forOp.getStep(), newForOpOperands, /*bodyBuilder=*/nullptr,
2097-
forOp.getUnsignedCmp());
2105+
rewriter, forOp.getLoc(), /**LowerBound=**/ newWarpOp.getResult(0),
2106+
/**UpperBound=**/ newWarpOp.getResult(1),
2107+
/**Step=**/ newWarpOp.getResult(2), newForOpOperands,
2108+
/*bodyBuilder=*/nullptr, forOp.getUnsignedCmp());
20982109
// Next, we insert a new `WarpOp` (called inner `WarpOp`) inside the
20992110
// newly created `ForOp`. This `WarpOp` will contain all ops that were
21002111
// contained within the original `ForOp` body.

mlir/test/Dialect/Vector/vector-warp-distribute.mlir

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,41 @@ func.func @warp_scf_for_use_from_above(%arg0: index) {
473473
return
474474
}
475475

476+
// -----
477+
// CHECK-PROP-LABEL: func.func @warp_scf_for_local_loop_bounds
478+
// CHECK-PROP: (%{{.*}}: index, %[[ARG1:[a-zA-Z0-9]+]]: index) {
479+
// CHECK-PROP: %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32] args(%[[ARG1]] : index) -> (vector<4xf32>) {
480+
// CHECK-PROP: ^bb0(%{{.*}}: index):
481+
// CHECK-PROP: %[[T2:.*]] = "some_def"() : () -> vector<128xf32>
482+
// CHECK-PROP: gpu.yield %[[T2]] : vector<128xf32>
483+
// CHECK-PROP: }
484+
// CHECK-PROP: %[[FOR:.*]] = scf.for %{{.*}} to %[[ARG1]] step %{{.*}} iter_args(%{{.*}}) -> (vector<4xf32>) {
485+
// CHECK-PROP: %[[W2:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32]
486+
// CHECK-PROP-SAME: args(%{{.*}} : vector<4xf32>) -> (vector<4xf32>) {
487+
// CHECK-PROP: ^bb0(%{{.*}}: vector<128xf32>):
488+
// CHECK-PROP: gpu.yield %{{.*}} : vector<128xf32>
489+
// CHECK-PROP: }
490+
// CHECK-PROP: scf.yield %[[W2]] : vector<4xf32>
491+
// CHECK-PROP: }
492+
// CHECK-PROP: "some_use"(%[[FOR]]) : (vector<4xf32>) -> ()
493+
// CHECK-PROP: return
494+
func.func @warp_scf_for_local_loop_bounds(%arg0: index, %bound: index) {
495+
%c1 = arith.constant 1 : index
496+
%c0 = arith.constant 0 : index
497+
%0 = gpu.warp_execute_on_lane_0(%arg0)[32]
498+
args(%bound : index) -> (vector<4xf32>) {
499+
^bb0(%arg1: index):
500+
%ini = "some_def"() : () -> (vector<128xf32>)
501+
%3 = scf.for %arg3 = %c0 to %arg1 step %c1 iter_args(%arg4 = %ini) -> (vector<128xf32>) {
502+
%acc = "some_def"(%arg4) : (vector<128xf32>) -> (vector<128xf32>)
503+
scf.yield %acc : vector<128xf32>
504+
}
505+
gpu.yield %3 : vector<128xf32>
506+
}
507+
"some_use"(%0) : (vector<4xf32>) -> ()
508+
return
509+
}
510+
476511
// -----
477512

478513
// CHECK-PROP-LABEL: func @warp_scf_for_swap(

0 commit comments

Comments
 (0)