Skip to content

Commit a8f3860

Browse files
authored
[mlir][tensor] Fix bug in tensor.extract(tensor.from_elements) folder (llvm#75109)
The folder for `tensor.extract` is not operating correctly when it is consuming the result of a `tensor.from_elements` operation. The existing unit test named `@extract_from_tensor.from_elements_3d` in `mlir/test/Dialect/Tensor/canonicalize.mlir` seems an attempt to stress this code. However, this unit tests creates a `tensor.from_elements` op exclusively from constants, which gets folded away into a single constant tensor. Therefore, the buggy code was never executed in unit tests. I have added a new unit test named `@extract_from_tensor.from_elements_variable_3d` that makes sure the `tensor.from_elements` op is not folded away by having its input operands come directly from function arguments. The original folder code would have made this test fail. This bug was notably affecting the lowering of the `tosa.pad` op in the `tosa-to-tensor` pass, where the generated code is likely to contain a `tensor.from_elements` + `tensor.extract` op sequence.
1 parent c873f77 commit a8f3860

File tree

2 files changed

+45
-2
lines changed

2 files changed

+45
-2
lines changed

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1116,9 +1116,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
11161116
int flatIndex = 0;
11171117
int stride = 1;
11181118
for (int i = rank - 1; i >= 0; --i) {
1119-
if (i < rank - 1)
1120-
stride *= tensorType.getDimSize(i);
11211119
flatIndex += indices[i] * stride;
1120+
stride *= tensorType.getDimSize(i);
11221121
}
11231122
// Prevent out of bounds accesses. This can happen in invalid code that
11241123
// will never execute.

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,50 @@ func.func @extract_from_tensor.from_elements_3d()
242242

243243
// -----
244244

245+
// CHECK-LABEL: func @extract_from_tensor.from_elements_variable_3d
246+
// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: f32
247+
// CHECK-SAME: %[[ARG_1:[a-zA-Z0-9_]+]]: f32
248+
// CHECK-SAME: %[[ARG_2:[a-zA-Z0-9_]+]]: f32
249+
// CHECK-SAME: %[[ARG_3:[a-zA-Z0-9_]+]]: f32
250+
// CHECK-SAME: %[[ARG_4:[a-zA-Z0-9_]+]]: f32
251+
// CHECK-SAME: %[[ARG_5:[a-zA-Z0-9_]+]]: f32
252+
// CHECK-SAME: %[[ARG_6:[a-zA-Z0-9_]+]]: f32
253+
// CHECK-SAME: %[[ARG_7:[a-zA-Z0-9_]+]]: f32
254+
// CHECK-SAME: %[[ARG_8:[a-zA-Z0-9_]+]]: f32
255+
// CHECK-SAME: %[[ARG_9:[a-zA-Z0-9_]+]]: f32
256+
// CHECK-SAME: %[[ARG_10:[a-zA-Z0-9_]+]]: f32
257+
// CHECK-SAME: %[[ARG_11:[a-zA-Z0-9_]+]]: f32
258+
func.func @extract_from_tensor.from_elements_variable_3d(
259+
%f0: f32, %f1: f32, %f2: f32, %f3: f32, %f4: f32, %f5: f32,
260+
%f6: f32, %f7: f32, %f8: f32, %f9: f32, %f10: f32, %f11: f32)
261+
-> (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32) {
262+
263+
%tensor = tensor.from_elements %f0,%f1,%f2,%f3,%f4,%f5,%f6,%f7,%f8,%f9,%f10,%f11
264+
: tensor<3x2x2xf32>
265+
%c0 = arith.constant 0 : index
266+
%c1 = arith.constant 1 : index
267+
%c2 = arith.constant 2 : index
268+
269+
%r0 = tensor.extract %tensor[%c0, %c0, %c0] : tensor<3x2x2xf32>
270+
%r1 = tensor.extract %tensor[%c0, %c0, %c1] : tensor<3x2x2xf32>
271+
%r2 = tensor.extract %tensor[%c0, %c1, %c0] : tensor<3x2x2xf32>
272+
%r3 = tensor.extract %tensor[%c0, %c1, %c1] : tensor<3x2x2xf32>
273+
%r4 = tensor.extract %tensor[%c1, %c0, %c0] : tensor<3x2x2xf32>
274+
%r5 = tensor.extract %tensor[%c1, %c0, %c1] : tensor<3x2x2xf32>
275+
%r6 = tensor.extract %tensor[%c1, %c1, %c0] : tensor<3x2x2xf32>
276+
%r7 = tensor.extract %tensor[%c1, %c1, %c1] : tensor<3x2x2xf32>
277+
%r8 = tensor.extract %tensor[%c2, %c0, %c0] : tensor<3x2x2xf32>
278+
%r9 = tensor.extract %tensor[%c2, %c0, %c1] : tensor<3x2x2xf32>
279+
%r10 = tensor.extract %tensor[%c2, %c1, %c0] : tensor<3x2x2xf32>
280+
%r11 = tensor.extract %tensor[%c2, %c1, %c1] : tensor<3x2x2xf32>
281+
return %r0,%r1,%r2,%r3,%r4,%r5,%r6,%r7,%r8,%r9,%r10,%r11
282+
: f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
283+
}
284+
// CHECK: return %[[ARG_0]], %[[ARG_1]], %[[ARG_2]], %[[ARG_3]], %[[ARG_4]], %[[ARG_5]],
285+
// CHECK-SAME: %[[ARG_6]], %[[ARG_7]], %[[ARG_8]], %[[ARG_9]], %[[ARG_10]], %[[ARG_11]]
286+
287+
// -----
288+
245289
// CHECK-LABEL: func.func @extract_from_elements_complex_i() -> tensor<3xcomplex<i32>> {
246290
// CHECK-NEXT: %cst = arith.constant dense<[(1,2), (3,2), (1,2)]> : tensor<3xcomplex<i32>>
247291
// CHECK-NEXT: return %cst : tensor<3xcomplex<i32>>

0 commit comments

Comments
 (0)