Skip to content

Commit ea291d0

Browse files
authored
[MLIR][Vector] Remove vector.splat (#162167)
vector.splat has been deprecated (user: please use the very similar vector.broadcast instead) with the last PR landing about 6 weeks ago. The discourse discussion is at https://discourse.llvm.org/t/rfc-mlir-vector-deprecate-then-remove-vector-splat/87143/1 The last PR was #152230 This PR completely removes vector.splat. In addition to removing vector.splat from VectorOps.td, it - Updates the few remaining places where vector::SplatOp is created (now vector::BroadcastOp is created) - Removes temporary patterns where vector.splat is replaced by vector.broadcast The only place 'vector.splat' appears is now the files https://github.com/llvm/llvm-project/blob/main/mlir/utils/tree-sitter-mlir/test/corpus/op.txt and https://github.com/llvm/llvm-project/blob/main/mlir/utils/tree-sitter-mlir/dialect/vector.js --------- Signed-off-by: James Newling <[email protected]>
1 parent 6ed18d8 commit ea291d0

File tree

28 files changed

+78
-501
lines changed

28 files changed

+78
-501
lines changed

flang/lib/Optimizer/Builder/PPCIntrinsicCall.cpp

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2383,26 +2383,26 @@ PPCIntrinsicLibrary::genVecSplat(mlir::Type resultType,
23832383
auto context{builder.getContext()};
23842384
auto argBases{getBasesForArgs(args)};
23852385

2386-
mlir::vector::SplatOp splatOp{nullptr};
2386+
mlir::vector::BroadcastOp splatOp{nullptr};
23872387
mlir::Type retTy{nullptr};
23882388
switch (vop) {
23892389
case VecOp::Splat: {
23902390
assert(args.size() == 2);
23912391
auto vecTyInfo{getVecTypeFromFir(argBases[0])};
23922392

23932393
auto extractOp{genVecExtract(resultType, args)};
2394-
splatOp =
2395-
mlir::vector::SplatOp::create(builder, loc, *(extractOp.getUnboxed()),
2396-
vecTyInfo.toMlirVectorType(context));
2394+
splatOp = mlir::vector::BroadcastOp::create(
2395+
builder, loc, vecTyInfo.toMlirVectorType(context),
2396+
*(extractOp.getUnboxed()));
23972397
retTy = vecTyInfo.toFirVectorType();
23982398
break;
23992399
}
24002400
case VecOp::Splats: {
24012401
assert(args.size() == 1);
24022402
auto vecTyInfo{getVecTypeFromEle(argBases[0])};
24032403

2404-
splatOp = mlir::vector::SplatOp::create(
2405-
builder, loc, argBases[0], vecTyInfo.toMlirVectorType(context));
2404+
splatOp = mlir::vector::BroadcastOp::create(
2405+
builder, loc, vecTyInfo.toMlirVectorType(context), argBases[0]);
24062406
retTy = vecTyInfo.toFirVectorType();
24072407
break;
24082408
}
@@ -2412,8 +2412,8 @@ PPCIntrinsicLibrary::genVecSplat(mlir::Type resultType,
24122412
auto intOp{builder.createConvert(loc, eleTy, argBases[0])};
24132413

24142414
// the intrinsic always returns vector(integer(4))
2415-
splatOp = mlir::vector::SplatOp::create(builder, loc, intOp,
2416-
mlir::VectorType::get(4, eleTy));
2415+
splatOp = mlir::vector::BroadcastOp::create(
2416+
builder, loc, mlir::VectorType::get(4, eleTy), intOp);
24172417
retTy = fir::VectorType::get(4, eleTy);
24182418
break;
24192419
}
@@ -2444,7 +2444,8 @@ PPCIntrinsicLibrary::genVecXlds(mlir::Type resultType,
24442444
auto addrConv{fir::ConvertOp::create(builder, loc, i64RefTy, addr)};
24452445

24462446
auto addrVal{fir::LoadOp::create(builder, loc, addrConv)};
2447-
auto splatRes{mlir::vector::SplatOp::create(builder, loc, addrVal, i64VecTy)};
2447+
auto splatRes{
2448+
mlir::vector::BroadcastOp::create(builder, loc, i64VecTy, addrVal)};
24482449

24492450
mlir::Value result{nullptr};
24502451
if (mlirTy != splatRes.getType()) {

mlir/docs/Dialects/Vector.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ Some existing Arith and Vector Dialect on `n-D` `vector` types comprise:
125125
// Produces a vector<3x7x8xf32>
126126
%b = arith.mulf %0, %1 : vector<3x7x8xf32>
127127
// Produces a vector<3x7x8xf32>
128-
%c = vector.splat %1 : vector<3x7x8xf32>
128+
%c = vector.broadcast %1 : f32 to vector<3x7x8xf32>
129129
130130
%d = vector.extract %0[1]: vector<7x8xf32> from vector<3x7x8xf32>
131131
%e = vector.extract %0[1, 5]: vector<8xf32> from vector<3x7x8xf32>
@@ -176,8 +176,6 @@ infrastructure can apply iteratively.
176176
### Virtual Vector to Hardware Vector Lowering
177177

178178
For now, `VV -> HWV` are specified in C++ (see for instance the
179-
[SplatOpLowering for n-D vectors](https://github.com/tensorflow/mlir/commit/0a0c4867c6a6fcb0a2f17ef26a791c1d551fe33d)
180-
or the
181179
[VectorOuterProductOp lowering](https://github.com/tensorflow/mlir/commit/957b1ca9680b4aacabb3a480fbc4ebd2506334b8)).
182180

183181
Simple

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 0 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -2881,53 +2881,6 @@ def Vector_PrintOp :
28812881
}];
28822882
}
28832883

2884-
//===----------------------------------------------------------------------===//
2885-
// SplatOp
2886-
//===----------------------------------------------------------------------===//
2887-
2888-
def Vector_SplatOp : Vector_Op<"splat", [
2889-
Pure,
2890-
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
2891-
TypesMatchWith<"operand type matches element type of result",
2892-
"aggregate", "input",
2893-
"::llvm::cast<VectorType>($_self).getElementType()">
2894-
]> {
2895-
let summary = "vector splat or broadcast operation";
2896-
let description = [{
2897-
Note: This operation is deprecated. Please use vector.broadcast.
2898-
2899-
Broadcast the operand to all elements of the result vector. The type of the
2900-
operand must match the element type of the vector type.
2901-
2902-
Example:
2903-
2904-
```mlir
2905-
%s = arith.constant 10.1 : f32
2906-
%t = vector.splat %s : vector<8x16xf32>
2907-
```
2908-
2909-
This operation is deprecated, the preferred representation of the above is:
2910-
2911-
```mlir
2912-
%s = arith.constant 10.1 : f32
2913-
%t = vector.broadcast %s : f32 to vector<8x16xf32>
2914-
```
2915-
}];
2916-
2917-
let arguments = (ins AnyType:$input);
2918-
let results = (outs AnyVectorOfAnyRank:$aggregate);
2919-
2920-
let builders = [
2921-
OpBuilder<(ins "Value":$element, "Type":$aggregateType),
2922-
[{ build($_builder, $_state, aggregateType, element); }]>];
2923-
let assemblyFormat = "$input attr-dict `:` type($aggregate)";
2924-
2925-
let hasFolder = 1;
2926-
2927-
// vector.splat is deprecated, and vector.broadcast should be used instead.
2928-
// Canonicalize vector.splat to vector.broadcast.
2929-
let hasCanonicalizer = 1;
2930-
}
29312884

29322885
//===----------------------------------------------------------------------===//
29332886
// VectorScaleOp

mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -432,10 +432,6 @@ static Value getOriginalVectorValue(Value value) {
432432
current = op.getSource();
433433
return false;
434434
})
435-
.Case<vector::SplatOp>([&current](auto op) {
436-
current = op.getInput();
437-
return false;
438-
})
439435
.Default([](Operation *) { return false; });
440436

441437
if (!skipOp) {

mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
236236
/// AFTER:
237237
/// ```mlir
238238
/// ...
239-
/// %pad_1d = vector.splat %pad : vector<[4]xi32>
239+
/// %pad_1d = vector.broadcast %pad : i32 to vector<[4]xi32>
240240
/// %tile = scf.for %tile_slice_idx = %c0 to %svl_s step %c1
241241
/// iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) {
242242
/// ...

mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -731,28 +731,14 @@ struct ExtractFromCreateMaskToPselLowering
731731
}
732732
};
733733

734-
// Convert all `vector.splat` to `vector.broadcast`. There is a path from
735-
// `vector.broadcast` to ArmSME via another pattern.
736-
struct ConvertSplatToBroadcast : public OpRewritePattern<vector::SplatOp> {
737-
using Base::Base;
738-
739-
LogicalResult matchAndRewrite(vector::SplatOp splatOp,
740-
PatternRewriter &rewriter) const final {
741-
742-
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splatOp, splatOp.getType(),
743-
splatOp.getInput());
744-
return success();
745-
}
746-
};
747-
748734
} // namespace
749735

750736
void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
751737
MLIRContext &ctx) {
752-
patterns.add<BroadcastOpToArmSMELowering, ConvertSplatToBroadcast,
753-
TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
754-
TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
755-
VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
738+
patterns.add<BroadcastOpToArmSMELowering, TransferReadToArmSMELowering,
739+
TransferWriteToArmSMELowering, TransposeOpToArmSMELowering,
740+
VectorLoadToArmSMELowering, VectorStoreToArmSMELowering,
741+
VectorOuterProductToArmSMELowering,
756742
VectorExtractToArmSMELowering, VectorInsertToArmSMELowering,
757743
VectorPrintToArmSMELowering, FoldTransferWriteOfExtractTileSlice,
758744
ExtractFromCreateMaskToPselLowering>(&ctx);

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2161,19 +2161,6 @@ class TransposeOpToMatrixTransposeOpLowering
21612161
}
21622162
};
21632163

