diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td index c0885a3763827..35d0b16628417 100644 --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -827,7 +827,6 @@ def Tensor_InsertOp : Tensor_Op<"insert", [ let hasFolder = 1; let hasVerifier = 1; - let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 12e8b257ce9f1..6e67377ddb6e8 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -1624,76 +1624,6 @@ OpFoldResult GatherOp::fold(FoldAdaptor adaptor) { // InsertOp //===----------------------------------------------------------------------===// -namespace { - -/// Pattern to fold an insert op of a constant destination and scalar to a new -/// constant. -/// -/// Example: -/// ``` -/// %0 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf32> -/// %c0 = arith.constant 0 : index -/// %c4_f32 = arith.constant 4.0 : f32 -/// %1 = tensor.insert %c4_f32 into %0[%c0] : tensor<4xf32> -/// ``` -/// is rewritten into: -/// ``` -/// %1 = arith.constant dense<[4.0, 2.0, 3.0, 4.0]> : tensor<4xf32> -/// ``` -class InsertOpConstantFold final : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(InsertOp insertOp, - PatternRewriter &rewriter) const override { - // Requires a ranked tensor type. - auto destType = - llvm::dyn_cast(insertOp.getDest().getType()); - if (!destType) - return failure(); - - // Pattern requires constant indices - SmallVector indices; - for (OpFoldResult indice : getAsOpFoldResult(insertOp.getIndices())) { - auto indiceAttr = dyn_cast(indice); - if (!indiceAttr) - return failure(); - indices.push_back(llvm::cast(indiceAttr).getInt()); - } - - // Requires a constant scalar to insert - OpFoldResult scalar = getAsOpFoldResult(insertOp.getScalar()); - Attribute scalarAttr = dyn_cast(scalar); - if (!scalarAttr) - return failure(); - - if (auto constantOp = dyn_cast_or_null( - insertOp.getDest().getDefiningOp())) { - if (auto sourceAttr = - llvm::dyn_cast(constantOp.getValue())) { - // Update the attribute at the inserted index. - auto sourceValues = sourceAttr.getValues(); - auto flattenedIndex = sourceAttr.getFlattenedIndex(indices); - std::vector updatedValues; - updatedValues.reserve(sourceAttr.getNumElements()); - for (unsigned i = 0; i < sourceAttr.getNumElements(); ++i) { - updatedValues.push_back(i == flattenedIndex ? scalarAttr - : sourceValues[i]); - } - rewriter.replaceOpWithNewOp( - insertOp, sourceAttr.getType(), - DenseElementsAttr::get(cast(sourceAttr.getType()), - updatedValues)); - return success(); - } - } - - return failure(); - } -}; - -} // namespace - void InsertOp::getAsmResultNames( function_ref setNameFn) { setNameFn(getResult(), "inserted"); @@ -1717,11 +1647,6 @@ OpFoldResult InsertOp::fold(FoldAdaptor adaptor) { return {}; } -void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add(context); -} - //===----------------------------------------------------------------------===// // GenerateOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index 646b2197d9aa6..f033a43c0dc24 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -231,22 +231,6 @@ func.func @fold_insert(%arg0 : index) -> (tensor<4xf32>) { return %ins_1 : tensor<4xf32> } - -// ----- - -func.func @canonicalize_insert_after_constant() -> (tensor<2x2xi32>) { - // Fold an insert into a splat. - // CHECK: %[[C4:.+]] = arith.constant dense<{{\[\[}}1, 2], [4, 4]]> : tensor<2x2xi32> - // CHECK-LITERAL: - // CHECK-NEXT: return %[[C4]] - %cst = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32> - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c4_i32 = arith.constant 4 : i32 - %inserted = tensor.insert %c4_i32 into %cst[%c1, %c0] : tensor<2x2xi32> - return %inserted : tensor<2x2xi32> -} - // ----- // CHECK-LABEL: func @extract_from_tensor.cast