Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1020,22 +1020,22 @@ struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
SmallVector<Location> locs(convertedTypes.size(), newFuncOp.getLoc());
entryBlock.addArguments(convertedTypes, locs);

// Replace the placeholder values with the new arguments. We assume there is
// only one block for now.
// Replace all uses of placeholders for initially legal arguments with their
// original function arguments (that were added to `newFuncOp`).
for (auto &[placeholderOp, argIdx] : tmpOps) {
if (!placeholderOp)
continue;
Value replacement = newFuncOp.getArgument(argIdx);
rewriter.replaceAllUsesWith(placeholderOp->getResult(0), replacement);
}

// Replace dummy operands of new `vector.insert_strided_slice` ops with
// their corresponding new function arguments. The new
// `vector.insert_strided_slice` ops are inserted only into the entry block,
// so iterating over that block is sufficient.
size_t unrolledInputIdx = 0;
for (auto [count, op] : enumerate(entryBlock.getOperations())) {
// We first look for operands that are placeholders for initially legal
// arguments.
Operation &curOp = op;
for (auto [operandIdx, operandVal] : llvm::enumerate(op.getOperands())) {
Operation *operandOp = operandVal.getDefiningOp();
if (auto it = tmpOps.find(operandOp); it != tmpOps.end()) {
size_t idx = operandIdx;
rewriter.modifyOpInPlace(&curOp, [&curOp, &newFuncOp, it, idx] {
curOp.setOperand(idx, newFuncOp.getArgument(it->second));
});
}
}
// Since all newly created operations are in the beginning, reaching the
// end of them means that any later `vector.insert_strided_slice` should
// not be touched.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,3 +189,76 @@ func.func @unsupported_scalable(%arg0 : vector<[8]xi32>) -> (vector<[8]xi32>) {
return %arg0 : vector<[8]xi32>
}

// -----

// Check that already legal function parameters are properly preserved across multiple blocks.

// CHECK-LABEL: func.func @legal_params_multiple_blocks_simple
// CHECK-SAME: (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32) -> i32
func.func @legal_params_multiple_blocks_simple(%arg0: i32, %arg1: i32) -> i32 {
// CHECK: %[[ADD0:.*]] = arith.addi %[[ARG0]], %[[ARG1]] : i32
// CHECK: %[[ADD1:.*]] = arith.addi %[[ADD0]], %[[ARG1]] : i32
// CHECK: return %[[ADD1]] : i32
cf.br ^bb1(%arg0 : i32)
^bb1(%acc0: i32):
%acc1_val = arith.addi %acc0, %arg1 : i32
cf.br ^bb2(%acc1_val : i32)
^bb2(%acc1: i32):
%acc2_val = arith.addi %acc1, %arg1 : i32
cf.br ^bb3(%acc2_val : i32)
^bb3(%acc_final: i32):
return %acc_final : i32
}

// -----

// Check that legal parameters and existing `vector.insert_strided_slice`s are properly preserved across multiple blocks.

// CHECK-LABEL: func.func @legal_params_with_vec_insert_multiple_blocks
// CHECK-SAME: (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32, %[[ARG2:.+]]: vector<4xi32>) -> vector<4xi32>
func.func @legal_params_with_vec_insert_multiple_blocks(%arg0: i32, %arg1: i32, %arg2: vector<4xi32>) -> vector<4xi32> {
// CHECK: %[[ADD0:.*]] = arith.addi %[[ARG0]], %[[ARG1]] : i32
// CHECK: %[[ADD1:.*]] = arith.addi %[[ADD0]], %[[ARG1]] : i32
// CHECK: %[[VEC1D:.*]] = vector.broadcast %[[ADD1]] : i32 to vector<1xi32>
// CHECK: %[[VEC0:.*]] = vector.insert_strided_slice %[[VEC1D]], %[[ARG2]] {offsets = [1], strides = [1]} : vector<1xi32> into vector<4xi32>
// CHECK: %[[VEC1:.*]] = vector.insert_strided_slice %[[VEC1D]], %[[VEC0]] {offsets = [2], strides = [1]} : vector<1xi32> into vector<4xi32>
// CHECK: %[[RESULT:.*]] = vector.insert_strided_slice %[[VEC1D]], %[[VEC1]] {offsets = [3], strides = [1]} : vector<1xi32> into vector<4xi32>
// CHECK: return %[[RESULT]] : vector<4xi32>
cf.br ^bb1(%arg0 : i32)
^bb1(%acc0: i32):
%acc1_val = arith.addi %acc0, %arg1 : i32
cf.br ^bb2(%acc1_val : i32)
^bb2(%acc1: i32):
%acc2_val = arith.addi %acc1, %arg1 : i32
cf.br ^bb3(%acc2_val : i32)
^bb3(%acc_final: i32):
%scalar_vec = vector.broadcast %acc_final : i32 to vector<1xi32>
%vec0 = vector.insert_strided_slice %scalar_vec, %arg2 {offsets = [1], strides = [1]} : vector<1xi32> into vector<4xi32>
%vec1 = vector.insert_strided_slice %scalar_vec, %vec0 {offsets = [2], strides = [1]} : vector<1xi32> into vector<4xi32>
%result = vector.insert_strided_slice %scalar_vec, %vec1 {offsets = [3], strides = [1]} : vector<1xi32> into vector<4xi32>
return %result : vector<4xi32>
}

// -----

// Check that already legal function parameters are preserved across a loop (which contains multiple blocks).

// CHECK-LABEL: @legal_params_for_loop
// CHECK-SAME: (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32, %[[ARG2:.+]]: i32)
func.func @legal_params_for_loop(%arg0: i32, %arg1: i32, %arg2: i32) -> i32 {
// CHECK: %[[CST0:.*]] = arith.constant 0 : index
// CHECK: %[[CST1:.*]] = arith.constant 1 : index
// CHECK: %[[UB:.*]] = arith.index_cast %[[ARG2]] : i32 to index
// CHECK: %[[RESULT:.*]] = scf.for %[[STEP:.*]] = %[[CST0]] to %[[UB]] step %[[CST1]] iter_args(%[[ACC:.*]] = %[[ARG0]]) -> (i32) {
// CHECK: %[[ADD:.*]] = arith.addi %[[ACC]], %[[ARG1]] : i32
// CHECK: scf.yield %[[ADD]] : i32
// CHECK: return %[[RESULT]] : i32
%zero = arith.constant 0 : index
%one = arith.constant 1 : index
%ub = arith.index_cast %arg2 : i32 to index
%result = scf.for %i = %zero to %ub step %one iter_args(%acc = %arg0) -> (i32) {
%new_acc = arith.addi %acc, %arg1 : i32
scf.yield %new_acc : i32
}
return %result : i32
}
Loading