2164-
/// Convert `vector.splat` to `vector.broadcast`. There is a path to LLVM from
2165-
/// `vector.broadcast` through other patterns.
2166-
struct VectorSplatToBroadcast : public ConvertOpToLLVMPattern<vector::SplatOp> {
2167-
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
2168-
LogicalResult
2169-
matchAndRewrite(vector::SplatOp splat, OpAdaptor adaptor,
2170-
ConversionPatternRewriter &rewriter) const override {
2171-
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splat, splat.getType(),
2172-
adaptor.getInput());
2173-
return success();
2174-
}
2175-
};
2176-
21772164
} // namespace
21782165

21792166
void mlir::vector::populateVectorRankReducingFMAPattern(
@@ -2212,7 +2199,7 @@ void mlir::populateVectorToLLVMConversionPatterns(
22122199
VectorInsertOpConversion, VectorPrintOpConversion,
22132200
VectorTypeCastOpConversion, VectorScaleOpConversion,
22142201
VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
2215-
VectorSplatToBroadcast, VectorBroadcastScalarToLowRankLowering,
2202+
VectorBroadcastScalarToLowRankLowering,
22162203
VectorBroadcastScalarToNdLowering,
22172204
VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
22182205
MaskedReductionOpConversion, VectorInterleaveOpLowering,

mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
#include "mlir/IR/BuiltinAttributes.h"
2323
#include "mlir/IR/BuiltinTypes.h"
2424
#include "mlir/IR/Location.h"
25-
#include "mlir/IR/Matchers.h"
2625
#include "mlir/IR/PatternMatch.h"
2726
#include "mlir/IR/TypeUtilities.h"
2827
#include "mlir/Transforms/DialectConversion.h"
@@ -79,20 +78,6 @@ struct VectorShapeCast final : public OpConversionPattern<vector::ShapeCastOp> {
7978
}
8079
};
8180

82-
// Convert `vector.splat` to `vector.broadcast`. There is a path from
83-
// `vector.broadcast` to SPIRV via other patterns.
84-
struct VectorSplatToBroadcast final
85-
: public OpConversionPattern<vector::SplatOp> {
86-
using Base::Base;
87-
LogicalResult
88-
matchAndRewrite(vector::SplatOp splat, OpAdaptor adaptor,
89-
ConversionPatternRewriter &rewriter) const override {
90-
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splat, splat.getType(),
91-
adaptor.getInput());
92-
return success();
93-
}
94-
};
95-
9681
struct VectorBitcastConvert final
9782
: public OpConversionPattern<vector::BitCastOp> {
9883
using Base::Base;
@@ -1092,10 +1077,10 @@ void mlir::populateVectorToSPIRVPatterns(
10921077
VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
10931078
VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
10941079
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
1095-
VectorSplatToBroadcast, VectorInsertStridedSliceOpConvert,
1096-
VectorShuffleOpConvert, VectorInterleaveOpConvert,
1097-
VectorDeinterleaveOpConvert, VectorScalarBroadcastPattern,
1098-
VectorLoadOpConverter, VectorStoreOpConverter, VectorStepOpConvert>(
1080+
VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
1081+
VectorInterleaveOpConvert, VectorDeinterleaveOpConvert,
1082+
VectorScalarBroadcastPattern, VectorLoadOpConverter,
1083+
VectorStoreOpConverter, VectorStepOpConvert>(
10991084
typeConverter, patterns.getContext(), PatternBenefit(1));
11001085

11011086
// Make sure that the more specialized dot product pattern has higher benefit

mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,7 @@ void mlir::arith::populateEmulateUnsupportedFloatsLegality(
123123
vector::OuterProductOp, vector::ScanOp>(
124124
[&](Operation *op) { return converter.isLegal(op); });
125125
target.addLegalOp<arith::BitcastOp, arith::ExtFOp, arith::TruncFOp,
126-
arith::ConstantOp, arith::SelectOp, vector::SplatOp,
127-
vector::BroadcastOp>();
126+
arith::ConstantOp, arith::SelectOp, vector::BroadcastOp>();
128127
}
129128

130129
void EmulateUnsupportedFloatsPass::runOnOperation() {

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 6 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1665,10 +1665,10 @@ static bool hasZeroDimVectors(Operation *op) {
16651665
llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
16661666
}
16671667

1668-
/// All BroadcastOps and SplatOps, as well as ShapeCastOps that only prepend
1669-
/// 1s, are considered to be 'broadcastlike'.
1668+
/// All BroadcastOps, as well as ShapeCastOps that only prepend 1s, are
1669+
/// considered to be 'broadcastlike'.
16701670
static bool isBroadcastLike(Operation *op) {
1671-
if (isa<BroadcastOp, SplatOp>(op))
1671+
if (isa<BroadcastOp>(op))
16721672
return true;
16731673

16741674
auto shapeCast = dyn_cast<ShapeCastOp>(op);
@@ -3249,23 +3249,18 @@ struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
32493249
};
32503250

32513251
/// Consider the defining operation `defOp` of `value`. If `defOp` is a
3252-
/// vector.splat or a vector.broadcast with a scalar operand, return the scalar
3253-
/// value that is splatted. Otherwise return null.
3252+
/// vector.broadcast with a scalar operand, return the scalar value that is
3253+
/// splatted. Otherwise return null.
32543254
///
3255-
/// Examples:
3255+
/// Example:
32563256
///
3257-
/// scalar_source --> vector.splat --> value - return scalar_source
32583257
/// scalar_source --> vector.broadcast --> value - return scalar_source
32593258
static Value getScalarSplatSource(Value value) {
32603259
// Block argument:
32613260
Operation *defOp = value.getDefiningOp();
32623261
if (!defOp)
32633262
return {};
32643263

3265-
// Splat:
3266-
if (auto splat = dyn_cast<vector::SplatOp>(defOp))
3267-
return splat.getInput();
3268-
32693264
auto broadcast = dyn_cast<vector::BroadcastOp>(defOp);
32703265

32713266
// Not broadcast (and not splat):
@@ -7511,41 +7506,6 @@ void mlir::vector::populateVectorToVectorCanonicalizationPatterns(
75117506
patterns.getContext(), benefit);
75127507
}
75137508

7514-
//===----------------------------------------------------------------------===//
7515-
// SplatOp
7516-
//===----------------------------------------------------------------------===//
7517-
7518-
OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
7519-
auto constOperand = adaptor.getInput();
7520-
if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
7521-
return {};
7522-
7523-
// SplatElementsAttr::get treats single value for second arg as being a splat.
7524-
return SplatElementsAttr::get(getType(), {constOperand});
7525-
}
7526-
7527-
// Canonicalizer for vector.splat. It always gets canonicalized to a
7528-
// vector.broadcast.
7529-
class SplatToBroadcastPattern final : public OpRewritePattern<SplatOp> {
7530-
public:
7531-
using Base::Base;
7532-
LogicalResult matchAndRewrite(SplatOp splatOp,
7533-
PatternRewriter &rewriter) const override {
7534-
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splatOp, splatOp.getType(),
7535-
splatOp.getOperand());
7536-
return success();
7537-
}
7538-
};
7539-
void SplatOp::getCanonicalizationPatterns(RewritePatternSet &results,
7540-
MLIRContext *context) {
7541-
results.add<SplatToBroadcastPattern>(context);
7542-
}
7543-
7544-
void SplatOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
7545-
SetIntRangeFn setResultRanges) {
7546-
setResultRanges(getResult(), argRanges.front());
7547-
}
7548-
75497509
Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
75507510
CombiningKind kind, Value v1, Value acc,
75517511
arith::FastMathFlagsAttr fastmath,

0 commit comments

Comments
 (0)