Skip to content

Commit 5aa7757

Browse files
committed
make broadcast the canonical form
1 parent f1acd69 commit 5aa7757

File tree

6 files changed

+307
-103
lines changed

6 files changed

+307
-103
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2919,6 +2919,8 @@ def Vector_SplatOp : Vector_Op<"splat", [
29192919
]> {
29202920
let summary = "vector splat or broadcast operation";
29212921
let description = [{
2922+
Note: This operation is deprecated. Please use vector.broadcast.
2923+
29222924
Broadcast the operand to all elements of the result vector. The type of the
29232925
operand must match the element type of the vector type.
29242926

@@ -2928,6 +2930,13 @@ def Vector_SplatOp : Vector_Op<"splat", [
29282930
%s = arith.constant 10.1 : f32
29292931
%t = vector.splat %s : vector<8x16xf32>
29302932
```
2933+
2934+
This operation is deprecated, the preferred representation of the above is:
2935+
2936+
```mlir
2937+
%s = arith.constant 10.1 : f32
2938+
%t = vector.broadcast %s : f32 to vector<8x16xf32>
2939+
```
29312940
}];
29322941

29332942
let arguments = (ins AnyType:$input);
@@ -2939,6 +2948,9 @@ def Vector_SplatOp : Vector_Op<"splat", [
29392948
let assemblyFormat = "$input attr-dict `:` type($aggregate)";
29402949

29412950
let hasFolder = 1;
2951+
2952+
// As vector.splat is deprecated, it is canonicalized to vector.broadcast.
2953+
let hasCanonicalizer = 1;
29422954
}
29432955

29442956
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 85 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1289,19 +1289,49 @@ LogicalResult vector::ExtractElementOp::verify() {
12891289
return success();
12901290
}
12911291

1292+
/// Consider the defining operation `defOp` of \p value. If `defOp` is a
1293+
/// vector.splat or a vector.broadcast with a scalar operand, return the scalar
1294+
/// value that is splatted. Otherwise return a null Value.
1295+
///
1296+
/// Example:
1297+
///
1298+
/// scalar_source --> vector.splat --> value > return scalar_source
1299+
/// scalar_source --> vector.broadcast --> value > return scalar_source
1300+
/// vector_source --> vector.broadcast --> value > return {}
1301+
/// * --> some.other.op --> value > return {}
1302+
static Value getSplatSource(Value value) {
1303+
1304+
// Block argument:
1305+
Operation *defOp = value.getDefiningOp();
1306+
if (!defOp)
1307+
return {};
1308+
1309+
// Splat:
1310+
auto splat = dyn_cast<vector::SplatOp>(defOp);
1311+
if (splat)
1312+
return splat.getInput();
1313+
1314+
auto broadcast = dyn_cast<vector::BroadcastOp>(defOp);
1315+
1316+
// Not broadcast (and not splat):
1317+
if (!broadcast)
1318+
return {};
1319+
1320+
// Broadcast of a vector:
1321+
if (isa<VectorType>(broadcast.getSourceType()))
1322+
return {};
1323+
1324+
// Broadcast of a scalar:
1325+
return broadcast.getSource();
1326+
}
1327+
12921328
OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
12931329
// Skip the 0-D vector here now.
12941330
if (!adaptor.getPosition())
12951331
return {};
12961332

1297-
// Fold extractelement (splat X) -> X.
1298-
if (auto splat = getVector().getDefiningOp<vector::SplatOp>())
1299-
return splat.getInput();
1300-
1301-
// Fold extractelement(broadcast(X)) -> X.
1302-
if (auto broadcast = getVector().getDefiningOp<vector::BroadcastOp>())
1303-
if (!llvm::isa<VectorType>(broadcast.getSource().getType()))
1304-
return broadcast.getSource();
1333+
if (auto splatValue = getSplatSource(getVector()))
1334+
return splatValue;
13051335

13061336
auto src = dyn_cast_or_null<DenseElementsAttr>(adaptor.getVector());
13071337
auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
@@ -2514,12 +2544,14 @@ OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
25142544
///
25152545
/// %0 = vector.from_elements %a, %a, %a : vector<3xf32>
25162546
/// ==> rewrite to vector.splat %a : vector<3xf32>
2517-
static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
2518-
PatternRewriter &rewriter) {
2547+
static LogicalResult
2548+
rewriteFromElementsAsBroadcast(FromElementsOp fromElementsOp,
2549+
PatternRewriter &rewriter) {
25192550
if (!llvm::all_equal(fromElementsOp.getElements()))
25202551
return failure();
2521-
rewriter.replaceOpWithNewOp<SplatOp>(fromElementsOp, fromElementsOp.getType(),
2522-
fromElementsOp.getElements().front());
2552+
rewriter.replaceOpWithNewOp<BroadcastOp>(
2553+
fromElementsOp, fromElementsOp.getType(),
2554+
fromElementsOp.getElements().front());
25232555
return success();
25242556
}
25252557

@@ -2550,7 +2582,7 @@ class FromElementsToShapeCast : public OpRewritePattern<FromElementsOp> {
25502582
LogicalResult matchAndRewrite(FromElementsOp fromElements,
25512583
PatternRewriter &rewriter) const override {
25522584

2553-
// Handled by `rewriteFromElementsAsSplat`
2585+
// Handled by `rewriteFromElementsAsBroadcast`
25542586
if (fromElements.getType().getNumElements() == 1)
25552587
return failure();
25562588

@@ -2644,7 +2676,7 @@ class FromElementsToShapeCast : public OpRewritePattern<FromElementsOp> {
26442676

26452677
void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
26462678
MLIRContext *context) {
2647-
results.add(rewriteFromElementsAsSplat);
2679+
results.add(rewriteFromElementsAsBroadcast);
26482680
results.add<FromElementsToShapeCast>(context);
26492681
}
26502682

@@ -3088,23 +3120,18 @@ struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
30883120
}
30893121
};
30903122

3091-
/// Pattern to rewrite a ShuffleOp(SplatOp, SplatOp) to SplatOp.
3123+
/// Pattern to rewrite shuffle(splat-like(v), splat-like(v)) as broadcast(v)
30923124
class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
30933125
public:
30943126
using OpRewritePattern::OpRewritePattern;
30953127

30963128
LogicalResult matchAndRewrite(ShuffleOp op,
30973129
PatternRewriter &rewriter) const override {
3098-
auto v1Splat = op.getV1().getDefiningOp<SplatOp>();
3099-
auto v2Splat = op.getV2().getDefiningOp<SplatOp>();
3100-
3101-
if (!v1Splat || !v2Splat)
3130+
Value splat = getSplatSource(op.getV1());
3131+
if (!splat || getSplatSource(op.getV2()) != splat)
31023132
return failure();
31033133

3104-
if (v1Splat.getInput() != v2Splat.getInput())
3105-
return failure();
3106-
3107-
rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), v1Splat.getInput());
3134+
rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
31083135
return success();
31093136
}
31103137
};
@@ -3314,23 +3341,19 @@ class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
33143341
}
33153342
};
33163343

3317-
/// Pattern to rewrite a InsertOp(SplatOp, SplatOp) to SplatOp.
3344+
/// Pattern to rewrite a insert(splat-like(v), splat-like(v)) as broadcast(v)
33183345
class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
33193346
public:
33203347
using OpRewritePattern::OpRewritePattern;
33213348

33223349
LogicalResult matchAndRewrite(InsertOp op,
33233350
PatternRewriter &rewriter) const override {
3324-
auto srcSplat = op.getValueToStore().getDefiningOp<SplatOp>();
3325-
auto dstSplat = op.getDest().getDefiningOp<SplatOp>();
33263351

3327-
if (!srcSplat || !dstSplat)
3352+
Value splat = getSplatSource(op.getValueToStore());
3353+
if (!splat || getSplatSource(op.getDest()) != splat)
33283354
return failure();
33293355

3330-
if (srcSplat.getInput() != dstSplat.getInput())
3331-
return failure();
3332-
3333-
rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), srcSplat.getInput());
3356+
rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
33343357
return success();
33353358
}
33363359
};
@@ -3598,27 +3621,21 @@ LogicalResult InsertStridedSliceOp::verify() {
35983621
}
35993622

36003623
namespace {
3601-
/// Pattern to rewrite an InsertStridedSliceOp(SplatOp(X):src_type,
3602-
/// SplatOp(X):dst_type) to SplatOp(X):dst_type.
3624+
/// Rewrite insert_strided_slice(splat-like(v), splat-like(v)) as v
36033625
class FoldInsertStridedSliceSplat final
36043626
: public OpRewritePattern<InsertStridedSliceOp> {
36053627
public:
36063628
using OpRewritePattern::OpRewritePattern;
36073629

36083630
LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
36093631
PatternRewriter &rewriter) const override {
3610-
auto srcSplatOp =
3611-
insertStridedSliceOp.getValueToStore().getDefiningOp<vector::SplatOp>();
3612-
auto destSplatOp =
3613-
insertStridedSliceOp.getDest().getDefiningOp<vector::SplatOp>();
36143632

3615-
if (!srcSplatOp || !destSplatOp)
3633+
auto dst = insertStridedSliceOp.getDest();
3634+
auto splat = getSplatSource(insertStridedSliceOp.getValueToStore());
3635+
if (!splat || getSplatSource(dst) != splat)
36163636
return failure();
36173637

3618-
if (srcSplatOp.getInput() != destSplatOp.getInput())
3619-
return failure();
3620-
3621-
rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
3638+
rewriter.replaceOp(insertStridedSliceOp, dst);
36223639
return success();
36233640
}
36243641
};
@@ -4197,17 +4214,18 @@ class StridedSliceBroadcast final
41974214
}
41984215
};
41994216

4200-
/// Pattern to rewrite an ExtractStridedSliceOp(SplatOp) to SplatOp.
4217+
/// Rewrite extract_strided_slice(splat-like(v)) with broadcast(v)
42014218
class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
42024219
public:
42034220
using OpRewritePattern::OpRewritePattern;
42044221

42054222
LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
42064223
PatternRewriter &rewriter) const override {
4207-
auto splat = op.getVector().getDefiningOp<SplatOp>();
4224+
4225+
Value splat = getSplatSource(op.getVector());
42084226
if (!splat)
42094227
return failure();
4210-
rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), splat.getInput());
4228+
rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
42114229
return success();
42124230
}
42134231
};
@@ -6357,19 +6375,19 @@ class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
63576375
}
63586376
};
63596377

6360-
// Folds transpose(splat x : src_type) : res_type into splat x : res_type.
6378+
/// Replace transpose(splat-like(v)) with broadcast(v)
63616379
class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> {
63626380
public:
63636381
using OpRewritePattern::OpRewritePattern;
63646382

63656383
LogicalResult matchAndRewrite(TransposeOp transposeOp,
63666384
PatternRewriter &rewriter) const override {
6367-
auto splatOp = transposeOp.getVector().getDefiningOp<vector::SplatOp>();
6368-
if (!splatOp)
6385+
Value splat = getSplatSource(transposeOp.getVector());
6386+
if (!splat)
63696387
return failure();
63706388

6371-
rewriter.replaceOpWithNewOp<vector::SplatOp>(
6372-
transposeOp, transposeOp.getResultVectorType(), splatOp.getInput());
6389+
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
6390+
transposeOp, transposeOp.getResultVectorType(), splat);
63736391
return success();
63746392
}
63756393
};
@@ -7120,6 +7138,23 @@ OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
71207138
return SplatElementsAttr::get(getType(), {constOperand});
71217139
}
71227140

7141+
// Canonicalizer for vector.splat. It always gets canonicalized to a vector.broadcast.
7142+
class SplatToBroadcastPattern : public OpRewritePattern<SplatOp> {
7143+
public:
7144+
using OpRewritePattern<SplatOp>::OpRewritePattern;
7145+
LogicalResult matchAndRewrite(SplatOp splatOp,
7146+
PatternRewriter &rewriter) const override {
7147+
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
7148+
splatOp, splatOp.getType(), splatOp.getOperand());
7149+
return success();
7150+
}
7151+
};
7152+
void SplatOp::getCanonicalizationPatterns(RewritePatternSet &results,
7153+
MLIRContext *context) {
7154+
results.add<SplatToBroadcastPattern>(context);
7155+
}
7156+
7157+
71237158
void SplatOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
71247159
SetIntRangeFn setResultRanges) {
71257160
setResultRanges(getResult(), argRanges.front());

0 commit comments

Comments
 (0)