From 6f58d06271adf72408261992dbfc0526f771163b Mon Sep 17 00:00:00 2001 From: brod4910 Date: Mon, 14 Oct 2024 07:08:51 -0600 Subject: [PATCH] Fix out-of-bounds FoldEmptyTensorWithDimOp crash #111270 --- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 5 ++- .../Dialect/Linalg/drop-unit-extent-dims.mlir | 39 +++++++++++++++++++ 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index c2d6bc610cd92..8f1e034cb15b9 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -980,7 +980,10 @@ struct FoldEmptyTensorWithDimOp : public OpRewritePattern { auto emptyTensorOp = dimOp.getSource().getDefiningOp(); if (!emptyTensorOp || !maybeConstantIndex) return failure(); - if (!emptyTensorOp.getType().isDynamicDim(*maybeConstantIndex)) + auto emptyTensorType = emptyTensorOp.getType(); + if (*maybeConstantIndex < 0 || + *maybeConstantIndex >= emptyTensorType.getRank() || + !emptyTensorType.isDynamicDim(*maybeConstantIndex)) return failure(); rewriter.replaceOp(dimOp, emptyTensorOp.getDynamicSize(*maybeConstantIndex)); diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir index 9a00b19aae400..3256daa8e0b59 100644 --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -1130,3 +1130,42 @@ module { return %1 : tensor } } + +// ----- + +func.func @no_fold_empty_tensor_dim_out_of_bounds(%arg0: tensor<1x?x10xf32>) -> tensor<1x?xf32> { + %cst = arith.constant 1.000000e+00 : f32 + %cst7 = arith.constant 7 : index + %dim = tensor.dim %arg0, %cst7 : tensor<1x?x10xf32> + %0 = tensor.empty(%dim) : tensor<1x?xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1x?xf32>) -> tensor<1x?xf32> + return %1 : tensor<1x?xf32> +} +// CHECK-LABEL: func.func @no_fold_empty_tensor_dim_out_of_bounds +// CHECK-SAME: %[[ARG0:.*]]: tensor<1x?x10xf32>) -> tensor<1x?xf32> { +// CHECK: %[[CST:.*]] = arith.constant 1.000000e+00 : f32 +// CHECK: %[[C7:.*]] = arith.constant 7 +// CHECK: %[[DIM:.*]] = tensor.dim %[[ARG0]], %[[C7]] : tensor<1x?x10xf32> +// CHECK: %[[VAL_0:.*]] = tensor.empty(%[[DIM]]) : tensor<1x?xf32> +// CHECK: %[[VAL_1:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[VAL_0]] : tensor<1x?xf32>) -> tensor<1x?xf32> +// CHECK: return %[[VAL_1]] : tensor<1x?xf32> +// CHECK: } + +// ----- + +func.func @fold_empty_tensor_dim_op(%arg0: tensor<1x?x10xf32>) -> tensor<1x?xf32> { + %cst = arith.constant 1.000000e+00 : f32 + %cst2 = index.constant 2 + %dim10 = tensor.dim %arg0, %cst2 : tensor<1x?x10xf32> + %0 = tensor.empty(%dim10) : tensor<1x?xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1x?xf32>) -> tensor<1x?xf32> + return %1 : tensor<1x?xf32> +} +// CHECK-LABEL: func.func @fold_empty_tensor_dim_op +// CHECK-SAME: %[[ARG0:.*]]: tensor<1x?x10xf32>) -> tensor<1x?xf32> { +// CHECK: %[[CST:.*]] = arith.constant 1.000000e+00 : f32 +// CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<1x10xf32> +// CHECK: %[[VAL_1:.*]] = tensor.cast %[[VAL_0]] : tensor<1x10xf32> to tensor<1x?xf32> +// CHECK: %[[VAL_2:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[VAL_1]] : tensor<1x?xf32>) -> tensor<1x?xf32> +// CHECK: return %[[VAL_2]] : tensor<1x?xf32> +// CHECK: }