@@ -3019,94 +3019,78 @@ class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
30193019 }
30203020};
30213021
3022- // Pattern to rewrite a InsertOp(ConstantOp into ConstantOp) -> ConstantOp.
3023- class InsertOpConstantFolder final : public OpRewritePattern<InsertOp> {
3024- public:
3025- using OpRewritePattern::OpRewritePattern;
3026-
3027- // Do not create constants with more than `vectorSizeFoldThreashold` elements,
3028- // unless the source vector constant has a single use.
3029- static constexpr int64_t vectorSizeFoldThreshold = 256 ;
3030-
3031- LogicalResult matchAndRewrite (InsertOp op,
3032- PatternRewriter &rewriter) const override {
3033- // TODO: Canonicalization for dynamic position not implemented yet.
3034- if (op.hasDynamicPosition ())
3035- return failure ();
3022+ } // namespace
30363023
3037- // Return if 'InsertOp' operand is not defined by a compatible vector
3038- // ConstantOp.
3039- TypedValue<VectorType> destVector = op.getDest ();
3040- Attribute vectorDestCst;
3041- if (!matchPattern (destVector, m_Constant (&vectorDestCst)))
3042- return failure ();
3043- auto denseDest = llvm::dyn_cast<DenseElementsAttr>(vectorDestCst);
3044- if (!denseDest)
3045- return failure ();
3024+ static Attribute
3025+ foldDenseElementsAttrDestInsertOp (InsertOp insertOp, Attribute srcAttr,
3026+ Attribute dstAttr,
3027+ int64_t maxVectorSizeFoldThreshold) {
3028+ if (insertOp.hasDynamicPosition ())
3029+ return {};
30463030
3047- VectorType destTy = destVector. getType ( );
3048- if (destTy. isScalable () )
3049- return failure () ;
3031+ auto denseDst = llvm::dyn_cast_if_present<DenseElementsAttr>(dstAttr );
3032+ if (!denseDst )
3033+ return {} ;
30503034
3051- // Make sure we do not create too many large constants.
3052- if (destTy.getNumElements () > vectorSizeFoldThreshold &&
3053- !destVector.hasOneUse ())
3054- return failure ();
3035+ if (!srcAttr) {
3036+ return {};
3037+ }
30553038
3056- Value sourceValue = op.getSource ();
3057- Attribute sourceCst;
3058- if (!matchPattern (sourceValue, m_Constant (&sourceCst)))
3059- return failure ();
3039+ VectorType destTy = insertOp.getDestVectorType ();
3040+ if (destTy.isScalable ())
3041+ return {};
30603042
3061- // Calculate the linearized position of the continuous chunk of elements to
3062- // insert.
3063- llvm::SmallVector<int64_t > completePositions (destTy.getRank (), 0 );
3064- copy (op.getStaticPosition (), completePositions.begin ());
3065- int64_t insertBeginPosition =
3066- linearize (completePositions, computeStrides (destTy.getShape ()));
3067-
3068- SmallVector<Attribute> insertedValues;
3069- Type destEltType = destTy.getElementType ();
3070-
3071- // The `convertIntegerAttr` method specifically handles the case
3072- // for `llvm.mlir.constant` which can hold an attribute with a
3073- // different type than the return type.
3074- if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(sourceCst)) {
3075- for (auto value : denseSource.getValues <Attribute>())
3076- insertedValues.push_back (convertIntegerAttr (value, destEltType));
3077- } else {
3078- insertedValues.push_back (convertIntegerAttr (sourceCst, destEltType));
3079- }
3043+ // Make sure we do not create too many large constants.
3044+ if (destTy.getNumElements () > maxVectorSizeFoldThreshold &&
3045+ !insertOp->hasOneUse ())
3046+ return {};
30803047
3081- auto allValues = llvm::to_vector (denseDest.getValues <Attribute>());
3082- copy (insertedValues, allValues.begin () + insertBeginPosition);
3083- auto newAttr = DenseElementsAttr::get (destTy, allValues);
3048+ // Calculate the linearized position of the continuous chunk of elements to
3049+ // insert.
3050+ llvm::SmallVector<int64_t > completePositions (destTy.getRank (), 0 );
3051+ copy (insertOp.getStaticPosition (), completePositions.begin ());
3052+ int64_t insertBeginPosition =
3053+ linearize (completePositions, computeStrides (destTy.getShape ()));
30843054
3085- rewriter.replaceOpWithNewOp <arith::ConstantOp>(op, newAttr);
3086- return success ();
3087- }
3055+ SmallVector<Attribute> insertedValues;
3056+ Type destEltType = destTy.getElementType ();
30883057
3089- private:
30903058 // / Converts the expected type to an IntegerAttr if there's
30913059 // / a mismatch.
3092- Attribute convertIntegerAttr (Attribute attr, Type expectedType) const {
3060+ auto convertIntegerAttr = [] (Attribute attr, Type expectedType) -> Attribute {
30933061 if (auto intAttr = mlir::dyn_cast<IntegerAttr>(attr)) {
30943062 if (intAttr.getType () != expectedType)
30953063 return IntegerAttr::get (expectedType, intAttr.getInt ());
30963064 }
30973065 return attr;
3066+ };
3067+
3068+ // The `convertIntegerAttr` method specifically handles the case
3069+ // for `llvm.mlir.constant` which can hold an attribute with a
3070+ // different type than the return type.
3071+ if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
3072+ for (auto value : denseSource.getValues <Attribute>())
3073+ insertedValues.push_back (convertIntegerAttr (value, destEltType));
3074+ } else {
3075+ insertedValues.push_back (convertIntegerAttr (srcAttr, destEltType));
30983076 }
3099- };
31003077
3101- } // namespace
3078+ auto allValues = llvm::to_vector (denseDst.getValues <Attribute>());
3079+ copy (insertedValues, allValues.begin () + insertBeginPosition);
3080+ auto newAttr = DenseElementsAttr::get (destTy, allValues);
3081+
3082+ return newAttr;
3083+ }
31023084
31033085void InsertOp::getCanonicalizationPatterns (RewritePatternSet &results,
31043086 MLIRContext *context) {
3105- results.add <InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
3106- InsertOpConstantFolder>(context);
3087+ results.add <InsertToBroadcast, BroadcastFolder, InsertSplatToSplat>(context);
31073088}
31083089
31093090OpFoldResult vector::InsertOp::fold (FoldAdaptor adaptor) {
3091+ // Do not create constants with more than `vectorSizeFoldThreashold` elements,
3092+ // unless the source vector constant has a single use.
3093+ constexpr int64_t vectorSizeFoldThreshold = 256 ;
31103094 // Fold "vector.insert %v, %dest [] : vector<2x2xf32> from vector<2x2xf32>" to
31113095 // %v. Note: Do not fold "vector.insert %v, %dest [] : f32 into vector<f32>"
31123096 // (type mismatch).
@@ -3118,6 +3102,11 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
31183102 if (auto res = foldPoisonIndexInsertExtractOp (
31193103 getContext (), adaptor.getStaticPosition (), kPoisonIndex ))
31203104 return res;
3105+ if (auto res = foldDenseElementsAttrDestInsertOp (*this , adaptor.getSource (),
3106+ adaptor.getDest (),
3107+ vectorSizeFoldThreshold)) {
3108+ return res;
3109+ }
31213110
31223111 return {};
31233112}
0 commit comments