@@ -3717,6 +3717,59 @@ OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
37173717 return getVector ();
37183718 if (succeeded (foldExtractStridedOpFromInsertChain (*this )))
37193719 return getResult ();
3720+
3721+ Attribute foldInput = adaptor.getVector ();
3722+ if (!foldInput) {
3723+ return {};
3724+ }
3725+
3726+ // rewrite : ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
3727+ if (auto splat = llvm::dyn_cast<SplatElementsAttr>(foldInput))
3728+ DenseElementsAttr::get (getType (), splat.getSplatValue <Attribute>());
3729+
3730+ // rewrite ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp.
3731+ if (auto dense = llvm::dyn_cast<DenseElementsAttr>(foldInput)) {
3732+ // TODO: Handle non-unit strides when they become available.
3733+ if (hasNonUnitStrides ())
3734+ return {};
3735+
3736+ Value sourceVector = getVector ();
3737+ auto sourceVecTy = llvm::cast<VectorType>(sourceVector.getType ());
3738+ ArrayRef<int64_t > sourceShape = sourceVecTy.getShape ();
3739+ SmallVector<int64_t , 4 > sourceStrides = computeStrides (sourceShape);
3740+
3741+ VectorType sliceVecTy = getType ();
3742+ ArrayRef<int64_t > sliceShape = sliceVecTy.getShape ();
3743+ int64_t sliceRank = sliceVecTy.getRank ();
3744+
3745+ // Expand offsets and sizes to match the vector rank.
3746+ SmallVector<int64_t , 4 > offsets (sliceRank, 0 );
3747+ copy (getI64SubArray (getOffsets ()), offsets.begin ());
3748+
3749+ SmallVector<int64_t , 4 > sizes (sourceShape);
3750+ copy (getI64SubArray (getSizes ()), sizes.begin ());
3751+
3752+ // Calculate the slice elements by enumerating all slice positions and
3753+ // linearizing them. The enumeration order is lexicographic which yields a
3754+ // sequence of monotonically increasing linearized position indices.
3755+ auto denseValuesBegin = dense.value_begin <Attribute>();
3756+ SmallVector<Attribute> sliceValues;
3757+ sliceValues.reserve (sliceVecTy.getNumElements ());
3758+ SmallVector<int64_t > currSlicePosition (offsets.begin (), offsets.end ());
3759+ do {
3760+ int64_t linearizedPosition = linearize (currSlicePosition, sourceStrides);
3761+ assert (linearizedPosition < sourceVecTy.getNumElements () &&
3762+ " Invalid index" );
3763+ sliceValues.push_back (*(denseValuesBegin + linearizedPosition));
3764+ } while (
3765+ succeeded (incSlicePosition (currSlicePosition, sliceShape, offsets)));
3766+
3767+ assert (static_cast <int64_t >(sliceValues.size ()) ==
3768+ sliceVecTy.getNumElements () &&
3769+ " Invalid number of slice elements" );
3770+ return DenseElementsAttr::get (sliceVecTy, sliceValues);
3771+ }
3772+
37203773 return {};
37213774}
37223775
@@ -3781,98 +3834,6 @@ class StridedSliceConstantMaskFolder final
37813834 }
37823835};
37833836
3784- // Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
3785- class StridedSliceSplatConstantFolder final
3786- : public OpRewritePattern<ExtractStridedSliceOp> {
3787- public:
3788- using OpRewritePattern::OpRewritePattern;
3789-
3790- LogicalResult matchAndRewrite (ExtractStridedSliceOp extractStridedSliceOp,
3791- PatternRewriter &rewriter) const override {
3792- // Return if 'ExtractStridedSliceOp' operand is not defined by a splat
3793- // ConstantOp.
3794- Value sourceVector = extractStridedSliceOp.getVector ();
3795- Attribute vectorCst;
3796- if (!matchPattern (sourceVector, m_Constant (&vectorCst)))
3797- return failure ();
3798-
3799- auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
3800- if (!splat)
3801- return failure ();
3802-
3803- auto newAttr = SplatElementsAttr::get (extractStridedSliceOp.getType (),
3804- splat.getSplatValue <Attribute>());
3805- rewriter.replaceOpWithNewOp <arith::ConstantOp>(extractStridedSliceOp,
3806- newAttr);
3807- return success ();
3808- }
3809- };
3810-
3811- // Pattern to rewrite a ExtractStridedSliceOp(non-splat ConstantOp) ->
3812- // ConstantOp.
3813- class StridedSliceNonSplatConstantFolder final
3814- : public OpRewritePattern<ExtractStridedSliceOp> {
3815- public:
3816- using OpRewritePattern::OpRewritePattern;
3817-
3818- LogicalResult matchAndRewrite (ExtractStridedSliceOp extractStridedSliceOp,
3819- PatternRewriter &rewriter) const override {
3820- // Return if 'ExtractStridedSliceOp' operand is not defined by a non-splat
3821- // ConstantOp.
3822- Value sourceVector = extractStridedSliceOp.getVector ();
3823- Attribute vectorCst;
3824- if (!matchPattern (sourceVector, m_Constant (&vectorCst)))
3825- return failure ();
3826-
3827- // The splat case is handled by `StridedSliceSplatConstantFolder`.
3828- auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
3829- if (!dense || dense.isSplat ())
3830- return failure ();
3831-
3832- // TODO: Handle non-unit strides when they become available.
3833- if (extractStridedSliceOp.hasNonUnitStrides ())
3834- return failure ();
3835-
3836- auto sourceVecTy = llvm::cast<VectorType>(sourceVector.getType ());
3837- ArrayRef<int64_t > sourceShape = sourceVecTy.getShape ();
3838- SmallVector<int64_t , 4 > sourceStrides = computeStrides (sourceShape);
3839-
3840- VectorType sliceVecTy = extractStridedSliceOp.getType ();
3841- ArrayRef<int64_t > sliceShape = sliceVecTy.getShape ();
3842- int64_t sliceRank = sliceVecTy.getRank ();
3843-
3844- // Expand offsets and sizes to match the vector rank.
3845- SmallVector<int64_t , 4 > offsets (sliceRank, 0 );
3846- copy (getI64SubArray (extractStridedSliceOp.getOffsets ()), offsets.begin ());
3847-
3848- SmallVector<int64_t , 4 > sizes (sourceShape);
3849- copy (getI64SubArray (extractStridedSliceOp.getSizes ()), sizes.begin ());
3850-
3851- // Calculate the slice elements by enumerating all slice positions and
3852- // linearizing them. The enumeration order is lexicographic which yields a
3853- // sequence of monotonically increasing linearized position indices.
3854- auto denseValuesBegin = dense.value_begin <Attribute>();
3855- SmallVector<Attribute> sliceValues;
3856- sliceValues.reserve (sliceVecTy.getNumElements ());
3857- SmallVector<int64_t > currSlicePosition (offsets.begin (), offsets.end ());
3858- do {
3859- int64_t linearizedPosition = linearize (currSlicePosition, sourceStrides);
3860- assert (linearizedPosition < sourceVecTy.getNumElements () &&
3861- " Invalid index" );
3862- sliceValues.push_back (*(denseValuesBegin + linearizedPosition));
3863- } while (
3864- succeeded (incSlicePosition (currSlicePosition, sliceShape, offsets)));
3865-
3866- assert (static_cast <int64_t >(sliceValues.size ()) ==
3867- sliceVecTy.getNumElements () &&
3868- " Invalid number of slice elements" );
3869- auto newAttr = DenseElementsAttr::get (sliceVecTy, sliceValues);
3870- rewriter.replaceOpWithNewOp <arith::ConstantOp>(extractStridedSliceOp,
3871- newAttr);
3872- return success ();
3873- }
3874- };
3875-
38763837// Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
38773838// BroadcastOp(ExtractStrideSliceOp).
38783839class StridedSliceBroadcast final
@@ -4016,8 +3977,7 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns(
40163977 RewritePatternSet &results, MLIRContext *context) {
40173978 // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
40183979 // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
4019- results.add <StridedSliceConstantMaskFolder, StridedSliceSplatConstantFolder,
4020- StridedSliceNonSplatConstantFolder, StridedSliceBroadcast,
3980+ results.add <StridedSliceConstantMaskFolder, StridedSliceBroadcast,
40213981 StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
40223982 context);
40233983}
@@ -5657,10 +5617,8 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
56575617
56585618 // shape_cast(constant) -> constant
56595619 if (auto splatAttr =
5660- llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource ())) {
5661- return DenseElementsAttr::get (resultType,
5662- splatAttr.getSplatValue <Attribute>());
5663- }
5620+ llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource ()))
5621+ return splatAttr.reshape (getType ());
56645622
56655623 // shape_cast(poison) -> poison
56665624 if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource ())) {
@@ -6004,10 +5962,9 @@ void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
60045962
60055963OpFoldResult vector::TransposeOp::fold (FoldAdaptor adaptor) {
60065964 // Eliminate splat constant transpose ops.
6007- if (auto attr =
6008- llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getVector ()))
6009- if (attr.isSplat ())
6010- return attr.reshape (getResultVectorType ());
5965+ if (auto splat =
5966+ llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector ()))
5967+ return splat.reshape (getResultVectorType ());
60115968
60125969 // Eliminate identity transpose ops. This happens when the dimensions of the
60135970 // input vector remain in their original order after the transpose operation.
0 commit comments