@@ -2476,17 +2476,19 @@ OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
24762476 return {};
24772477}
24782478
2479- // / Rewrite a vector.from_elements into a vector.splat if all elements are the
2480- // / same SSA value. E.g.:
2481- // /
2482- // / %0 = vector.from_elements %a, %a, %a : vector<3xf32>
2483- // / ==> rewrite to vector.splat %a : vector<3xf32>
2484- static LogicalResult rewriteFromElementsAsSplat (FromElementsOp fromElementsOp,
2485- PatternRewriter &rewriter) {
2479+ // / Rewrite vector.from_elements as vector.broadcast if the elements are the
2480+ // / same. Example:
2481+ // / %0 = vector.from_elements %a, %a, %a : vector<3xf32>
2482+ // / =>
2483+ // / %0 = vector.broadcast %a : f32 to vector<3xf32>
2484+ static LogicalResult
2485+ rewriteFromElementsAsBroadcast (FromElementsOp fromElementsOp,
2486+ PatternRewriter &rewriter) {
24862487 if (!llvm::all_equal (fromElementsOp.getElements ()))
24872488 return failure ();
2488- rewriter.replaceOpWithNewOp <SplatOp>(fromElementsOp, fromElementsOp.getType (),
2489- fromElementsOp.getElements ().front ());
2489+ rewriter.replaceOpWithNewOp <BroadcastOp>(
2490+ fromElementsOp, fromElementsOp.getType (),
2491+ fromElementsOp.getElements ().front ());
24902492 return success ();
24912493}
24922494
@@ -2517,7 +2519,7 @@ class FromElementsToShapeCast : public OpRewritePattern<FromElementsOp> {
25172519 LogicalResult matchAndRewrite (FromElementsOp fromElements,
25182520 PatternRewriter &rewriter) const override {
25192521
2520- // Handled by `rewriteFromElementsAsSplat`
2522+ // Handled by `rewriteFromElementsAsBroadcast`.
25212523 if (fromElements.getType ().getNumElements () == 1 )
25222524 return failure ();
25232525
@@ -2610,7 +2612,7 @@ class FromElementsToShapeCast : public OpRewritePattern<FromElementsOp> {
26102612
26112613void FromElementsOp::getCanonicalizationPatterns (RewritePatternSet &results,
26122614 MLIRContext *context) {
2613- results.add (rewriteFromElementsAsSplat );
2615+ results.add (rewriteFromElementsAsBroadcast );
26142616 results.add <FromElementsToShapeCast>(context);
26152617}
26162618
@@ -3058,23 +3060,50 @@ struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
30583060 }
30593061};
30603062
3061- // / Pattern to rewrite a ShuffleOp(SplatOp, SplatOp) to SplatOp.
3063+ // / Consider the defining operation `defOp` of `value`. If `defOp` is a
3064+ // / vector.splat or a vector.broadcast with a scalar operand, return the scalar
3065+ // / value that is splatted. Otherwise return null.
3066+ // /
3067+ // / Examples:
3068+ // /
3069+ // / scalar_source --> vector.splat --> value - return scalar_source
3070+ // / scalar_source --> vector.broadcast --> value - return scalar_source
3071+ static Value getScalarSplatSource (Value value) {
3072+ // Block argument:
3073+ Operation *defOp = value.getDefiningOp ();
3074+ if (!defOp)
3075+ return {};
3076+
3077+ // Splat:
3078+ if (auto splat = dyn_cast<vector::SplatOp>(defOp))
3079+ return splat.getInput ();
3080+
3081+ auto broadcast = dyn_cast<vector::BroadcastOp>(defOp);
3082+
3083+ // Not broadcast (and not splat):
3084+ if (!broadcast)
3085+ return {};
3086+
3087+ // Broadcast of a vector:
3088+ if (isa<VectorType>(broadcast.getSourceType ()))
3089+ return {};
3090+
3091+ // Broadcast of a scalar:
3092+ return broadcast.getSource ();
3093+ }
3094+
3095+ // / Pattern to rewrite shuffle(splat-like(v), splat-like(v)) as broadcast(v).
30623096class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
30633097public:
30643098 using OpRewritePattern::OpRewritePattern;
30653099
30663100 LogicalResult matchAndRewrite (ShuffleOp op,
30673101 PatternRewriter &rewriter) const override {
3068- auto v1Splat = op.getV1 ().getDefiningOp <SplatOp>();
3069- auto v2Splat = op.getV2 ().getDefiningOp <SplatOp>();
3070-
3071- if (!v1Splat || !v2Splat)
3102+ Value splat = getScalarSplatSource (op.getV1 ());
3103+ if (!splat || getScalarSplatSource (op.getV2 ()) != splat)
30723104 return failure ();
30733105
3074- if (v1Splat.getInput () != v2Splat.getInput ())
3075- return failure ();
3076-
3077- rewriter.replaceOpWithNewOp <SplatOp>(op, op.getType (), v1Splat.getInput ());
3106+ rewriter.replaceOpWithNewOp <BroadcastOp>(op, op.getType (), splat);
30783107 return success ();
30793108 }
30803109};
@@ -3230,23 +3259,19 @@ class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
32303259 }
32313260};
32323261
3233- // / Pattern to rewrite a InsertOp(SplatOp, SplatOp) to SplatOp .
3262+ // / Pattern to rewrite a insert(splat-like(v), splat-like(v)) as broadcast(v) .
32343263class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
32353264public:
32363265 using OpRewritePattern::OpRewritePattern;
32373266
32383267 LogicalResult matchAndRewrite (InsertOp op,
32393268 PatternRewriter &rewriter) const override {
3240- auto srcSplat = op.getValueToStore ().getDefiningOp <SplatOp>();
3241- auto dstSplat = op.getDest ().getDefiningOp <SplatOp>();
3242-
3243- if (!srcSplat || !dstSplat)
3244- return failure ();
32453269
3246- if (srcSplat.getInput () != dstSplat.getInput ())
3270+ Value splat = getScalarSplatSource (op.getValueToStore ());
3271+ if (!splat || getScalarSplatSource (op.getDest ()) != splat)
32473272 return failure ();
32483273
3249- rewriter.replaceOpWithNewOp <SplatOp >(op, op.getType (), srcSplat. getInput () );
3274+ rewriter.replaceOpWithNewOp <BroadcastOp >(op, op.getType (), splat );
32503275 return success ();
32513276 }
32523277};
@@ -3514,27 +3539,21 @@ LogicalResult InsertStridedSliceOp::verify() {
35143539}
35153540
35163541namespace {
3517- // / Pattern to rewrite an InsertStridedSliceOp(SplatOp(X):src_type,
3518- // / SplatOp(X):dst_type) to SplatOp(X):dst_type.
3542+ // / Rewrite insert_strided_slice(splat-like(v), splat-like(v)) as v.
35193543class FoldInsertStridedSliceSplat final
35203544 : public OpRewritePattern<InsertStridedSliceOp> {
35213545public:
35223546 using OpRewritePattern::OpRewritePattern;
35233547
35243548 LogicalResult matchAndRewrite (InsertStridedSliceOp insertStridedSliceOp,
35253549 PatternRewriter &rewriter) const override {
3526- auto srcSplatOp =
3527- insertStridedSliceOp.getValueToStore ().getDefiningOp <vector::SplatOp>();
3528- auto destSplatOp =
3529- insertStridedSliceOp.getDest ().getDefiningOp <vector::SplatOp>();
35303550
3531- if (!srcSplatOp || !destSplatOp)
3551+ auto dst = insertStridedSliceOp.getDest ();
3552+ auto splat = getScalarSplatSource (insertStridedSliceOp.getValueToStore ());
3553+ if (!splat || getScalarSplatSource (dst) != splat)
35323554 return failure ();
35333555
3534- if (srcSplatOp.getInput () != destSplatOp.getInput ())
3535- return failure ();
3536-
3537- rewriter.replaceOp (insertStridedSliceOp, insertStridedSliceOp.getDest ());
3556+ rewriter.replaceOp (insertStridedSliceOp, dst);
35383557 return success ();
35393558 }
35403559};
@@ -4189,17 +4208,18 @@ class StridedSliceBroadcast final
41894208 }
41904209};
41914210
4192- // / Pattern to rewrite an ExtractStridedSliceOp(SplatOp) to SplatOp .
4211+ // / Rewrite extract_strided_slice(splat-like(v)) with broadcast(v) .
41934212class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
41944213public:
41954214 using OpRewritePattern::OpRewritePattern;
41964215
41974216 LogicalResult matchAndRewrite (ExtractStridedSliceOp op,
41984217 PatternRewriter &rewriter) const override {
4199- auto splat = op.getVector ().getDefiningOp <SplatOp>();
4218+
4219+ Value splat = getScalarSplatSource (op.getVector ());
42004220 if (!splat)
42014221 return failure ();
4202- rewriter.replaceOpWithNewOp <SplatOp >(op, op.getType (), splat. getInput () );
4222+ rewriter.replaceOpWithNewOp <BroadcastOp >(op, op.getType (), splat);
42034223 return success ();
42044224 }
42054225};
@@ -6354,19 +6374,19 @@ class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
63546374 }
63556375};
63566376
6357- // Folds transpose(splat x : src_type) : res_type into splat x : res_type.
6377+ // / Replace transpose(splat-like(v)) with broadcast(v)
63586378class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> {
63596379public:
63606380 using OpRewritePattern::OpRewritePattern;
63616381
63626382 LogicalResult matchAndRewrite (TransposeOp transposeOp,
63636383 PatternRewriter &rewriter) const override {
6364- auto splatOp = transposeOp.getVector (). getDefiningOp <vector::SplatOp>( );
6365- if (!splatOp )
6384+ Value splat = getScalarSplatSource ( transposeOp.getVector ());
6385+ if (!splat )
63666386 return failure ();
63676387
6368- rewriter.replaceOpWithNewOp <vector::SplatOp >(
6369- transposeOp, transposeOp.getResultVectorType (), splatOp. getInput () );
6388+ rewriter.replaceOpWithNewOp <vector::BroadcastOp >(
6389+ transposeOp, transposeOp.getResultVectorType (), splat );
63706390 return success ();
63716391 }
63726392};
@@ -7117,6 +7137,23 @@ OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
71177137 return SplatElementsAttr::get (getType (), {constOperand});
71187138}
71197139
7140+ // Canonicalizer for vector.splat. It always gets canonicalized to a
7141+ // vector.broadcast.
7142+ class SplatToBroadcastPattern final : public OpRewritePattern<SplatOp> {
7143+ public:
7144+ using OpRewritePattern<SplatOp>::OpRewritePattern;
7145+ LogicalResult matchAndRewrite (SplatOp splatOp,
7146+ PatternRewriter &rewriter) const override {
7147+ rewriter.replaceOpWithNewOp <vector::BroadcastOp>(splatOp, splatOp.getType (),
7148+ splatOp.getOperand ());
7149+ return success ();
7150+ }
7151+ };
7152+ void SplatOp::getCanonicalizationPatterns (RewritePatternSet &results,
7153+ MLIRContext *context) {
7154+ results.add <SplatToBroadcastPattern>(context);
7155+ }
7156+
71207157void SplatOp::inferResultRanges (ArrayRef<ConstantIntRanges> argRanges,
71217158 SetIntRangeFn setResultRanges) {
71227159 setResultRanges (getResult (), argRanges.front ());
0 commit comments