Skip to content

Commit 671eaf8

Browse files
authored
[mlir][vector] Avoid use of vector.splat in transforms (#150279)
This is part of vector.splat deprecation Reference: https://discourse.llvm.org/t/rfc-mlir-vector-deprecate-then-remove-vector-splat/87143/5 Instead of creating vector::SplatOp, create vector::BroadcastOp
1 parent f4972a2 commit 671eaf8

File tree

6 files changed

+76
-53
lines changed

6 files changed

+76
-53
lines changed

mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,10 @@ using namespace mlir;
2828
using namespace mlir::vector;
2929

3030
namespace {
31-
/// Progressive lowering of BroadcastOp.
31+
32+
/// Convert a vector.broadcast with a vector operand to a lower rank
33+
/// vector.broadcast. vector.broadcast with a scalar operand is expected to be
34+
/// convertible to the lower level target dialect (LLVM, SPIR-V, etc.) directly.
3235
class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
3336
public:
3437
using OpRewritePattern::OpRewritePattern;
@@ -40,20 +43,23 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
4043
VectorType srcType = dyn_cast<VectorType>(op.getSourceType());
4144
Type eltType = dstType.getElementType();
4245

43-
// Scalar to any vector can use splat.
44-
if (!srcType) {
45-
rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, op.getSource());
46-
return success();
47-
}
46+
// A broadcast from a scalar is considered to be in the lowered form.
47+
if (!srcType)
48+
return rewriter.notifyMatchFailure(
49+
op, "broadcast from scalar already in lowered form");
4850

4951
// Determine rank of source and destination.
5052
int64_t srcRank = srcType.getRank();
5153
int64_t dstRank = dstType.getRank();
5254

53-
// Stretching scalar inside vector (e.g. vector<1xf32>) can use splat.
55+
// Here we are broadcasting to a rank-1 vector. Ensure that the source is a
56+
// scalar.
5457
if (srcRank <= 1 && dstRank == 1) {
55-
Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource());
56-
rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext);
58+
SmallVector<int64_t> fullRankPosition(srcRank, 0);
59+
Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource(),
60+
fullRankPosition);
61+
assert(!isa<VectorType>(ext.getType()) && "expected scalar");
62+
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, dstType, ext);
5763
return success();
5864
}
5965

mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,7 @@ struct TransferReadToVectorLoadLowering
468468
read, "vector type is not rank 1, can't create masked load, needs "
469469
"VectorToSCF");
470470

