diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index ff323983a17c0..8e0e723cf4ed3 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -3019,94 +3019,78 @@ class InsertSplatToSplat final : public OpRewritePattern { } }; -// Pattern to rewrite a InsertOp(ConstantOp into ConstantOp) -> ConstantOp. -class InsertOpConstantFolder final : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - // Do not create constants with more than `vectorSizeFoldThreashold` elements, - // unless the source vector constant has a single use. - static constexpr int64_t vectorSizeFoldThreshold = 256; - - LogicalResult matchAndRewrite(InsertOp op, - PatternRewriter &rewriter) const override { - // TODO: Canonicalization for dynamic position not implemented yet. - if (op.hasDynamicPosition()) - return failure(); +} // namespace - // Return if 'InsertOp' operand is not defined by a compatible vector - // ConstantOp. - TypedValue destVector = op.getDest(); - Attribute vectorDestCst; - if (!matchPattern(destVector, m_Constant(&vectorDestCst))) - return failure(); - auto denseDest = llvm::dyn_cast(vectorDestCst); - if (!denseDest) - return failure(); +static Attribute +foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr, + Attribute dstAttr, + int64_t maxVectorSizeFoldThreshold) { + if (insertOp.hasDynamicPosition()) + return {}; - VectorType destTy = destVector.getType(); - if (destTy.isScalable()) - return failure(); + auto denseDst = llvm::dyn_cast_if_present(dstAttr); + if (!denseDst) + return {}; - // Make sure we do not create too many large constants. - if (destTy.getNumElements() > vectorSizeFoldThreshold && - !destVector.hasOneUse()) - return failure(); + if (!srcAttr) { + return {}; + } - Value sourceValue = op.getSource(); - Attribute sourceCst; - if (!matchPattern(sourceValue, m_Constant(&sourceCst))) - return failure(); + VectorType destTy = insertOp.getDestVectorType(); + if (destTy.isScalable()) + return {}; - // Calculate the linearized position of the continuous chunk of elements to - // insert. - llvm::SmallVector completePositions(destTy.getRank(), 0); - copy(op.getStaticPosition(), completePositions.begin()); - int64_t insertBeginPosition = - linearize(completePositions, computeStrides(destTy.getShape())); - - SmallVector insertedValues; - Type destEltType = destTy.getElementType(); - - // The `convertIntegerAttr` method specifically handles the case - // for `llvm.mlir.constant` which can hold an attribute with a - // different type than the return type. - if (auto denseSource = llvm::dyn_cast(sourceCst)) { - for (auto value : denseSource.getValues()) - insertedValues.push_back(convertIntegerAttr(value, destEltType)); - } else { - insertedValues.push_back(convertIntegerAttr(sourceCst, destEltType)); - } + // Make sure we do not create too many large constants. + if (destTy.getNumElements() > maxVectorSizeFoldThreshold && + !insertOp->hasOneUse()) + return {}; - auto allValues = llvm::to_vector(denseDest.getValues()); - copy(insertedValues, allValues.begin() + insertBeginPosition); - auto newAttr = DenseElementsAttr::get(destTy, allValues); + // Calculate the linearized position of the continuous chunk of elements to + // insert. + llvm::SmallVector completePositions(destTy.getRank(), 0); + copy(insertOp.getStaticPosition(), completePositions.begin()); + int64_t insertBeginPosition = + linearize(completePositions, computeStrides(destTy.getShape())); - rewriter.replaceOpWithNewOp(op, newAttr); - return success(); - } + SmallVector insertedValues; + Type destEltType = destTy.getElementType(); -private: /// Converts the expected type to an IntegerAttr if there's /// a mismatch. - Attribute convertIntegerAttr(Attribute attr, Type expectedType) const { + auto convertIntegerAttr = [](Attribute attr, Type expectedType) -> Attribute { if (auto intAttr = mlir::dyn_cast(attr)) { if (intAttr.getType() != expectedType) return IntegerAttr::get(expectedType, intAttr.getInt()); } return attr; + }; + + // The `convertIntegerAttr` method specifically handles the case + // for `llvm.mlir.constant` which can hold an attribute with a + // different type than the return type. + if (auto denseSource = llvm::dyn_cast(srcAttr)) { + for (auto value : denseSource.getValues()) + insertedValues.push_back(convertIntegerAttr(value, destEltType)); + } else { + insertedValues.push_back(convertIntegerAttr(srcAttr, destEltType)); } -}; -} // namespace + auto allValues = llvm::to_vector(denseDst.getValues()); + copy(insertedValues, allValues.begin() + insertBeginPosition); + auto newAttr = DenseElementsAttr::get(destTy, allValues); + + return newAttr; +} void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) { + // Do not create constants with more than `vectorSizeFoldThreashold` elements, + // unless the source vector constant has a single use. + constexpr int64_t vectorSizeFoldThreshold = 256; // Fold "vector.insert %v, %dest [] : vector<2x2xf32> from vector<2x2xf32>" to // %v. Note: Do not fold "vector.insert %v, %dest [] : f32 into vector" // (type mismatch). @@ -3118,6 +3102,11 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) { if (auto res = foldPoisonIndexInsertExtractOp( getContext(), adaptor.getStaticPosition(), kPoisonIndex)) return res; + if (auto res = foldDenseElementsAttrDestInsertOp(*this, adaptor.getSource(), + adaptor.getDest(), + vectorSizeFoldThreshold)) { + return res; + } return {}; } diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index 36b37a137ac1e..1ab28b9df2d19 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1517,13 +1517,9 @@ func.func @constant_mask_2d() -> vector<4x4xi1> { } // CHECK-LABEL: func @constant_mask_2d -// CHECK: %[[VAL_0:.*]] = arith.constant dense<[true, true, false, false]> : vector<4xi1> -// CHECK: %[[VAL_1:.*]] = arith.constant dense : vector<4x4xi1> -// CHECK: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[VAL_1]] : vector<4x4xi1> to !llvm.array<4 x vector<4xi1>> -// CHECK: %[[VAL_3:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_2]][0] : !llvm.array<4 x vector<4xi1>> -// CHECK: %[[VAL_4:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_3]][1] : !llvm.array<4 x vector<4xi1>> -// CHECK: %[[VAL_5:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : !llvm.array<4 x vector<4xi1>> to vector<4x4xi1> -// CHECK: return %[[VAL_5]] : vector<4x4xi1> +// CHECK: %[[VAL_0:.*]] = arith.constant +// CHECK-SAME{LITERAL}: dense<[[true, true, false, false], [true, true, false, false], [false, false, false, false], [false, false, false, false]]> : vector<4x4xi1> +// CHECK: return %[[VAL_0]] : vector<4x4xi1> // ----- diff --git a/mlir/test/Dialect/Vector/vector-mask-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-mask-lowering-transforms.mlir index 7838543e151be..b5eb6e63f5a8d 100644 --- a/mlir/test/Dialect/Vector/vector-mask-lowering-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-mask-lowering-transforms.mlir @@ -10,11 +10,9 @@ func.func @genbool_1d() -> vector<8xi1> { } // CHECK-LABEL: func @genbool_2d -// CHECK: %[[C1:.*]] = arith.constant dense<[true, true, false, false]> : vector<4xi1> -// CHECK: %[[C2:.*]] = arith.constant dense : vector<4x4xi1> -// CHECK: %[[T0:.*]] = vector.insert %[[C1]], %[[C2]] [0] : vector<4xi1> into vector<4x4xi1> -// CHECK: %[[T1:.*]] = vector.insert %[[C1]], %[[T0]] [1] : vector<4xi1> into vector<4x4xi1> -// CHECK: return %[[T1]] : vector<4x4xi1> +// CHECK: %[[C0:.*]] = arith.constant +// CHECK-SAME{LITERAL}: dense<[[true, true, false, false], [true, true, false, false], [false, false, false, false], [false, false, false, false]]> : vector<4x4xi1> +// CHECK: return %[[C0]] : vector<4x4xi1> func.func @genbool_2d() -> vector<4x4xi1> { %v = vector.constant_mask [2, 2] : vector<4x4xi1> @@ -22,12 +20,9 @@ func.func @genbool_2d() -> vector<4x4xi1> { } // CHECK-LABEL: func @genbool_3d -// CHECK-DAG: %[[C1:.*]] = arith.constant dense<[true, true, true, false]> : vector<4xi1> -// CHECK-DAG: %[[C2:.*]] = arith.constant dense : vector<3x4xi1> -// CHECK-DAG: %[[C3:.*]] = arith.constant dense : vector<2x3x4xi1> -// CHECK: %[[T0:.*]] = vector.insert %[[C1]], %[[C2]] [0] : vector<4xi1> into vector<3x4xi1> -// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C3]] [0] : vector<3x4xi1> into vector<2x3x4xi1> -// CHECK: return %[[T1]] : vector<2x3x4xi1> +// CHECK: %[[C0:.*]] = arith.constant +// CHECK-SAME{LITERAL}: dense<[[[true, true, true, false], [false, false, false, false], [false, false, false, false]], [[false, false, false, false], [false, false, false, false], [false, false, false, false]]]> : vector<2x3x4xi1> +// CHECK: return %[[C0]] : vector<2x3x4xi1> func.func @genbool_3d() -> vector<2x3x4xi1> { %v = vector.constant_mask [1, 1, 3] : vector<2x3x4xi1>