Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 19 additions & 15 deletions mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1020,35 +1020,39 @@ struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
SmallVector<Location> locs(convertedTypes.size(), newFuncOp.getLoc());
entryBlock.addArguments(convertedTypes, locs);

// Replace the placeholder values with the new arguments. We assume there is
// only one block for now.
// Replace the placeholder values with the new arguments.
size_t unrolledInputIdx = 0;
for (auto [count, op] : enumerate(entryBlock.getOperations())) {
newFuncOp.walk([&](Operation *op) {
// We first look for operands that are placeholders for initially legal
// arguments.
Operation &curOp = op;
for (auto [operandIdx, operandVal] : llvm::enumerate(op.getOperands())) {
for (auto [operandIdx, operandVal] : llvm::enumerate(op->getOperands())) {
Operation *operandOp = operandVal.getDefiningOp();
if (auto it = tmpOps.find(operandOp); it != tmpOps.end()) {
size_t idx = operandIdx;
rewriter.modifyOpInPlace(&curOp, [&curOp, &newFuncOp, it, idx] {
curOp.setOperand(idx, newFuncOp.getArgument(it->second));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code originally only processes one block, but the placeholder values for the function arguments were applied with replaceAllUsesWith and may replace the function arguments across multiple blocks. The current implementation leaves the placeholder values in the other blocks (except for the first one) as is instead of the original function arguments, which is incorrect.

rewriter.modifyOpInPlace(op, [&] {
op->setOperand(idx, newFuncOp.getArgument(it->second));
});
}
}
// Since all newly created operations are in the beginning, reaching the
// end of them means that any later `vector.insert_strided_slice` should
// not be touched.
if (count >= newOpCount)
continue;

// Only consider `vector.insert_strided_slice` ops that were newly created
// at the beginning of the entry block. Once we encounter operations
// outside the entry block or past the `newOpCount`-th operation in the
// entry block, we skip and leave exisintg `vector.insert_strided_slice`
// ops as is.
if (op->getBlock() != &entryBlock ||
static_cast<size_t>(std::distance(entryBlock.begin(),
op->getIterator())) >= newOpCount)
return;

if (auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op)) {
size_t unrolledInputNo = unrolledInputNums[unrolledInputIdx];
rewriter.modifyOpInPlace(&curOp, [&] {
curOp.setOperand(0, newFuncOp.getArgument(unrolledInputNo));
rewriter.modifyOpInPlace(op, [&] {
op->setOperand(0, newFuncOp.getArgument(unrolledInputNo));
});
++unrolledInputIdx;
}
}
});

// Erase the original funcOp. The `tmpOps` do not need to be erased since
// they have no uses and will be handled by dead-code elimination.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,3 +189,76 @@ func.func @unsupported_scalable(%arg0 : vector<[8]xi32>) -> (vector<[8]xi32>) {
return %arg0 : vector<[8]xi32>
}

// -----

// Check that already legal function parameters are properly preserved across multiple blocks.

// CHECK-LABEL: func.func @legal_params_multiple_blocks_simple
// CHECK-SAME: (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32) -> i32
func.func @legal_params_multiple_blocks_simple(%arg0: i32, %arg1: i32) -> i32 {
// CHECK: %[[ADD0:.*]] = arith.addi %[[ARG0]], %[[ARG1]] : i32
// CHECK: %[[ADD1:.*]] = arith.addi %[[ADD0]], %[[ARG1]] : i32
// CHECK: return %[[ADD1]] : i32
cf.br ^bb1(%arg0 : i32)
^bb1(%acc0: i32):
%acc1_val = arith.addi %acc0, %arg1 : i32
cf.br ^bb2(%acc1_val : i32)
^bb2(%acc1: i32):
%acc2_val = arith.addi %acc1, %arg1 : i32
cf.br ^bb3(%acc2_val : i32)
^bb3(%acc_final: i32):
return %acc_final : i32
}

// -----

// Check that legal parameters and existing `vector.insert_strided_slice`s are properly preserved across multiple blocks.

// CHECK-LABEL: func.func @legal_params_with_vec_insert_multiple_blocks
// CHECK-SAME: (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32, %[[ARG2:.+]]: vector<4xi32>) -> vector<4xi32>
func.func @legal_params_with_vec_insert_multiple_blocks(%arg0: i32, %arg1: i32, %arg2: vector<4xi32>) -> vector<4xi32> {
// CHECK: %[[ADD0:.*]] = arith.addi %[[ARG0]], %[[ARG1]] : i32
// CHECK: %[[ADD1:.*]] = arith.addi %[[ADD0]], %[[ARG1]] : i32
// CHECK: %[[VEC1D:.*]] = vector.broadcast %[[ADD1]] : i32 to vector<1xi32>
// CHECK: %[[VEC0:.*]] = vector.insert_strided_slice %[[VEC1D]], %[[ARG2]] {offsets = [1], strides = [1]} : vector<1xi32> into vector<4xi32>
// CHECK: %[[VEC1:.*]] = vector.insert_strided_slice %[[VEC1D]], %[[VEC0]] {offsets = [2], strides = [1]} : vector<1xi32> into vector<4xi32>
// CHECK: %[[RESULT:.*]] = vector.insert_strided_slice %[[VEC1D]], %[[VEC1]] {offsets = [3], strides = [1]} : vector<1xi32> into vector<4xi32>
// CHECK: return %[[RESULT]] : vector<4xi32>
cf.br ^bb1(%arg0 : i32)
^bb1(%acc0: i32):
%acc1_val = arith.addi %acc0, %arg1 : i32
cf.br ^bb2(%acc1_val : i32)
^bb2(%acc1: i32):
%acc2_val = arith.addi %acc1, %arg1 : i32
cf.br ^bb3(%acc2_val : i32)
^bb3(%acc_final: i32):
%scalar_vec = vector.broadcast %acc_final : i32 to vector<1xi32>
%vec0 = vector.insert_strided_slice %scalar_vec, %arg2 {offsets = [1], strides = [1]} : vector<1xi32> into vector<4xi32>
%vec1 = vector.insert_strided_slice %scalar_vec, %vec0 {offsets = [2], strides = [1]} : vector<1xi32> into vector<4xi32>
%result = vector.insert_strided_slice %scalar_vec, %vec1 {offsets = [3], strides = [1]} : vector<1xi32> into vector<4xi32>
return %result : vector<4xi32>
}

// -----

// Check that already legal function parameters are preserved across a loop (which contains multiple blocks).

// CHECK-LABEL: @legal_params_for_loop
// CHECK-SAME: (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32, %[[ARG2:.+]]: i32)
func.func @legal_params_for_loop(%arg0: i32, %arg1: i32, %arg2: i32) -> i32 {
// CHECK: %[[CST0:.*]] = arith.constant 0 : index
// CHECK: %[[CST1:.*]] = arith.constant 1 : index
// CHECK: %[[UB:.*]] = arith.index_cast %[[ARG2]] : i32 to index
// CHECK: %[[RESULT:.*]] = scf.for %[[STEP:.*]] = %[[CST0]] to %[[UB]] step %[[CST1]] iter_args(%[[ACC:.*]] = %[[ARG0]]) -> (i32) {
// CHECK: %[[ADD:.*]] = arith.addi %[[ACC]], %[[ARG1]] : i32
// CHECK: scf.yield %[[ADD]] : i32
// CHECK: return %[[RESULT]] : i32
%zero = arith.constant 0 : index
%one = arith.constant 1 : index
%ub = arith.index_cast %arg2 : i32 to index
%result = scf.for %i = %zero to %ub step %one iter_args(%acc = %arg0) -> (i32) {
%new_acc = arith.addi %acc, %arg1 : i32
scf.yield %new_acc : i32
}
return %result : i32
}
Loading