@@ -3013,94 +3013,78 @@ class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
30133013 }
30143014};
30153015
3016- // Pattern to rewrite a InsertOp(ConstantOp into ConstantOp) -> ConstantOp.
3017- class InsertOpConstantFolder final : public OpRewritePattern<InsertOp> {
3018- public:
3019- using OpRewritePattern::OpRewritePattern;
3020-
3021- // Do not create constants with more than `vectorSizeFoldThreashold` elements,
3022- // unless the source vector constant has a single use.
3023- static constexpr int64_t vectorSizeFoldThreshold = 256 ;
3024-
3025- LogicalResult matchAndRewrite (InsertOp op,
3026- PatternRewriter &rewriter) const override {
3027- // TODO: Canonicalization for dynamic position not implemented yet.
3028- if (op.hasDynamicPosition ())
3029- return failure ();
3016+ } // namespace
30303017
3031- // Return if 'InsertOp' operand is not defined by a compatible vector
3032- // ConstantOp.
3033- TypedValue<VectorType> destVector = op.getDest ();
3034- Attribute vectorDestCst;
3035- if (!matchPattern (destVector, m_Constant (&vectorDestCst)))
3036- return failure ();
3037- auto denseDest = llvm::dyn_cast<DenseElementsAttr>(vectorDestCst);
3038- if (!denseDest)
3039- return failure ();
3018+ static Attribute
3019+ foldDenseElementsAttrDestInsertOp (InsertOp insertOp, Attribute srcAttr,
3020+ Attribute dstAttr,
3021+ int64_t maxVectorSizeFoldThreshold) {
3022+ if (insertOp.hasDynamicPosition ())
3023+ return {};
30403024
3041- VectorType destTy = destVector. getType ( );
3042- if (destTy. isScalable () )
3043- return failure () ;
3025+ auto denseDst = llvm::dyn_cast_if_present<DenseElementsAttr>(dstAttr );
3026+ if (!denseDst )
3027+ return {} ;
30443028
3045- // Make sure we do not create too many large constants.
3046- if (destTy.getNumElements () > vectorSizeFoldThreshold &&
3047- !destVector.hasOneUse ())
3048- return failure ();
3029+ if (!srcAttr) {
3030+ return {};
3031+ }
30493032
3050- Value sourceValue = op.getSource ();
3051- Attribute sourceCst;
3052- if (!matchPattern (sourceValue, m_Constant (&sourceCst)))
3053- return failure ();
3033+ VectorType destTy = insertOp.getDestVectorType ();
3034+ if (destTy.isScalable ())
3035+ return {};
30543036
3055- // Calculate the linearized position of the continuous chunk of elements to
3056- // insert.
3057- llvm::SmallVector<int64_t > completePositions (destTy.getRank (), 0 );
3058- copy (op.getStaticPosition (), completePositions.begin ());
3059- int64_t insertBeginPosition =
3060- linearize (completePositions, computeStrides (destTy.getShape ()));
3061-
3062- SmallVector<Attribute> insertedValues;
3063- Type destEltType = destTy.getElementType ();
3064-
3065- // The `convertIntegerAttr` method specifically handles the case
3066- // for `llvm.mlir.constant` which can hold an attribute with a
3067- // different type than the return type.
3068- if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(sourceCst)) {
3069- for (auto value : denseSource.getValues <Attribute>())
3070- insertedValues.push_back (convertIntegerAttr (value, destEltType));
3071- } else {
3072- insertedValues.push_back (convertIntegerAttr (sourceCst, destEltType));
3073- }
3037+ // Make sure we do not create too many large constants.
3038+ if (destTy.getNumElements () > maxVectorSizeFoldThreshold &&
3039+ !insertOp->hasOneUse ())
3040+ return {};
30743041
3075- auto allValues = llvm::to_vector (denseDest.getValues <Attribute>());
3076- copy (insertedValues, allValues.begin () + insertBeginPosition);
3077- auto newAttr = DenseElementsAttr::get (destTy, allValues);
3042+ // Calculate the linearized position of the continuous chunk of elements to
3043+ // insert.
3044+ llvm::SmallVector<int64_t > completePositions (destTy.getRank (), 0 );
3045+ copy (insertOp.getStaticPosition (), completePositions.begin ());
3046+ int64_t insertBeginPosition =
3047+ linearize (completePositions, computeStrides (destTy.getShape ()));
30783048
3079- rewriter.replaceOpWithNewOp <arith::ConstantOp>(op, newAttr);
3080- return success ();
3081- }
3049+ SmallVector<Attribute> insertedValues;
3050+ Type destEltType = destTy.getElementType ();
30823051
3083- private:
30843052 // / Converts the expected type to an IntegerAttr if there's
30853053 // / a mismatch.
3086- Attribute convertIntegerAttr (Attribute attr, Type expectedType) const {
3054+ auto convertIntegerAttr = [] (Attribute attr, Type expectedType) -> Attribute {
30873055 if (auto intAttr = mlir::dyn_cast<IntegerAttr>(attr)) {
30883056 if (intAttr.getType () != expectedType)
30893057 return IntegerAttr::get (expectedType, intAttr.getInt ());
30903058 }
30913059 return attr;
3060+ };
3061+
3062+ // The `convertIntegerAttr` method specifically handles the case
3063+ // for `llvm.mlir.constant` which can hold an attribute with a
3064+ // different type than the return type.
3065+ if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
3066+ for (auto value : denseSource.getValues <Attribute>())
3067+ insertedValues.push_back (convertIntegerAttr (value, destEltType));
3068+ } else {
3069+ insertedValues.push_back (convertIntegerAttr (srcAttr, destEltType));
30923070 }
3093- };
30943071
3095- } // namespace
3072+ auto allValues = llvm::to_vector (denseDst.getValues <Attribute>());
3073+ copy (insertedValues, allValues.begin () + insertBeginPosition);
3074+ auto newAttr = DenseElementsAttr::get (destTy, allValues);
3075+
3076+ return newAttr;
3077+ }
30963078
30973079void InsertOp::getCanonicalizationPatterns (RewritePatternSet &results,
30983080 MLIRContext *context) {
3099- results.add <InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
3100- InsertOpConstantFolder>(context);
3081+ results.add <InsertToBroadcast, BroadcastFolder, InsertSplatToSplat>(context);
31013082}
31023083
31033084OpFoldResult vector::InsertOp::fold (FoldAdaptor adaptor) {
3085+ // Do not create constants with more than `vectorSizeFoldThreashold` elements,
3086+ // unless the source vector constant has a single use.
3087+ constexpr int64_t vectorSizeFoldThreshold = 256 ;
31043088 // Fold "vector.insert %v, %dest [] : vector<2x2xf32> from vector<2x2xf32>" to
31053089 // %v. Note: Do not fold "vector.insert %v, %dest [] : f32 into vector<f32>"
31063090 // (type mismatch).
@@ -3112,6 +3096,11 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
31123096 if (auto res = foldPoisonIndexInsertExtractOp (
31133097 getContext (), adaptor.getStaticPosition (), kPoisonIndex ))
31143098 return res;
3099+ if (auto res = foldDenseElementsAttrDestInsertOp (*this , adaptor.getSource (),
3100+ adaptor.getDest (),
3101+ vectorSizeFoldThreshold)) {
3102+ return res;
3103+ }
31153104
31163105 return {};
31173106}
0 commit comments