@@ -2943,6 +2943,11 @@ static VectorType trimLeadingOneDims(VectorType oldType) {
29432943 return VectorType::get (newShape, oldType.getElementType ());
29442944}
29452945
2946+ // / Return a smallVector of size `rank` containing all zeros.
2947+ static SmallVector<int64_t > splatZero (int64_t rank) {
2948+ return SmallVector<int64_t >(rank, 0 );
2949+ }
2950+
29462951// Casts away leading one dimensions in vector.extract_strided_slice's vector
29472952// input by inserting vector.shape_cast.
29482953struct CastAwayExtractStridedSliceLeadingOneDim
@@ -2969,8 +2974,8 @@ struct CastAwayExtractStridedSliceLeadingOneDim
29692974
29702975 Location loc = extractOp.getLoc ();
29712976
2972- Value newSrcVector = rewriter.create <vector::ShapeCastOp >(
2973- loc, newSrcType, extractOp.vector ());
2977+ Value newSrcVector = rewriter.create <vector::ExtractOp >(
2978+ loc, extractOp.vector (), splatZero (dropCount ));
29742979
29752980 // The offsets/sizes/strides attribute can have a less number of elements
29762981 // than the input vector's rank: it is meant for the leading dimensions.
@@ -2984,7 +2989,7 @@ struct CastAwayExtractStridedSliceLeadingOneDim
29842989 auto newExtractOp = rewriter.create <vector::ExtractStridedSliceOp>(
29852990 loc, newDstType, newSrcVector, newOffsets, newSizes, newStrides);
29862991
2987- rewriter.replaceOpWithNewOp <vector::ShapeCastOp >(extractOp, oldDstType,
2992+ rewriter.replaceOpWithNewOp <vector::BroadcastOp >(extractOp, oldDstType,
29882993 newExtractOp);
29892994
29902995 return success ();
@@ -3004,17 +3009,18 @@ struct CastAwayInsertStridedSliceLeadingOneDim
30043009 VectorType oldDstType = insertOp.getDestVectorType ();
30053010 VectorType newDstType = trimLeadingOneDims (oldDstType);
30063011
3007- if (newSrcType.getRank () == oldSrcType.getRank () &&
3008- newDstType.getRank () == oldDstType.getRank ())
3012+ int64_t srcDropCount = oldSrcType.getRank () - newSrcType.getRank ();
3013+ int64_t dstDropCount = oldDstType.getRank () - newDstType.getRank ();
3014+ if (srcDropCount == 0 && dstDropCount == 0 )
30093015 return failure ();
30103016
30113017 // Trim leading one dimensions from both operands.
30123018 Location loc = insertOp.getLoc ();
30133019
3014- Value newSrcVector = rewriter.create <vector::ShapeCastOp >(
3015- loc, newSrcType, insertOp.source ());
3016- Value newDstVector =
3017- rewriter. create <vector::ShapeCastOp>( loc, newDstType, insertOp.dest ());
3020+ Value newSrcVector = rewriter.create <vector::ExtractOp >(
3021+ loc, insertOp.source (), splatZero (srcDropCount ));
3022+ Value newDstVector = rewriter. create <vector::ExtractOp>(
3023+ loc, insertOp.dest (), splatZero (dstDropCount ));
30183024
30193025 auto newOffsets = rewriter.getArrayAttr (
30203026 insertOp.offsets ().getValue ().take_back (newDstType.getRank ()));
@@ -3024,7 +3030,7 @@ struct CastAwayInsertStridedSliceLeadingOneDim
30243030 auto newInsertOp = rewriter.create <vector::InsertStridedSliceOp>(
30253031 loc, newDstType, newSrcVector, newDstVector, newOffsets, newStrides);
30263032
3027- rewriter.replaceOpWithNewOp <vector::ShapeCastOp >(insertOp, oldDstType,
3033+ rewriter.replaceOpWithNewOp <vector::BroadcastOp >(insertOp, oldDstType,
30283034 newInsertOp);
30293035
30303036 return success ();
@@ -3068,7 +3074,7 @@ struct CastAwayTransferReadLeadingOneDim
30683074 auto newRead = rewriter.create <vector::TransferReadOp>(
30693075 read.getLoc (), newType, read.source (), read.indices (), newMap,
30703076 read.padding (), inBounds);
3071- rewriter.replaceOpWithNewOp <vector::ShapeCastOp >(read, oldType, newRead);
3077+ rewriter.replaceOpWithNewOp <vector::BroadcastOp >(read, oldType, newRead);
30723078
30733079 return success ();
30743080 }
@@ -3092,9 +3098,9 @@ struct CastAwayTransferWriteLeadingOneDim
30923098
30933099 VectorType oldType = write.getVectorType ();
30943100 VectorType newType = trimLeadingOneDims (oldType);
3095-
30963101 if (newType == oldType)
30973102 return failure ();
3103+ int64_t dropDim = oldType.getRank () - newType.getRank ();
30983104
30993105 AffineMap oldMap = write.permutation_map ();
31003106 ArrayRef<AffineExpr> newResults =
@@ -3108,44 +3114,15 @@ struct CastAwayTransferWriteLeadingOneDim
31083114 inBounds = rewriter.getArrayAttr (
31093115 write.in_boundsAttr ().getValue ().take_back (newType.getRank ()));
31103116
3111- auto newVector = rewriter.create <vector::ShapeCastOp >(
3112- write.getLoc (), newType, write.vector ());
3117+ auto newVector = rewriter.create <vector::ExtractOp >(
3118+ write.getLoc (), write.vector (), splatZero (dropDim ));
31133119 rewriter.replaceOpWithNewOp <vector::TransferWriteOp>(
31143120 write, newVector, write.source (), write.indices (), newMap, inBounds);
31153121
31163122 return success ();
31173123 }
31183124};
31193125
3120- template <typename BroadCastType>
3121- struct CastAwayBroadcastLeadingOneDim : public OpRewritePattern <BroadCastType> {
3122- using OpRewritePattern<BroadCastType>::OpRewritePattern;
3123-
3124- LogicalResult matchAndRewrite (BroadCastType broadcastOp,
3125- PatternRewriter &rewriter) const override {
3126- VectorType dstType =
3127- broadcastOp.getResult ().getType ().template dyn_cast <VectorType>();
3128- if (!dstType)
3129- return failure ();
3130- VectorType newDstType = trimLeadingOneDims (dstType);
3131- if (newDstType == dstType)
3132- return failure ();
3133- Location loc = broadcastOp.getLoc ();
3134- Value source = broadcastOp->getOperand (0 );
3135- VectorType srcVecType = source.getType ().template dyn_cast <VectorType>();
3136- if (srcVecType)
3137- srcVecType = trimLeadingOneDims (srcVecType);
3138- if (srcVecType && srcVecType != source.getType ()) {
3139- source = rewriter.create <vector::ShapeCastOp>(loc, srcVecType, source);
3140- }
3141- Value newBroadcastOp =
3142- rewriter.create <BroadCastType>(loc, newDstType, source);
3143- rewriter.replaceOpWithNewOp <vector::ShapeCastOp>(broadcastOp, dstType,
3144- newBroadcastOp);
3145- return success ();
3146- }
3147- };
3148-
31493126class CastAwayElementwiseLeadingOneDim : public RewritePattern {
31503127public:
31513128 CastAwayElementwiseLeadingOneDim (MLIRContext *context)
@@ -3161,14 +3138,12 @@ class CastAwayElementwiseLeadingOneDim : public RewritePattern {
31613138 VectorType newVecType = trimLeadingOneDims (vecType);
31623139 if (newVecType == vecType)
31633140 return failure ();
3164-
3141+ int64_t dropDim = vecType. getRank () - newVecType. getRank ();
31653142 SmallVector<Value, 4 > newOperands;
31663143 for (Value operand : op->getOperands ()) {
31673144 if (auto opVecType = operand.getType ().dyn_cast <VectorType>()) {
3168- auto newType =
3169- VectorType::get (newVecType.getShape (), opVecType.getElementType ());
3170- newOperands.push_back (rewriter.create <vector::ShapeCastOp>(
3171- op->getLoc (), newType, operand));
3145+ newOperands.push_back (rewriter.create <vector::ExtractOp>(
3146+ op->getLoc (), operand, splatZero (dropDim)));
31723147 } else {
31733148 newOperands.push_back (operand);
31743149 }
@@ -3178,69 +3153,12 @@ class CastAwayElementwiseLeadingOneDim : public RewritePattern {
31783153 state.addOperands (newOperands);
31793154 state.addTypes (newVecType);
31803155 Operation *newOp = rewriter.createOperation (state);
3181- rewriter.replaceOpWithNewOp <vector::ShapeCastOp >(op, vecType,
3156+ rewriter.replaceOpWithNewOp <vector::BroadcastOp >(op, vecType,
31823157 newOp->getResult (0 ));
31833158 return success ();
31843159 }
31853160};
31863161
3187- // If extractOp is only removing unit dimensions it can be transformed to a
3188- // shapecast.
3189- class ExtractToShapeCast final : public OpRewritePattern<ExtractOp> {
3190- public:
3191- using OpRewritePattern<ExtractOp>::OpRewritePattern;
3192-
3193- LogicalResult matchAndRewrite (ExtractOp extractOp,
3194- PatternRewriter &rewriter) const override {
3195- auto dstVecType = extractOp.getResult ().getType ().dyn_cast <VectorType>();
3196- if (!dstVecType || extractOp.getVectorType ().getNumElements () !=
3197- dstVecType.getNumElements ())
3198- return failure ();
3199- rewriter.replaceOpWithNewOp <ShapeCastOp>(extractOp, dstVecType,
3200- extractOp.vector ());
3201- return success ();
3202- }
3203- };
3204-
3205- // If insertOp is only inserting unit dimensions it can be transformed to a
3206- // shapecast.
3207- class InsertToShapeCast final : public OpRewritePattern<InsertOp> {
3208- public:
3209- using OpRewritePattern<InsertOp>::OpRewritePattern;
3210-
3211- LogicalResult matchAndRewrite (InsertOp insertOp,
3212- PatternRewriter &rewriter) const override {
3213- auto srcVecType = insertOp.getSourceType ().dyn_cast <VectorType>();
3214- if (!srcVecType || insertOp.getDestVectorType ().getNumElements () !=
3215- srcVecType.getNumElements ())
3216- return failure ();
3217- rewriter.replaceOpWithNewOp <ShapeCastOp>(
3218- insertOp, insertOp.getDestVectorType (), insertOp.source ());
3219- return success ();
3220- }
3221- };
3222-
3223- // BroadcastOp can only add dimensions or broadcast a dimension from 1 to N. In
3224- // the degenerated case where the broadcast only adds dimensions of size 1 it
3225- // can be replaced by a ShapeCastOp. This canonicalization checks if the total
3226- // number of elements is the same before and after the broadcast to detect if
3227- // the only change in the vector type are new dimensions of size 1.
3228- class BroadcastToShapeCast final : public OpRewritePattern<BroadcastOp> {
3229- public:
3230- using OpRewritePattern<BroadcastOp>::OpRewritePattern;
3231-
3232- LogicalResult matchAndRewrite (BroadcastOp broadcastOp,
3233- PatternRewriter &rewriter) const override {
3234- auto srcVecType = broadcastOp.getSourceType ().dyn_cast <VectorType>();
3235- if (!srcVecType || broadcastOp.getVectorType ().getNumElements () !=
3236- srcVecType.getNumElements ())
3237- return failure ();
3238- rewriter.replaceOpWithNewOp <ShapeCastOp>(
3239- broadcastOp, broadcastOp.getVectorType (), broadcastOp.source ());
3240- return success ();
3241- }
3242- };
3243-
32443162// Returns the values in `arrayAttr` as an integer vector.
32453163static SmallVector<int64_t , 4 > getIntValueVector (ArrayAttr arrayAttr) {
32463164 return llvm::to_vector<4 >(
@@ -3722,13 +3640,11 @@ void mlir::vector::populateShapeCastFoldingPatterns(
37223640
37233641void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns (
37243642 RewritePatternSet &patterns) {
3725- patterns.add <
3726- BroadcastToShapeCast, CastAwayExtractStridedSliceLeadingOneDim,
3727- CastAwayInsertStridedSliceLeadingOneDim,
3728- CastAwayTransferReadLeadingOneDim, CastAwayTransferWriteLeadingOneDim,
3729- CastAwayBroadcastLeadingOneDim<vector::BroadcastOp>,
3730- CastAwayBroadcastLeadingOneDim<SplatOp>, CastAwayElementwiseLeadingOneDim,
3731- ExtractToShapeCast, InsertToShapeCast>(patterns.getContext ());
3643+ patterns.add <CastAwayExtractStridedSliceLeadingOneDim,
3644+ CastAwayInsertStridedSliceLeadingOneDim,
3645+ CastAwayTransferReadLeadingOneDim,
3646+ CastAwayTransferWriteLeadingOneDim,
3647+ CastAwayElementwiseLeadingOneDim>(patterns.getContext ());
37323648 populateShapeCastFoldingPatterns (patterns);
37333649}
37343650
0 commit comments