diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index 62a24646d0662..f5a58c58e05df 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -1020,22 +1020,22 @@ struct FuncOpVectorUnroll final : OpRewritePattern { SmallVector locs(convertedTypes.size(), newFuncOp.getLoc()); entryBlock.addArguments(convertedTypes, locs); - // Replace the placeholder values with the new arguments. We assume there is - // only one block for now. + // Replace all uses of placeholders for initially legal arguments with their + // original function arguments (that were added to `newFuncOp`). + for (auto &[placeholderOp, argIdx] : tmpOps) { + if (!placeholderOp) + continue; + Value replacement = newFuncOp.getArgument(argIdx); + rewriter.replaceAllUsesWith(placeholderOp->getResult(0), replacement); + } + + // Replace dummy operands of new `vector.insert_strided_slice` ops with + // their corresponding new function arguments. The new + // `vector.insert_strided_slice` ops are inserted only into the entry block, + // so iterating over that block is sufficient. size_t unrolledInputIdx = 0; for (auto [count, op] : enumerate(entryBlock.getOperations())) { - // We first look for operands that are placeholders for initially legal - // arguments. Operation &curOp = op; - for (auto [operandIdx, operandVal] : llvm::enumerate(op.getOperands())) { - Operation *operandOp = operandVal.getDefiningOp(); - if (auto it = tmpOps.find(operandOp); it != tmpOps.end()) { - size_t idx = operandIdx; - rewriter.modifyOpInPlace(&curOp, [&curOp, &newFuncOp, it, idx] { - curOp.setOperand(idx, newFuncOp.getArgument(it->second)); - }); - } - } // Since all newly created operations are in the beginning, reaching the // end of them means that any later `vector.insert_strided_slice` should // not be touched. diff --git a/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir b/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir index c018ccb924983..211d6c90243bd 100644 --- a/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir +++ b/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir @@ -189,3 +189,76 @@ func.func @unsupported_scalable(%arg0 : vector<[8]xi32>) -> (vector<[8]xi32>) { return %arg0 : vector<[8]xi32> } +// ----- + +// Check that already legal function parameters are properly preserved across multiple blocks. + +// CHECK-LABEL: func.func @legal_params_multiple_blocks_simple +// CHECK-SAME: (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32) -> i32 +func.func @legal_params_multiple_blocks_simple(%arg0: i32, %arg1: i32) -> i32 { + // CHECK: %[[ADD0:.*]] = arith.addi %[[ARG0]], %[[ARG1]] : i32 + // CHECK: %[[ADD1:.*]] = arith.addi %[[ADD0]], %[[ARG1]] : i32 + // CHECK: return %[[ADD1]] : i32 + cf.br ^bb1(%arg0 : i32) +^bb1(%acc0: i32): + %acc1_val = arith.addi %acc0, %arg1 : i32 + cf.br ^bb2(%acc1_val : i32) +^bb2(%acc1: i32): + %acc2_val = arith.addi %acc1, %arg1 : i32 + cf.br ^bb3(%acc2_val : i32) +^bb3(%acc_final: i32): + return %acc_final : i32 +} + +// ----- + +// Check that legal parameters and existing `vector.insert_strided_slice`s are properly preserved across multiple blocks. + +// CHECK-LABEL: func.func @legal_params_with_vec_insert_multiple_blocks +// CHECK-SAME: (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32, %[[ARG2:.+]]: vector<4xi32>) -> vector<4xi32> +func.func @legal_params_with_vec_insert_multiple_blocks(%arg0: i32, %arg1: i32, %arg2: vector<4xi32>) -> vector<4xi32> { + // CHECK: %[[ADD0:.*]] = arith.addi %[[ARG0]], %[[ARG1]] : i32 + // CHECK: %[[ADD1:.*]] = arith.addi %[[ADD0]], %[[ARG1]] : i32 + // CHECK: %[[VEC1D:.*]] = vector.broadcast %[[ADD1]] : i32 to vector<1xi32> + // CHECK: %[[VEC0:.*]] = vector.insert_strided_slice %[[VEC1D]], %[[ARG2]] {offsets = [1], strides = [1]} : vector<1xi32> into vector<4xi32> + // CHECK: %[[VEC1:.*]] = vector.insert_strided_slice %[[VEC1D]], %[[VEC0]] {offsets = [2], strides = [1]} : vector<1xi32> into vector<4xi32> + // CHECK: %[[RESULT:.*]] = vector.insert_strided_slice %[[VEC1D]], %[[VEC1]] {offsets = [3], strides = [1]} : vector<1xi32> into vector<4xi32> + // CHECK: return %[[RESULT]] : vector<4xi32> + cf.br ^bb1(%arg0 : i32) +^bb1(%acc0: i32): + %acc1_val = arith.addi %acc0, %arg1 : i32 + cf.br ^bb2(%acc1_val : i32) +^bb2(%acc1: i32): + %acc2_val = arith.addi %acc1, %arg1 : i32 + cf.br ^bb3(%acc2_val : i32) +^bb3(%acc_final: i32): + %scalar_vec = vector.broadcast %acc_final : i32 to vector<1xi32> + %vec0 = vector.insert_strided_slice %scalar_vec, %arg2 {offsets = [1], strides = [1]} : vector<1xi32> into vector<4xi32> + %vec1 = vector.insert_strided_slice %scalar_vec, %vec0 {offsets = [2], strides = [1]} : vector<1xi32> into vector<4xi32> + %result = vector.insert_strided_slice %scalar_vec, %vec1 {offsets = [3], strides = [1]} : vector<1xi32> into vector<4xi32> + return %result : vector<4xi32> +} + +// ----- + +// Check that already legal function parameters are preserved across a loop (which contains multiple blocks). + +// CHECK-LABEL: @legal_params_for_loop +// CHECK-SAME: (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32, %[[ARG2:.+]]: i32) +func.func @legal_params_for_loop(%arg0: i32, %arg1: i32, %arg2: i32) -> i32 { + // CHECK: %[[CST0:.*]] = arith.constant 0 : index + // CHECK: %[[CST1:.*]] = arith.constant 1 : index + // CHECK: %[[UB:.*]] = arith.index_cast %[[ARG2]] : i32 to index + // CHECK: %[[RESULT:.*]] = scf.for %[[STEP:.*]] = %[[CST0]] to %[[UB]] step %[[CST1]] iter_args(%[[ACC:.*]] = %[[ARG0]]) -> (i32) { + // CHECK: %[[ADD:.*]] = arith.addi %[[ACC]], %[[ARG1]] : i32 + // CHECK: scf.yield %[[ADD]] : i32 + // CHECK: return %[[RESULT]] : i32 + %zero = arith.constant 0 : index + %one = arith.constant 1 : index + %ub = arith.index_cast %arg2 : i32 to index + %result = scf.for %i = %zero to %ub step %one iter_args(%acc = %arg0) -> (i32) { + %new_acc = arith.addi %acc, %arg1 : i32 + scf.yield %new_acc : i32 + } + return %result : i32 +}