Skip to content

Commit 40556d0

Browse files
authored
[MLIR][Tensor] Fix out-of-bounds FoldEmptyTensorWithDimOp crash (#112196)
Fixes #111270
1 parent f87484d commit 40556d0

File tree

2 files changed

+43
-1
lines changed

2 files changed

+43
-1
lines changed

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -980,7 +980,10 @@ struct FoldEmptyTensorWithDimOp : public OpRewritePattern<DimOp> {
980980
auto emptyTensorOp = dimOp.getSource().getDefiningOp<EmptyOp>();
981981
if (!emptyTensorOp || !maybeConstantIndex)
982982
return failure();
983-
if (!emptyTensorOp.getType().isDynamicDim(*maybeConstantIndex))
983+
auto emptyTensorType = emptyTensorOp.getType();
984+
if (*maybeConstantIndex < 0 ||
985+
*maybeConstantIndex >= emptyTensorType.getRank() ||
986+
!emptyTensorType.isDynamicDim(*maybeConstantIndex))
984987
return failure();
985988
rewriter.replaceOp(dimOp,
986989
emptyTensorOp.getDynamicSize(*maybeConstantIndex));

mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1130,3 +1130,42 @@ module {
11301130
return %1 : tensor<?x1x61x1xf32>
11311131
}
11321132
}
1133+
1134+
// -----
1135+
1136+
func.func @no_fold_empty_tensor_dim_out_of_bounds(%arg0: tensor<1x?x10xf32>) -> tensor<1x?xf32> {
1137+
%cst = arith.constant 1.000000e+00 : f32
1138+
%cst7 = arith.constant 7 : index
1139+
%dim = tensor.dim %arg0, %cst7 : tensor<1x?x10xf32>
1140+
%0 = tensor.empty(%dim) : tensor<1x?xf32>
1141+
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1x?xf32>) -> tensor<1x?xf32>
1142+
return %1 : tensor<1x?xf32>
1143+
}
1144+
// CHECK-LABEL: func.func @no_fold_empty_tensor_dim_out_of_bounds
1145+
// CHECK-SAME: %[[ARG0:.*]]: tensor<1x?x10xf32>) -> tensor<1x?xf32> {
1146+
// CHECK: %[[CST:.*]] = arith.constant 1.000000e+00 : f32
1147+
// CHECK: %[[C7:.*]] = arith.constant 7
1148+
// CHECK: %[[DIM:.*]] = tensor.dim %[[ARG0]], %[[C7]] : tensor<1x?x10xf32>
1149+
// CHECK: %[[VAL_0:.*]] = tensor.empty(%[[DIM]]) : tensor<1x?xf32>
1150+
// CHECK: %[[VAL_1:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[VAL_0]] : tensor<1x?xf32>) -> tensor<1x?xf32>
1151+
// CHECK: return %[[VAL_1]] : tensor<1x?xf32>
1152+
// CHECK: }
1153+
1154+
// -----
1155+
1156+
func.func @fold_empty_tensor_dim_op(%arg0: tensor<1x?x10xf32>) -> tensor<1x?xf32> {
1157+
%cst = arith.constant 1.000000e+00 : f32
1158+
%cst2 = index.constant 2
1159+
%dim10 = tensor.dim %arg0, %cst2 : tensor<1x?x10xf32>
1160+
%0 = tensor.empty(%dim10) : tensor<1x?xf32>
1161+
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1x?xf32>) -> tensor<1x?xf32>
1162+
return %1 : tensor<1x?xf32>
1163+
}
1164+
// CHECK-LABEL: func.func @fold_empty_tensor_dim_op
1165+
// CHECK-SAME: %[[ARG0:.*]]: tensor<1x?x10xf32>) -> tensor<1x?xf32> {
1166+
// CHECK: %[[CST:.*]] = arith.constant 1.000000e+00 : f32
1167+
// CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<1x10xf32>
1168+
// CHECK: %[[VAL_1:.*]] = tensor.cast %[[VAL_0]] : tensor<1x10xf32> to tensor<1x?xf32>
1169+
// CHECK: %[[VAL_2:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[VAL_1]] : tensor<1x?xf32>) -> tensor<1x?xf32>
1170+
// CHECK: return %[[VAL_2]] : tensor<1x?xf32>
1171+
// CHECK: }

0 commit comments

Comments
 (0)