@@ -3714,12 +3714,67 @@ foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
37143714 return failure ();
37153715}
37163716
3717+ // ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp.
3718+ static OpFoldResult
3719+ foldExtractStridedSliceNonSplatConstant (ExtractStridedSliceOp op,
3720+ Attribute foldInput) {
3721+
3722+ auto dense = llvm::dyn_cast_if_present<DenseElementsAttr>(foldInput);
3723+ if (!dense)
3724+ return {};
3725+
3726+ // TODO: Handle non-unit strides when they become available.
3727+ if (op.hasNonUnitStrides ())
3728+ return {};
3729+
3730+ VectorType sourceVecTy = op.getSourceVectorType ();
3731+ ArrayRef<int64_t > sourceShape = sourceVecTy.getShape ();
3732+ SmallVector<int64_t , 4 > sourceStrides = computeStrides (sourceShape);
3733+
3734+ VectorType sliceVecTy = op.getType ();
3735+ ArrayRef<int64_t > sliceShape = sliceVecTy.getShape ();
3736+ int64_t rank = sliceVecTy.getRank ();
3737+
3738+ // Expand offsets and sizes to match the vector rank.
3739+ SmallVector<int64_t , 4 > offsets (rank, 0 );
3740+ copy (getI64SubArray (op.getOffsets ()), offsets.begin ());
3741+
3742+ SmallVector<int64_t , 4 > sizes (sourceShape);
3743+ copy (getI64SubArray (op.getSizes ()), sizes.begin ());
3744+
3745+ // Calculate the slice elements by enumerating all slice positions and
3746+ // linearizing them. The enumeration order is lexicographic which yields a
3747+ // sequence of monotonically increasing linearized position indices.
3748+ const auto denseValuesBegin = dense.value_begin <Attribute>();
3749+ SmallVector<Attribute> sliceValues;
3750+ sliceValues.reserve (sliceVecTy.getNumElements ());
3751+ SmallVector<int64_t > currSlicePosition (offsets.begin (), offsets.end ());
3752+ do {
3753+ int64_t linearizedPosition = linearize (currSlicePosition, sourceStrides);
3754+ assert (linearizedPosition < sourceVecTy.getNumElements () &&
3755+ " Invalid index" );
3756+ sliceValues.push_back (*(denseValuesBegin + linearizedPosition));
3757+ } while (succeeded (incSlicePosition (currSlicePosition, sliceShape, offsets)));
3758+
3759+ assert (static_cast <int64_t >(sliceValues.size ()) ==
3760+ sliceVecTy.getNumElements () &&
3761+ " Invalid number of slice elements" );
3762+ return DenseElementsAttr::get (sliceVecTy, sliceValues);
3763+ }
3764+
37173765OpFoldResult ExtractStridedSliceOp::fold (FoldAdaptor adaptor) {
37183766 if (getSourceVectorType () == getResult ().getType ())
37193767 return getVector ();
37203768 if (succeeded (foldExtractStridedOpFromInsertChain (*this )))
37213769 return getResult ();
3722- return {};
3770+
3771+ // ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
3772+ if (auto splat =
3773+ llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector ()))
3774+ DenseElementsAttr::get (getType (), splat.getSplatValue <Attribute>());
3775+
3776+ // ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp.
3777+ return foldExtractStridedSliceNonSplatConstant (*this , adaptor.getVector ());
37233778}
37243779
37253780void ExtractStridedSliceOp::getOffsets (SmallVectorImpl<int64_t > &results) {
@@ -3783,98 +3838,6 @@ class StridedSliceConstantMaskFolder final
37833838 }
37843839};
37853840
3786- // Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
3787- class StridedSliceSplatConstantFolder final
3788- : public OpRewritePattern<ExtractStridedSliceOp> {
3789- public:
3790- using OpRewritePattern::OpRewritePattern;
3791-
3792- LogicalResult matchAndRewrite (ExtractStridedSliceOp extractStridedSliceOp,
3793- PatternRewriter &rewriter) const override {
3794- // Return if 'ExtractStridedSliceOp' operand is not defined by a splat
3795- // ConstantOp.
3796- Value sourceVector = extractStridedSliceOp.getVector ();
3797- Attribute vectorCst;
3798- if (!matchPattern (sourceVector, m_Constant (&vectorCst)))
3799- return failure ();
3800-
3801- auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
3802- if (!splat)
3803- return failure ();
3804-
3805- auto newAttr = SplatElementsAttr::get (extractStridedSliceOp.getType (),
3806- splat.getSplatValue <Attribute>());
3807- rewriter.replaceOpWithNewOp <arith::ConstantOp>(extractStridedSliceOp,
3808- newAttr);
3809- return success ();
3810- }
3811- };
3812-
3813- // Pattern to rewrite a ExtractStridedSliceOp(non-splat ConstantOp) ->
3814- // ConstantOp.
3815- class StridedSliceNonSplatConstantFolder final
3816- : public OpRewritePattern<ExtractStridedSliceOp> {
3817- public:
3818- using OpRewritePattern::OpRewritePattern;
3819-
3820- LogicalResult matchAndRewrite (ExtractStridedSliceOp extractStridedSliceOp,
3821- PatternRewriter &rewriter) const override {
3822- // Return if 'ExtractStridedSliceOp' operand is not defined by a non-splat
3823- // ConstantOp.
3824- Value sourceVector = extractStridedSliceOp.getVector ();
3825- Attribute vectorCst;
3826- if (!matchPattern (sourceVector, m_Constant (&vectorCst)))
3827- return failure ();
3828-
3829- // The splat case is handled by `StridedSliceSplatConstantFolder`.
3830- auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
3831- if (!dense || dense.isSplat ())
3832- return failure ();
3833-
3834- // TODO: Handle non-unit strides when they become available.
3835- if (extractStridedSliceOp.hasNonUnitStrides ())
3836- return failure ();
3837-
3838- auto sourceVecTy = llvm::cast<VectorType>(sourceVector.getType ());
3839- ArrayRef<int64_t > sourceShape = sourceVecTy.getShape ();
3840- SmallVector<int64_t , 4 > sourceStrides = computeStrides (sourceShape);
3841-
3842- VectorType sliceVecTy = extractStridedSliceOp.getType ();
3843- ArrayRef<int64_t > sliceShape = sliceVecTy.getShape ();
3844- int64_t sliceRank = sliceVecTy.getRank ();
3845-
3846- // Expand offsets and sizes to match the vector rank.
3847- SmallVector<int64_t , 4 > offsets (sliceRank, 0 );
3848- copy (getI64SubArray (extractStridedSliceOp.getOffsets ()), offsets.begin ());
3849-
3850- SmallVector<int64_t , 4 > sizes (sourceShape);
3851- copy (getI64SubArray (extractStridedSliceOp.getSizes ()), sizes.begin ());
3852-
3853- // Calculate the slice elements by enumerating all slice positions and
3854- // linearizing them. The enumeration order is lexicographic which yields a
3855- // sequence of monotonically increasing linearized position indices.
3856- auto denseValuesBegin = dense.value_begin <Attribute>();
3857- SmallVector<Attribute> sliceValues;
3858- sliceValues.reserve (sliceVecTy.getNumElements ());
3859- SmallVector<int64_t > currSlicePosition (offsets.begin (), offsets.end ());
3860- do {
3861- int64_t linearizedPosition = linearize (currSlicePosition, sourceStrides);
3862- assert (linearizedPosition < sourceVecTy.getNumElements () &&
3863- " Invalid index" );
3864- sliceValues.push_back (*(denseValuesBegin + linearizedPosition));
3865- } while (
3866- succeeded (incSlicePosition (currSlicePosition, sliceShape, offsets)));
3867-
3868- assert (static_cast <int64_t >(sliceValues.size ()) ==
3869- sliceVecTy.getNumElements () &&
3870- " Invalid number of slice elements" );
3871- auto newAttr = DenseElementsAttr::get (sliceVecTy, sliceValues);
3872- rewriter.replaceOpWithNewOp <arith::ConstantOp>(extractStridedSliceOp,
3873- newAttr);
3874- return success ();
3875- }
3876- };
3877-
38783841// Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
38793842// BroadcastOp(ExtractStrideSliceOp).
38803843class StridedSliceBroadcast final
@@ -4018,8 +3981,7 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns(
40183981 RewritePatternSet &results, MLIRContext *context) {
40193982 // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
40203983 // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
4021- results.add <StridedSliceConstantMaskFolder, StridedSliceSplatConstantFolder,
4022- StridedSliceNonSplatConstantFolder, StridedSliceBroadcast,
3984+ results.add <StridedSliceConstantMaskFolder, StridedSliceBroadcast,
40233985 StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
40243986 context);
40253987}
@@ -5659,10 +5621,8 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
56595621
56605622 // shape_cast(constant) -> constant
56615623 if (auto splatAttr =
5662- llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource ())) {
5663- return DenseElementsAttr::get (resultType,
5664- splatAttr.getSplatValue <Attribute>());
5665- }
5624+ llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource ()))
5625+ return splatAttr.reshape (getType ());
56665626
56675627 // shape_cast(poison) -> poison
56685628 if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource ())) {
@@ -6006,10 +5966,9 @@ void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
60065966
60075967OpFoldResult vector::TransposeOp::fold (FoldAdaptor adaptor) {
60085968 // Eliminate splat constant transpose ops.
6009- if (auto attr =
6010- llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getVector ()))
6011- if (attr.isSplat ())
6012- return attr.reshape (getResultVectorType ());
5969+ if (auto splat =
5970+ llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector ()))
5971+ return splat.reshape (getResultVectorType ());
60135972
60145973 // Eliminate poison transpose ops.
60155974 if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getVector ()))
0 commit comments