Skip to content
39 changes: 29 additions & 10 deletions mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1554,22 +1554,36 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
llvm::SmallSetVector<Value, 32> escapingValues;
SmallVector<Type> inputTypes;
SmallVector<Type> distTypes;
auto collectEscapingValues = [&](Value value) {
if (!escapingValues.insert(value))
return;
Type distType = value.getType();
if (auto vecType = dyn_cast<VectorType>(distType)) {
AffineMap map = distributionMapFn(value);
distType = getDistributedType(vecType, map, warpOp.getWarpSize());
}
inputTypes.push_back(value.getType());
distTypes.push_back(distType);
};

mlir::visitUsedValuesDefinedAbove(
forOp.getBodyRegion(), [&](OpOperand *operand) {
Operation *parent = operand->get().getParentRegion()->getParentOp();
if (warpOp->isAncestor(parent)) {
if (!escapingValues.insert(operand->get()))
return;
Type distType = operand->get().getType();
if (auto vecType = dyn_cast<VectorType>(distType)) {
AffineMap map = distributionMapFn(operand->get());
distType = getDistributedType(vecType, map, warpOp.getWarpSize());
}
inputTypes.push_back(operand->get().getType());
distTypes.push_back(distType);
collectEscapingValues(operand->get());
}
});

// Any forOp result that is not already yielded by the warpOp
// region is also considered escaping and must be returned by the
// original warpOp.
for (OpResult forResult : forOp.getResults()) {
// Check if this forResult is already yielded by the yield op.
if (llvm::is_contained(yield->getOperands(), forResult))
continue;
collectEscapingValues(forResult);
}

if (llvm::is_contained(distTypes, Type{}))
return failure();

Expand Down Expand Up @@ -1609,7 +1623,12 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
forOp.getResultTypes().end());
llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
for (auto [i, retIdx] : llvm::enumerate(newRetIndices)) {
warpInput.push_back(newWarpOp.getResult(retIdx));
auto newWarpResult = newWarpOp.getResult(retIdx);
// Unused forOp results yielded by the warpOp region are already included
// in the new ForOp.
if (llvm::is_contained(newOperands, newWarpResult))
continue;
warpInput.push_back(newWarpResult);
argIndexMapping[escapingValues[i]] = warpInputType.size();
warpInputType.push_back(inputTypes[i]);
}
Expand Down
36 changes: 36 additions & 0 deletions mlir/test/Dialect/Vector/vector-warp-distribute.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,42 @@ func.func @warp_scf_for_multiple_yield(%arg0: index, %arg1: memref<?xf32>, %arg2
return
}

// -----
// CHECK-PROP-LABEL: func.func @warp_scf_for_unused_yield(
// CHECK-PROP: %[[W0:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<4xf32>, vector<4xf32>) {
// CHECK-PROP: %[[INI0:.*]] = "some_def"() : () -> vector<128xf32>
// CHECK-PROP: %[[INI1:.*]] = "some_def"() : () -> vector<128xf32>
// CHECK-PROP: gpu.yield %[[INI0]], %[[INI1]] : vector<128xf32>, vector<128xf32>
// CHECK-PROP: }
// CHECK-PROP: %[[F:.*]]:2 = scf.for %{{.*}} iter_args(%{{.*}} = %[[W0]]#0, %{{.*}} = %[[W0]]#1) -> (vector<4xf32>, vector<4xf32>) {
// CHECK-PROP: %[[W1:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[32] args(%{{.*}} : vector<4xf32>, vector<4xf32>) -> (vector<4xf32>, vector<4xf32>) {
// CHECK-PROP: %[[ACC0:.*]] = "some_def"(%{{.*}}) : (vector<128xf32>, index) -> vector<128xf32>
// CHECK-PROP: %[[ACC1:.*]] = "some_def"(%{{.*}}) : (index, vector<128xf32>, vector<128xf32>) -> vector<128xf32>
// CHECK-PROP: gpu.yield %[[ACC1]], %[[ACC0]] : vector<128xf32>, vector<128xf32>
// CHECK-PROP: }
// CHECK-PROP: scf.yield %[[W1]]#0, %[[W1]]#1 : vector<4xf32>, vector<4xf32>
// CHECK-PROP: }
// CHECK-PROP: "some_use"(%[[F]]#0) : (vector<4xf32>) -> ()
func.func @warp_scf_for_unused_yield(%arg0: index) {
%c128 = arith.constant 128 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%0 = gpu.warp_execute_on_lane_0(%arg0)[32] -> (vector<4xf32>) {
%ini = "some_def"() : () -> (vector<128xf32>)
%ini1 = "some_def"() : () -> (vector<128xf32>)
%3:2 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini, %arg5 = %ini1) -> (vector<128xf32>, vector<128xf32>) {
%add = arith.addi %arg3, %c1 : index
%1 = "some_def"(%arg5, %add) : (vector<128xf32>, index) -> (vector<128xf32>)
%acc = "some_def"(%add, %arg4, %1) : (index, vector<128xf32>, vector<128xf32>) -> (vector<128xf32>)
scf.yield %acc, %1 : vector<128xf32>, vector<128xf32>
}
gpu.yield %3#0 : vector<128xf32>
}
"some_use"(%0) : (vector<4xf32>) -> ()
return
}


// -----

// CHECK-PROP-LABEL: func @vector_reduction(
Expand Down
Loading