Skip to content

Commit 190b9aa

Browse files
committed
Add tests and update comment
1 parent c720c6b commit 190b9aa

File tree

2 files changed

+78
-3
lines changed

2 files changed

+78
-3
lines changed

mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1035,9 +1035,11 @@ struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
10351035
}
10361036
}
10371037

1038-
// Since all newly created operations are in the beginning, reaching the
1039-
// end of them means that any later `vector.insert_strided_slice` should
1040-
// not be touched.
1038+
// Only consider `vector.insert_strided_slice` ops that were newly created
1039+
// at the beginning of the entry block. Once we encounter operations
1040+
// outside the entry block or past the `newOpCount`-th operation in the
1041+
// entry block, we skip and leave exisintg `vector.insert_strided_slice`
1042+
// ops as is.
10411043
if (op->getBlock() != &entryBlock ||
10421044
static_cast<size_t>(std::distance(entryBlock.begin(),
10431045
op->getIterator())) >= newOpCount)

mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,3 +189,76 @@ func.func @unsupported_scalable(%arg0 : vector<[8]xi32>) -> (vector<[8]xi32>) {
189189
return %arg0 : vector<[8]xi32>
190190
}
191191

192+
// -----
193+
194+
// Check that already legal function parameters are properly preserved across multiple blocks.
195+
196+
// CHECK-LABEL: func.func @legal_params_multiple_blocks_simple
197+
// CHECK-SAME: (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32) -> i32
198+
func.func @legal_params_multiple_blocks_simple(%arg0: i32, %arg1: i32) -> i32 {
199+
// CHECK: %[[ADD0:.*]] = arith.addi %[[ARG0]], %[[ARG1]] : i32
200+
// CHECK: %[[ADD1:.*]] = arith.addi %[[ADD0]], %[[ARG1]] : i32
201+
// CHECK: return %[[ADD1]] : i32
202+
cf.br ^bb1(%arg0 : i32)
203+
^bb1(%acc0: i32):
204+
%acc1_val = arith.addi %acc0, %arg1 : i32
205+
cf.br ^bb2(%acc1_val : i32)
206+
^bb2(%acc1: i32):
207+
%acc2_val = arith.addi %acc1, %arg1 : i32
208+
cf.br ^bb3(%acc2_val : i32)
209+
^bb3(%acc_final: i32):
210+
return %acc_final : i32
211+
}
212+
213+
// -----
214+
215+
// Check that legal parameters and existing `vector.insert_strided_slice`s are properly preserved across multiple blocks.
216+
217+
// CHECK-LABEL: func.func @legal_params_with_vec_insert_multiple_blocks
218+
// CHECK-SAME: (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32, %[[ARG2:.+]]: vector<4xi32>) -> vector<4xi32>
219+
func.func @legal_params_with_vec_insert_multiple_blocks(%arg0: i32, %arg1: i32, %arg2: vector<4xi32>) -> vector<4xi32> {
220+
// CHECK: %[[ADD0:.*]] = arith.addi %[[ARG0]], %[[ARG1]] : i32
221+
// CHECK: %[[ADD1:.*]] = arith.addi %[[ADD0]], %[[ARG1]] : i32
222+
// CHECK: %[[VEC1D:.*]] = vector.broadcast %[[ADD1]] : i32 to vector<1xi32>
223+
// CHECK: %[[VEC0:.*]] = vector.insert_strided_slice %[[VEC1D]], %[[ARG2]] {offsets = [1], strides = [1]} : vector<1xi32> into vector<4xi32>
224+
// CHECK: %[[VEC1:.*]] = vector.insert_strided_slice %[[VEC1D]], %[[VEC0]] {offsets = [2], strides = [1]} : vector<1xi32> into vector<4xi32>
225+
// CHECK: %[[RESULT:.*]] = vector.insert_strided_slice %[[VEC1D]], %[[VEC1]] {offsets = [3], strides = [1]} : vector<1xi32> into vector<4xi32>
226+
// CHECK: return %[[RESULT]] : vector<4xi32>
227+
cf.br ^bb1(%arg0 : i32)
228+
^bb1(%acc0: i32):
229+
%acc1_val = arith.addi %acc0, %arg1 : i32
230+
cf.br ^bb2(%acc1_val : i32)
231+
^bb2(%acc1: i32):
232+
%acc2_val = arith.addi %acc1, %arg1 : i32
233+
cf.br ^bb3(%acc2_val : i32)
234+
^bb3(%acc_final: i32):
235+
%scalar_vec = vector.broadcast %acc_final : i32 to vector<1xi32>
236+
%vec0 = vector.insert_strided_slice %scalar_vec, %arg2 {offsets = [1], strides = [1]} : vector<1xi32> into vector<4xi32>
237+
%vec1 = vector.insert_strided_slice %scalar_vec, %vec0 {offsets = [2], strides = [1]} : vector<1xi32> into vector<4xi32>
238+
%result = vector.insert_strided_slice %scalar_vec, %vec1 {offsets = [3], strides = [1]} : vector<1xi32> into vector<4xi32>
239+
return %result : vector<4xi32>
240+
}
241+
242+
// -----
243+
244+
// Check that already legal function parameters are preserved across a loop (which contains multiple blocks).
245+
246+
// CHECK-LABEL: @legal_params_for_loop
247+
// CHECK-SAME: (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32, %[[ARG2:.+]]: i32)
248+
func.func @legal_params_for_loop(%arg0: i32, %arg1: i32, %arg2: i32) -> i32 {
249+
// CHECK: %[[CST0:.*]] = arith.constant 0 : index
250+
// CHECK: %[[CST1:.*]] = arith.constant 1 : index
251+
// CHECK: %[[UB:.*]] = arith.index_cast %[[ARG2]] : i32 to index
252+
// CHECK: %[[RESULT:.*]] = scf.for %[[STEP:.*]] = %[[CST0]] to %[[UB]] step %[[CST1]] iter_args(%[[ACC:.*]] = %[[ARG0]]) -> (i32) {
253+
// CHECK: %[[ADD:.*]] = arith.addi %[[ACC]], %[[ARG1]] : i32
254+
// CHECK: scf.yield %[[ADD]] : i32
255+
// CHECK: return %[[RESULT]] : i32
256+
%zero = arith.constant 0 : index
257+
%one = arith.constant 1 : index
258+
%ub = arith.index_cast %arg2 : i32 to index
259+
%result = scf.for %i = %zero to %ub step %one iter_args(%acc = %arg0) -> (i32) {
260+
%new_acc = arith.addi %acc, %arg1 : i32
261+
scf.yield %new_acc : i32
262+
}
263+
return %result : i32
264+
}

0 commit comments

Comments
 (0)