471-
Value fill = vector::SplatOp::create(
471+
Value fill = vector::BroadcastOp::create(
472472
rewriter, read.getLoc(), unbroadcastedVectorType, read.getPadding());
473473
res = vector::MaskedLoadOp::create(
474474
rewriter, read.getLoc(), unbroadcastedVectorType, read.getBase(),

mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ class DecomposeNDExtractStridedSlice
303303
// Extract/insert on a lower ranked extract strided slice op.
304304
Value zero = arith::ConstantOp::create(rewriter, loc, elemType,
305305
rewriter.getZeroAttr(elemType));
306-
Value res = SplatOp::create(rewriter, loc, dstType, zero);
306+
Value res = BroadcastOp::create(rewriter, loc, dstType, zero);
307307
for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
308308
off += stride, ++idx) {
309309
Value one = ExtractOp::create(rewriter, loc, op.getVector(), off);

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 39 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -939,7 +939,7 @@ struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
939939

940940
Value zero = arith::ConstantOp::create(rewriter, loc, elemType,
941941
rewriter.getZeroAttr(elemType));
942-
Value res = SplatOp::create(rewriter, loc, castDstType, zero);
942+
Value res = BroadcastOp::create(rewriter, loc, castDstType, zero);
943943

944944
SmallVector<int64_t> sliceShape = {castDstLastDim};
945945
SmallVector<int64_t> strides = {1};
@@ -987,6 +987,23 @@ static Type cloneOrReplace(Type type, Type newElementType) {
987987
return newElementType;
988988
}
989989

990+
/// If `value` is the result of a splat or broadcast operation, return the input
991+
/// of the splat/broadcast operation.
992+
static Value getBroadcastLikeSource(Value value) {
993+
994+
Operation *op = value.getDefiningOp();
995+
if (!op)
996+
return {};
997+
998+
if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
999+
return broadcast.getSource();
1000+
1001+
if (auto splat = dyn_cast<vector::SplatOp>(op))
1002+
return splat.getInput();
1003+
1004+
return {};
1005+
}
1006+
9901007
/// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex:
9911008
///
9921009
/// Example:
@@ -1026,39 +1043,37 @@ struct ReorderElementwiseOpsOnBroadcast final
10261043
}
10271044

10281045
Type resultElemType = resultType.getElementType();
1046+
10291047
// Get the type of the first non-constant operand
1030-
Operation *firstBroadcastOrSplat = nullptr;
1048+
Value splatSource;
10311049
for (Value operand : op->getOperands()) {
10321050
Operation *definingOp = operand.getDefiningOp();
10331051
if (!definingOp)
10341052
return failure();
10351053
if (definingOp->hasTrait<OpTrait::ConstantLike>())
10361054
continue;
1037-
if (!isa<vector::BroadcastOp, vector::SplatOp>(*definingOp))
1038-
return failure();
1039-
firstBroadcastOrSplat = definingOp;
1055+
splatSource = getBroadcastLikeSource(operand);
10401056
break;
10411057
}
1042-
if (!firstBroadcastOrSplat)
1058+
if (!splatSource)
10431059
return failure();
1044-
Type unbroadcastResultType = cloneOrReplace(
1045-
firstBroadcastOrSplat->getOperand(0).getType(), resultElemType);
1060+
Type unbroadcastResultType =
1061+
cloneOrReplace(splatSource.getType(), resultElemType);
10461062

10471063
// Make sure that all operands are broadcast from identically-shaped types:
10481064
// * scalar (`vector.broadcast` + `vector.splat`), or
10491065
// * vector (`vector.broadcast`).
10501066
// Otherwise the re-ordering wouldn't be safe.
1051-
if (!llvm::all_of(op->getOperands(), [&unbroadcastResultType](Value val) {
1052-
if (auto bcastOp = val.getDefiningOp<vector::BroadcastOp>())
1053-
return haveSameShapeAndScaling(bcastOp.getOperand().getType(),
1054-
unbroadcastResultType);
1055-
if (auto splatOp = val.getDefiningOp<vector::SplatOp>())
1056-
return haveSameShapeAndScaling(splatOp.getOperand().getType(),
1057-
unbroadcastResultType);
1067+
if (!llvm::all_of(op->getOperands(), [splatSource](Value val) {
1068+
if (auto source = getBroadcastLikeSource(val))
1069+
return haveSameShapeAndScaling(source.getType(),
1070+
splatSource.getType());
10581071
SplatElementsAttr splatConst;
10591072
return matchPattern(val, m_Constant(&splatConst));
10601073
})) {
1061-
return failure();
1074+
return rewriter.notifyMatchFailure(
1075+
op,
1076+
"not all operands are constants or broadcasts from the same type");
10621077
}
10631078

10641079
// Collect the source values before broadcasting
@@ -1287,15 +1302,17 @@ class StoreOpFromSplatOrBroadcast final
12871302
return rewriter.notifyMatchFailure(
12881303
op, "only 1-element vectors are supported");
12891304

1290-
Operation *splat = op.getValueToStore().getDefiningOp();
1291-
if (!isa_and_present<vector::BroadcastOp, vector::SplatOp>(splat))
1292-
return rewriter.notifyMatchFailure(op, "neither a splat nor a broadcast");
1305+
Value toStore = op.getValueToStore();
1306+
Value source = getBroadcastLikeSource(toStore);
1307+
if (!source)
1308+
return rewriter.notifyMatchFailure(
1309+
op, "value to store is not from a broadcast");
12931310

12941311
// Checking for single use so we can remove splat.
1312+
Operation *splat = toStore.getDefiningOp();
12951313
if (!splat->hasOneUse())
12961314
return rewriter.notifyMatchFailure(op, "expected single op use");
12971315

1298-
Value source = splat->getOperand(0);
12991316
Value base = op.getBase();
13001317
ValueRange indices = op.getIndices();
13011318

@@ -1345,13 +1362,13 @@ static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
13451362
// Add in an offset if requested.
13461363
if (off) {
13471364
Value o = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, *off);
1348-
Value ov = vector::SplatOp::create(rewriter, loc, indices.getType(), o);
1365+
Value ov = vector::BroadcastOp::create(rewriter, loc, indices.getType(), o);
13491366
indices = arith::AddIOp::create(rewriter, loc, ov, indices);
13501367
}
13511368
// Construct the vector comparison.
13521369
Value bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, b);
13531370
Value bounds =
1354-
vector::SplatOp::create(rewriter, loc, indices.getType(), bound);
1371+
vector::BroadcastOp::create(rewriter, loc, indices.getType(), bound);
13551372
return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
13561373
indices, bounds);
13571374
}

mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
// CHECK-LABEL: func @broadcast_vec1d_from_scalar
44
// CHECK-SAME: %[[A:.*0]]: f32
5-
// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2xf32>
5+
// CHECK: %[[T0:.*]] = vector.broadcast %[[A]] : f32 to vector<2xf32>
66
// CHECK: return %[[T0]] : vector<2xf32>
77

88
func.func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> {
@@ -12,7 +12,7 @@ func.func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> {
1212

1313
// CHECK-LABEL: func @broadcast_vec2d_from_scalar
1414
// CHECK-SAME: %[[A:.*0]]: f32
15-
// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2x3xf32>
15+
// CHECK: %[[T0:.*]] = vector.broadcast %[[A]] : f32 to vector<2x3xf32>
1616
// CHECK: return %[[T0]] : vector<2x3xf32>
1717

1818
func.func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> {
@@ -22,7 +22,7 @@ func.func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> {
2222

2323
// CHECK-LABEL: func @broadcast_vec3d_from_scalar
2424
// CHECK-SAME: %[[A:.*0]]: f32
25-
// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2x3x4xf32>
25+
// CHECK: %[[T0:.*]] = vector.broadcast %[[A]] : f32 to vector<2x3x4xf32>
2626
// CHECK: return %[[T0]] : vector<2x3x4xf32>
2727

2828
func.func @broadcast_vec3d_from_scalar(%arg0: f32) -> vector<2x3x4xf32> {
@@ -87,7 +87,7 @@ func.func @broadcast_vec3d_from_vec2d(%arg0: vector<3x2xf32>) -> vector<4x3x2xf3
8787
// CHECK-LABEL: func @broadcast_stretch
8888
// CHECK-SAME: %[[A:.*0]]: vector<1xf32>
8989
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<1xf32>
90-
// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<4xf32>
90+
// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : f32 to vector<4xf32>
9191
// CHECK: return %[[T1]] : vector<4xf32>
9292

9393
func.func @broadcast_stretch(%arg0: vector<1xf32>) -> vector<4xf32> {
@@ -113,16 +113,16 @@ func.func @broadcast_stretch_at_start(%arg0: vector<1x4xf32>) -> vector<3x4xf32>
113113
// CHECK-SAME: %[[A:.*0]]: vector<4x1xf32>
114114
// CHECK: %[[U0:.*]] = ub.poison : vector<4x3xf32>
115115
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : f32 from vector<4x1xf32>
116-
// CHECK: %[[T2:.*]] = vector.splat %[[T0]] : vector<3xf32>
116+
// CHECK: %[[T2:.*]] = vector.broadcast %[[T0]] : f32 to vector<3xf32>
117117
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[U0]] [0] : vector<3xf32> into vector<4x3xf32>
118118
// CHECK: %[[T4:.*]] = vector.extract %[[A]][1, 0] : f32 from vector<4x1xf32>
119-
// CHECK: %[[T6:.*]] = vector.splat %[[T4]] : vector<3xf32>
119+
// CHECK: %[[T6:.*]] = vector.broadcast %[[T4]] : f32 to vector<3xf32>
120120
// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<4x3xf32>
121121
// CHECK: %[[T8:.*]] = vector.extract %[[A]][2, 0] : f32 from vector<4x1xf32>
122-
// CHECK: %[[T10:.*]] = vector.splat %[[T8]] : vector<3xf32>
122+
// CHECK: %[[T10:.*]] = vector.broadcast %[[T8]] : f32 to vector<3xf32>
123123
// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T7]] [2] : vector<3xf32> into vector<4x3xf32>
124124
// CHECK: %[[T12:.*]] = vector.extract %[[A]][3, 0] : f32 from vector<4x1xf32>
125-
// CHECK: %[[T14:.*]] = vector.splat %[[T12]] : vector<3xf32>
125+
// CHECK: %[[T14:.*]] = vector.broadcast %[[T12]] : f32 to vector<3xf32>
126126
// CHECK: %[[T15:.*]] = vector.insert %[[T14]], %[[T11]] [3] : vector<3xf32> into vector<4x3xf32>
127127
// CHECK: return %[[T15]] : vector<4x3xf32>
128128

mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
// CHECK-SAME: %[[B:.*1]]: vector<3xf32>
66
// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32>
77
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<2xf32>
8-
// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xf32>
8+
// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : f32 to vector<3xf32>
99
// CHECK: %[[T2:.*]] = arith.mulf %[[T1]], %[[B]] : vector<3xf32>
1010
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32>
1111
// CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : f32 from vector<2xf32>
12-
// CHECK: %[[T5:.*]] = vector.splat %[[T4]] : vector<3xf32>
12+
// CHECK: %[[T5:.*]] = vector.broadcast %[[T4]] : f32 to vector<3xf32>
1313
// CHECK: %[[T6:.*]] = arith.mulf %[[T5]], %[[B]] : vector<3xf32>
1414
// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<2x3xf32>
1515
// CHECK: return %[[T7]] : vector<2x3xf32>
@@ -26,12 +26,12 @@ func.func @outerproduct_noacc(%arg0: vector<2xf32>,
2626
// CHECK-SAME: %[[C:.*2]]: vector<2x3xf32>
2727
// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32>
2828
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<2xf32>
29-
// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xf32>
29+
// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : f32 to vector<3xf32>
3030
// CHECK: %[[T2:.*]] = vector.extract %[[C]][0] : vector<3xf32> from vector<2x3xf32>
3131
// CHECK: %[[T3:.*]] = vector.fma %[[T1]], %[[B]], %[[T2]] : vector<3xf32>
3232
// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32>
3333
// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : f32 from vector<2xf32>
34-
// CHECK: %[[T6:.*]] = vector.splat %[[T5]] : vector<3xf32>
34+
// CHECK: %[[T6:.*]] = vector.broadcast %[[T5]] : f32 to vector<3xf32>
3535
// CHECK: %[[T7:.*]] = vector.extract %[[C]][1] : vector<3xf32> from vector<2x3xf32>
3636
// CHECK: %[[T8:.*]] = vector.fma %[[T6]], %[[B]], %[[T7]] : vector<3xf32>
3737
// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : vector<3xf32> into vector<2x3xf32>
@@ -49,11 +49,11 @@ func.func @outerproduct_acc(%arg0: vector<2xf32>,
4949
// CHECK-SAME: %[[B:.*1]]: vector<3xi32>
5050
// CHECK: %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32>
5151
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : i32 from vector<2xi32>
52-
// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xi32>
52+
// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : i32 to vector<3xi32>
5353
// CHECK: %[[T2:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32>
5454
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32>
5555
// CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : i32 from vector<2xi32>
56-
// CHECK: %[[T5:.*]] = vector.splat %[[T4]] : vector<3xi32>
56+
// CHECK: %[[T5:.*]] = vector.broadcast %[[T4]] : i32 to vector<3xi32>
5757
// CHECK: %[[T6:.*]] = arith.muli %[[T5]], %[[B]] : vector<3xi32>
5858
// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xi32> into vector<2x3xi32>
5959
// CHECK: return %[[T7]] : vector<2x3xi32>
@@ -69,13 +69,13 @@ func.func @outerproduct_noacc_int(%arg0: vector<2xi32>,
6969
// CHECK-SAME: %[[C:.*2]]: vector<2x3xi32>
7070
// CHECK: %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32>
7171
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : i32 from vector<2xi32>
72-
// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xi32>
72+
// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : i32 to vector<3xi32>
7373
// CHECK: %[[T2:.*]] = vector.extract %[[C]][0] : vector<3xi32> from vector<2x3xi32>
7474
// CHECK: %[[T3:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32>
7575
// CHECK: %[[T4:.*]] = arith.addi %[[T3]], %[[T2]] : vector<3xi32>
7676
// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32>
7777
// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : i32 from vector<2xi32>
78-
// CHECK: %[[T7:.*]] = vector.splat %[[T6]] : vector<3xi32>
78+
// CHECK: %[[T7:.*]] = vector.broadcast %[[T6]] : i32 to vector<3xi32>
7979
// CHECK: %[[T8:.*]] = vector.extract %[[C]][1] : vector<3xi32> from vector<2x3xi32>
8080
// CHECK: %[[T9:.*]] = arith.muli %[[T7]], %[[B]] : vector<3xi32>
8181
// CHECK: %[[T10:.*]] = arith.addi %[[T9]], %[[T8]] : vector<3xi32>
@@ -91,7 +91,7 @@ func.func @outerproduct_acc_int(%arg0: vector<2xi32>,
9191
// CHECK-LABEL: func @axpy_fp(
9292
// CHECK-SAME: %[[A:.*0]]: vector<16xf32>,
9393
// CHECK-SAME: %[[B:.*1]]: f32)
94-
// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xf32>
94+
// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : f32 to vector<16xf32>
9595
// CHECK: %[[T1:.*]] = arith.mulf %[[A]], %[[T0]] : vector<16xf32>
9696
// CHECK: return %[[T1]] : vector<16xf32>
9797
func.func @axpy_fp(%arg0: vector<16xf32>, %arg1: f32) -> vector<16xf32> {
@@ -103,7 +103,7 @@ func.func @axpy_fp(%arg0: vector<16xf32>, %arg1: f32) -> vector<16xf32> {
103103
// CHECK-SAME: %[[A:.*0]]: vector<16xf32>,
104104
// CHECK-SAME: %[[B:.*1]]: f32,
105105
// CHECK-SAME: %[[C:.*2]]: vector<16xf32>)
106-
// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xf32>
106+
// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : f32 to vector<16xf32>
107107
// CHECK: %[[T1:.*]] = vector.fma %[[A]], %[[T0]], %[[C]] : vector<16xf32>
108108
// CHECK: return %[[T1]] : vector<16xf32>
109109
func.func @axpy_fp_add(%arg0: vector<16xf32>, %arg1: f32, %arg2 : vector<16xf32>) -> vector<16xf32> {
@@ -114,7 +114,7 @@ func.func @axpy_fp_add(%arg0: vector<16xf32>, %arg1: f32, %arg2 : vector<16xf32>
114114
// CHECK-LABEL: func @axpy_int(
115115
// CHECK-SAME: %[[A:.*0]]: vector<16xi32>,
116116
// CHECK-SAME: %[[B:.*1]]: i32)
117-
// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xi32>
117+
// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : i32 to vector<16xi32>
118118
// CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32>
119119
// CHECK: return %[[T1]] : vector<16xi32>
120120
func.func @axpy_int(%arg0: vector<16xi32>, %arg1: i32) -> vector<16xi32> {
@@ -126,7 +126,7 @@ func.func @axpy_int(%arg0: vector<16xi32>, %arg1: i32) -> vector<16xi32> {
126126
// CHECK-SAME: %[[A:.*0]]: vector<16xi32>,
127127
// CHECK-SAME: %[[B:.*1]]: i32,
128128
// CHECK-SAME: %[[C:.*2]]: vector<16xi32>)
129-
// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xi32>
129+
// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : i32 to vector<16xi32>
130130
// CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32>
131131
// CHECK: %[[T2:.*]] = arith.addi %[[T1]], %[[C]] : vector<16xi32>
132132
// CHECK: return %[[T2]] : vector<16xi32>

0 commit comments

Comments
 (0)