@@ -1289,19 +1289,49 @@ LogicalResult vector::ExtractElementOp::verify() {
12891289 return success ();
12901290}
12911291
1292+ // / Consider the defining operation `defOp` of \p value. If `defOp` is a
1293+ // / vector.splat or a vector.broadcast with a scalar operand, return the scalar
1294+ // / value that is splatted. Otherwise return a null Value.
1295+ // /
1296+ // / Example:
1297+ // /
1298+ // / scalar_source --> vector.splat --> value > return scalar_source
1299+ // / scalar_source --> vector.broadcast --> value > return scalar_source
1300+ // / vector_source --> vector.broadcast --> value > return {}
1301+ // / * --> some.other.op --> value > return {}
1302+ static Value getSplatSource (Value value) {
1303+
1304+ // Block argument:
1305+ Operation *defOp = value.getDefiningOp ();
1306+ if (!defOp)
1307+ return {};
1308+
1309+ // Splat:
1310+ auto splat = dyn_cast<vector::SplatOp>(defOp);
1311+ if (splat)
1312+ return splat.getInput ();
1313+
1314+ auto broadcast = dyn_cast<vector::BroadcastOp>(defOp);
1315+
1316+ // Not broadcast (and not splat):
1317+ if (!broadcast)
1318+ return {};
1319+
1320+ // Broadcast of a vector:
1321+ if (isa<VectorType>(broadcast.getSourceType ()))
1322+ return {};
1323+
1324+ // Broadcast of a scalar:
1325+ return broadcast.getSource ();
1326+ }
1327+
12921328OpFoldResult vector::ExtractElementOp::fold (FoldAdaptor adaptor) {
12931329 // Skip the 0-D vector here now.
12941330 if (!adaptor.getPosition ())
12951331 return {};
12961332
1297- // Fold extractelement (splat X) -> X.
1298- if (auto splat = getVector ().getDefiningOp <vector::SplatOp>())
1299- return splat.getInput ();
1300-
1301- // Fold extractelement(broadcast(X)) -> X.
1302- if (auto broadcast = getVector ().getDefiningOp <vector::BroadcastOp>())
1303- if (!llvm::isa<VectorType>(broadcast.getSource ().getType ()))
1304- return broadcast.getSource ();
1333+ if (auto splatValue = getSplatSource (getVector ()))
1334+ return splatValue;
13051335
13061336 auto src = dyn_cast_or_null<DenseElementsAttr>(adaptor.getVector ());
13071337 auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition ());
@@ -2514,12 +2544,14 @@ OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
25142544// /
25152545// / %0 = vector.from_elements %a, %a, %a : vector<3xf32>
25162546// / ==> rewrite to vector.splat %a : vector<3xf32>
2517- static LogicalResult rewriteFromElementsAsSplat (FromElementsOp fromElementsOp,
2518- PatternRewriter &rewriter) {
2547+ static LogicalResult
2548+ rewriteFromElementsAsBroadcast (FromElementsOp fromElementsOp,
2549+ PatternRewriter &rewriter) {
25192550 if (!llvm::all_equal (fromElementsOp.getElements ()))
25202551 return failure ();
2521- rewriter.replaceOpWithNewOp <SplatOp>(fromElementsOp, fromElementsOp.getType (),
2522- fromElementsOp.getElements ().front ());
2552+ rewriter.replaceOpWithNewOp <BroadcastOp>(
2553+ fromElementsOp, fromElementsOp.getType (),
2554+ fromElementsOp.getElements ().front ());
25232555 return success ();
25242556}
25252557
@@ -2550,7 +2582,7 @@ class FromElementsToShapeCast : public OpRewritePattern<FromElementsOp> {
25502582 LogicalResult matchAndRewrite (FromElementsOp fromElements,
25512583 PatternRewriter &rewriter) const override {
25522584
2553- // Handled by `rewriteFromElementsAsSplat `
2585+ // Handled by `rewriteFromElementsAsBroadcast `
25542586 if (fromElements.getType ().getNumElements () == 1 )
25552587 return failure ();
25562588
@@ -2644,7 +2676,7 @@ class FromElementsToShapeCast : public OpRewritePattern<FromElementsOp> {
26442676
26452677void FromElementsOp::getCanonicalizationPatterns (RewritePatternSet &results,
26462678 MLIRContext *context) {
2647- results.add (rewriteFromElementsAsSplat );
2679+ results.add (rewriteFromElementsAsBroadcast );
26482680 results.add <FromElementsToShapeCast>(context);
26492681}
26502682
@@ -3088,23 +3120,18 @@ struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
30883120 }
30893121};
30903122
3091- // / Pattern to rewrite a ShuffleOp(SplatOp, SplatOp) to SplatOp.
3123+ // / Pattern to rewrite shuffle(splat-like(v), splat-like(v)) as broadcast(v)
30923124class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
30933125public:
30943126 using OpRewritePattern::OpRewritePattern;
30953127
30963128 LogicalResult matchAndRewrite (ShuffleOp op,
30973129 PatternRewriter &rewriter) const override {
3098- auto v1Splat = op.getV1 ().getDefiningOp <SplatOp>();
3099- auto v2Splat = op.getV2 ().getDefiningOp <SplatOp>();
3100-
3101- if (!v1Splat || !v2Splat)
3130+ Value splat = getSplatSource (op.getV1 ());
3131+ if (!splat || getSplatSource (op.getV2 ()) != splat)
31023132 return failure ();
31033133
3104- if (v1Splat.getInput () != v2Splat.getInput ())
3105- return failure ();
3106-
3107- rewriter.replaceOpWithNewOp <SplatOp>(op, op.getType (), v1Splat.getInput ());
3134+ rewriter.replaceOpWithNewOp <BroadcastOp>(op, op.getType (), splat);
31083135 return success ();
31093136 }
31103137};
@@ -3314,23 +3341,19 @@ class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
33143341 }
33153342};
33163343
3317- // / Pattern to rewrite a InsertOp(SplatOp, SplatOp) to SplatOp.
3344+ // / Pattern to rewrite a insert(splat-like(v), splat-like(v)) as broadcast(v)
33183345class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
33193346public:
33203347 using OpRewritePattern::OpRewritePattern;
33213348
33223349 LogicalResult matchAndRewrite (InsertOp op,
33233350 PatternRewriter &rewriter) const override {
3324- auto srcSplat = op.getValueToStore ().getDefiningOp <SplatOp>();
3325- auto dstSplat = op.getDest ().getDefiningOp <SplatOp>();
33263351
3327- if (!srcSplat || !dstSplat)
3352+ Value splat = getSplatSource (op.getValueToStore ());
3353+ if (!splat || getSplatSource (op.getDest ()) != splat)
33283354 return failure ();
33293355
3330- if (srcSplat.getInput () != dstSplat.getInput ())
3331- return failure ();
3332-
3333- rewriter.replaceOpWithNewOp <SplatOp>(op, op.getType (), srcSplat.getInput ());
3356+ rewriter.replaceOpWithNewOp <BroadcastOp>(op, op.getType (), splat);
33343357 return success ();
33353358 }
33363359};
@@ -3598,27 +3621,21 @@ LogicalResult InsertStridedSliceOp::verify() {
35983621}
35993622
36003623namespace {
3601- // / Pattern to rewrite an InsertStridedSliceOp(SplatOp(X):src_type,
3602- // / SplatOp(X):dst_type) to SplatOp(X):dst_type.
3624+ // / Rewrite insert_strided_slice(splat-like(v), splat-like(v)) as v
36033625class FoldInsertStridedSliceSplat final
36043626 : public OpRewritePattern<InsertStridedSliceOp> {
36053627public:
36063628 using OpRewritePattern::OpRewritePattern;
36073629
36083630 LogicalResult matchAndRewrite (InsertStridedSliceOp insertStridedSliceOp,
36093631 PatternRewriter &rewriter) const override {
3610- auto srcSplatOp =
3611- insertStridedSliceOp.getValueToStore ().getDefiningOp <vector::SplatOp>();
3612- auto destSplatOp =
3613- insertStridedSliceOp.getDest ().getDefiningOp <vector::SplatOp>();
36143632
3615- if (!srcSplatOp || !destSplatOp)
3633+ auto dst = insertStridedSliceOp.getDest ();
3634+ auto splat = getSplatSource (insertStridedSliceOp.getValueToStore ());
3635+ if (!splat || getSplatSource (dst) != splat)
36163636 return failure ();
36173637
3618- if (srcSplatOp.getInput () != destSplatOp.getInput ())
3619- return failure ();
3620-
3621- rewriter.replaceOp (insertStridedSliceOp, insertStridedSliceOp.getDest ());
3638+ rewriter.replaceOp (insertStridedSliceOp, dst);
36223639 return success ();
36233640 }
36243641};
@@ -4197,17 +4214,18 @@ class StridedSliceBroadcast final
41974214 }
41984215};
41994216
4200- // / Pattern to rewrite an ExtractStridedSliceOp(SplatOp) to SplatOp.
4217+ // / Rewrite extract_strided_slice(splat-like(v)) with broadcast(v)
42014218class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
42024219public:
42034220 using OpRewritePattern::OpRewritePattern;
42044221
42054222 LogicalResult matchAndRewrite (ExtractStridedSliceOp op,
42064223 PatternRewriter &rewriter) const override {
4207- auto splat = op.getVector ().getDefiningOp <SplatOp>();
4224+
4225+ Value splat = getSplatSource (op.getVector ());
42084226 if (!splat)
42094227 return failure ();
4210- rewriter.replaceOpWithNewOp <SplatOp >(op, op.getType (), splat. getInput () );
4228+ rewriter.replaceOpWithNewOp <BroadcastOp >(op, op.getType (), splat);
42114229 return success ();
42124230 }
42134231};
@@ -6357,19 +6375,19 @@ class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
63576375 }
63586376};
63596377
6360- // Folds transpose(splat x : src_type) : res_type into splat x : res_type.
6378+ // / Replace transpose(splat-like(v)) with broadcast(v)
63616379class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> {
63626380public:
63636381 using OpRewritePattern::OpRewritePattern;
63646382
63656383 LogicalResult matchAndRewrite (TransposeOp transposeOp,
63666384 PatternRewriter &rewriter) const override {
6367- auto splatOp = transposeOp.getVector (). getDefiningOp <vector::SplatOp>( );
6368- if (!splatOp )
6385+ Value splat = getSplatSource ( transposeOp.getVector ());
6386+ if (!splat )
63696387 return failure ();
63706388
6371- rewriter.replaceOpWithNewOp <vector::SplatOp >(
6372- transposeOp, transposeOp.getResultVectorType (), splatOp. getInput () );
6389+ rewriter.replaceOpWithNewOp <vector::BroadcastOp >(
6390+ transposeOp, transposeOp.getResultVectorType (), splat );
63736391 return success ();
63746392 }
63756393};
@@ -7120,6 +7138,23 @@ OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
71207138 return SplatElementsAttr::get (getType (), {constOperand});
71217139}
71227140
7141+ // Canonicalizer for vector.splat. It always gets canonicalized to a vector.broadcast.
7142+ class SplatToBroadcastPattern : public OpRewritePattern <SplatOp> {
7143+ public:
7144+ using OpRewritePattern<SplatOp>::OpRewritePattern;
7145+ LogicalResult matchAndRewrite (SplatOp splatOp,
7146+ PatternRewriter &rewriter) const override {
7147+ rewriter.replaceOpWithNewOp <vector::BroadcastOp>(
7148+ splatOp, splatOp.getType (), splatOp.getOperand ());
7149+ return success ();
7150+ }
7151+ };
7152+ void SplatOp::getCanonicalizationPatterns (RewritePatternSet &results,
7153+ MLIRContext *context) {
7154+ results.add <SplatToBroadcastPattern>(context);
7155+ }
7156+
7157+
71237158void SplatOp::inferResultRanges (ArrayRef<ConstantIntRanges> argRanges,
71247159 SetIntRangeFn setResultRanges) {
71257160 setResultRanges (getResult (), argRanges.front ());
0 commit comments