Skip to content

Commit 330a7e1

Browse files
authored
[mlir][Vector] Make elementwise-on-broadcast sinking handle splat consts (llvm#150867)
There is a pattern that rewrites elementwise_op(broadcast(x1 : T to U), broadcast(x2 : T to U), ...) to broadcast(elementwise_op(x1, x2, ...) : T to U). This pattern did not, however, account for the case where a broadcast constant is represented as a SplatElementsAttr, which can safely be reshaped or scalarized but is not a `vector.broadcast` or `vector.splat` operation. This patch fixes this oversight, prenting premature broadcasting. This did result in the need to update some linalg dialect tests, which now feature a less-broadcast computation and/or more constant folding.
1 parent a1aba84 commit 330a7e1

File tree

3 files changed

+129
-46
lines changed

3 files changed

+129
-46
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1005,40 +1005,68 @@ struct ReorderElementwiseOpsOnBroadcast final
10051005
"might be a scalar");
10061006
}
10071007

1008-
// Get the type of the lhs operand
1009-
auto *lhsBcastOrSplat = op->getOperand(0).getDefiningOp();
1010-
if (!lhsBcastOrSplat ||
1011-
!isa<vector::BroadcastOp, vector::SplatOp>(*lhsBcastOrSplat))
1008+
// Get the type of the first non-constant operand
1009+
Operation *firstBroadcastOrSplat = nullptr;
1010+
for (Value operand : op->getOperands()) {
1011+
Operation *definingOp = operand.getDefiningOp();
1012+
if (!definingOp)
1013+
return failure();
1014+
if (definingOp->hasTrait<OpTrait::ConstantLike>())
1015+
continue;
1016+
if (!isa<vector::BroadcastOp, vector::SplatOp>(*definingOp))
1017+
return failure();
1018+
firstBroadcastOrSplat = definingOp;
1019+
break;
1020+
}
1021+
if (!firstBroadcastOrSplat)
10121022
return failure();
1013-
auto lhsBcastOrSplatType = lhsBcastOrSplat->getOperand(0).getType();
1023+
Type firstBroadcastOrSplatType =
1024+
firstBroadcastOrSplat->getOperand(0).getType();
10141025

10151026
// Make sure that all operands are broadcast from identical types:
10161027
// * scalar (`vector.broadcast` + `vector.splat`), or
10171028
// * vector (`vector.broadcast`).
10181029
// Otherwise the re-ordering wouldn't be safe.
1019-
if (!llvm::all_of(op->getOperands(), [&lhsBcastOrSplatType](Value val) {
1020-
auto bcast = val.getDefiningOp<vector::BroadcastOp>();
1021-
if (bcast)
1022-
return (bcast.getOperand().getType() == lhsBcastOrSplatType);
1023-
auto splat = val.getDefiningOp<vector::SplatOp>();
1024-
if (splat)
1025-
return (splat.getOperand().getType() == lhsBcastOrSplatType);
1026-
return false;
1027-
})) {
1030+
if (!llvm::all_of(
1031+
op->getOperands(), [&firstBroadcastOrSplatType](Value val) {
1032+
if (auto bcastOp = val.getDefiningOp<vector::BroadcastOp>())
1033+
return (bcastOp.getOperand().getType() ==
1034+
firstBroadcastOrSplatType);
1035+
if (auto splatOp = val.getDefiningOp<vector::SplatOp>())
1036+
return (splatOp.getOperand().getType() ==
1037+
firstBroadcastOrSplatType);
1038+
SplatElementsAttr splatConst;
1039+
return matchPattern(val, m_Constant(&splatConst));
1040+
})) {
10281041
return failure();
10291042
}
10301043

10311044
// Collect the source values before broadcasting
10321045
SmallVector<Value> srcValues;
10331046
srcValues.reserve(op->getNumOperands());
10341047
for (Value operand : op->getOperands()) {
1035-
srcValues.push_back(operand.getDefiningOp()->getOperand(0));
1048+
SplatElementsAttr splatConst;
1049+
if (matchPattern(operand, m_Constant(&splatConst))) {
1050+
Attribute newConst;
1051+
if (auto shapedTy = dyn_cast<ShapedType>(firstBroadcastOrSplatType)) {
1052+
newConst = splatConst.resizeSplat(shapedTy);
1053+
} else {
1054+
newConst = splatConst.getSplatValue<Attribute>();
1055+
}
1056+
Operation *newConstOp =
1057+
operand.getDefiningOp()->getDialect()->materializeConstant(
1058+
rewriter, newConst, firstBroadcastOrSplatType,
1059+
operand.getLoc());
1060+
srcValues.push_back(newConstOp->getResult(0));
1061+
} else {
1062+
srcValues.push_back(operand.getDefiningOp()->getOperand(0));
1063+
}
10361064
}
10371065

10381066
// Create the "elementwise" Op
10391067
Operation *elementwiseOp =
10401068
rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
1041-
lhsBcastOrSplatType, op->getAttrs());
1069+
firstBroadcastOrSplatType, op->getAttrs());
10421070

10431071
// Replace the original Op with the elementwise Op
10441072
auto vectorType = op->getResultTypes()[0];

mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir

Lines changed: 21 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -230,18 +230,17 @@ func.func @vectorize_nd_tensor_extract_index_from_tensor(%arg0: tensor<3x3xf32>,
230230
// CHECK-SAME: %[[ARG4:.*]]: tensor<4x7x3x2xf32>
231231
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
232232
// CHECK-DAG: %[[PV:.*]] = ub.poison : i32
233-
// CHECK-DAG: %[[CST:.*]] = arith.constant dense<3> : vector<7x2x4x3xindex>
233+
// CHECK-DAG: %[[CST:.*]] = arith.constant dense<3> : vector<4x3xindex>
234234
// CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<true> : vector<4x7x3x2xi1>
235235
// CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<4x7x3x2xf32>
236236
// CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], %[[PV]] {in_bounds = [true, true]} : tensor<4x3xi32>, vector<4x3xi32>
237237
// CHECK: %[[V1:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], %[[PV]] {in_bounds = [true, true]} : tensor<4x3xi32>, vector<4x3xi32>
238238
// CHECK: %[[CAST:.*]] = arith.index_cast %[[V0]] : vector<4x3xi32> to vector<4x3xindex>
239-
// CHECK: %[[B1:.*]] = vector.broadcast %[[CAST]] : vector<4x3xindex> to vector<7x2x4x3xindex>
240239
// CHECK: %[[CAST_1:.*]] = arith.index_cast %[[V1]] : vector<4x3xi32> to vector<4x3xindex>
241-
// CHECK: %[[B2:.*]] = vector.broadcast %[[CAST_1]] : vector<4x3xindex> to vector<7x2x4x3xindex>
242-
// CHECK: %[[MULI:.*]] = arith.muli %[[B1]], %[[CST]] : vector<7x2x4x3xindex>
243-
// CHECK: %[[ADDI:.*]] = arith.addi %[[B2]], %[[MULI]] : vector<7x2x4x3xindex>
244-
// CHECK: %[[T:.*]] = vector.transpose %[[ADDI]], [2, 0, 3, 1] : vector<7x2x4x3xindex> to vector<4x7x3x2xindex>
240+
// CHECK: %[[MULI:.*]] = arith.muli %[[CAST]], %[[CST]] : vector<4x3xindex>
241+
// CHECK: %[[ADDI:.*]] = arith.addi %[[CAST_1]], %[[MULI]] : vector<4x3xindex>
242+
// CHECK: %[[B:.*]] = vector.broadcast %[[ADDI]] : vector<4x3xindex> to vector<7x2x4x3xindex>
243+
// CHECK: %[[T:.*]] = vector.transpose %[[B]], [2, 0, 3, 1] : vector<7x2x4x3xindex> to vector<4x7x3x2xindex>
245244
// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]]] [%[[T]]], %[[CST_1]], %[[PASSTHRU]] : tensor<3x3xf32>, vector<4x7x3x2xindex>, vector<4x7x3x2xi1>, vector<4x7x3x2xf32> into vector<4x7x3x2xf32>
246245
// CHECK: vector.transfer_write %[[GATHER]], %[[ARG4]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true, true]} : vector<4x7x3x2xf32>, tensor<4x7x3x2xf32>
247246

@@ -270,20 +269,16 @@ func.func @vectorize_nd_tensor_extract_load_1d_column_vector_using_gather_load(%
270269
// CHECK-SAME: %[[ARG0:.*]]: tensor<8x128x768xf32>
271270
// CHECK-SAME: %[[ARG1:.*]]: index
272271
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
273-
// CHECK-DAG: %[[CST:.*]] = arith.constant dense<768> : vector<1x8xindex>
274-
// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<128> : vector<1x8xindex>
275272
// CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<8x1xf32>
276-
// CHECK-DAG: %[[CST_2:.*]] = arith.constant dense<true> : vector<8x1xi1>
277-
// CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
273+
// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<true> : vector<8x1xi1>
274+
// CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<[0, 98304, 196608, 294912, 393216, 491520, 589824, 688128]> : vector<8xindex>
278275
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<8x1xf32>
279-
// CHECK: %[[B1:.*]] = vector.broadcast %[[CST_3]] : vector<8xindex> to vector<1x8xindex>
280276
// CHECK: %[[ADDI_ARG1:.*]] = arith.addi %[[ARG1]], %[[ARG1]] : index
281-
// CHECK: %[[MULI_1:.*]] = arith.muli %[[B1]], %[[CST_0]] : vector<1x8xindex>
282-
// CHECK: %[[MULI_2:.*]] = arith.muli %[[MULI_1]], %[[CST]] : vector<1x8xindex>
283-
// CHECK: %[[T:.*]] = vector.transpose %[[MULI_2]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
277+
// CHECK: %[[B1:.*]] = vector.broadcast %[[CST_1]] : vector<8xindex> to vector<1x8xindex>
278+
// CHECK: %[[T:.*]] = vector.transpose %[[B1]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
284279
// CHECK: %[[B2:.*]] = vector.broadcast %[[ADDI_ARG1]] : index to vector<8x1xindex>
285280
// CHECK: %[[ADDI:.*]] = arith.addi %[[B2]], %[[T]] : vector<8x1xindex>
286-
// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[ADDI]]], %[[CST_2]], %[[PASSTHRU]] : tensor<8x128x768xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32>
281+
// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[ADDI]]], %[[CST_0]], %[[PASSTHRU]] : tensor<8x128x768xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32>
287282
// CHECK: vector.transfer_write %[[GATHER]], %[[EMPTY]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32>
288283

289284
// -----
@@ -309,15 +304,13 @@ func.func @index_from_output_column_vector_gather_load(%src: tensor<8x128xf32>)
309304

310305
// CHECK-LABEL: func.func @index_from_output_column_vector_gather_load(
311306
// CHECK-SAME: %[[SRC:.*]]: tensor<8x128xf32>) -> tensor<8x1xf32> {
312-
// CHECK: %[[C128:.*]] = arith.constant dense<128> : vector<1x8xindex>
307+
// CHECK: %[[IDX_VEC:.*]] = arith.constant dense<[0, 128, 256, 384, 512, 640, 768, 896]> : vector<8xindex>
313308
// CHECK: %[[C0:.*]] = arith.constant 0 : index
314309
// CHECK: %[[PASS_THRU:.*]] = arith.constant dense<0.000000e+00> : vector<8x1xf32>
315310
// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<8x1xi1>
316-
// CHECK: %[[IDX_VEC:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
317311
// CHECK: %[[OUT:.*]] = tensor.empty() : tensor<8x1xf32>
318312
// CHECK: %[[B:.*]] = vector.broadcast %[[IDX_VEC]] : vector<8xindex> to vector<1x8xindex>
319-
// CHECK: %[[MUL:.*]] = arith.muli %[[B]], %[[C128]] : vector<1x8xindex>
320-
// CHECK: %[[TR:.*]] = vector.transpose %[[MUL]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
313+
// CHECK: %[[TR:.*]] = vector.transpose %[[B]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
321314
// CHECK: %[[GATHER:.*]] = vector.gather %[[SRC]]{{\[}}%[[C0]], %[[C0]]] {{\[}}%[[TR]]], %[[MASK]], %[[PASS_THRU]] : tensor<8x128xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32>
322315
// CHECK: %[[RES:.*]] = vector.transfer_write %[[GATHER]], %[[OUT]]{{\[}}%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32>
323316
// CHECK: return %[[RES]] : tensor<8x1xf32>
@@ -420,12 +413,12 @@ func.func @vectorize_nd_tensor_extract_with_affine_apply_gather(%6: tensor<80x16
420413
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant dense<true> : vector<1x4xi1>
421414
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant dense<0.000000e+00> : vector<1x4xf32>
422415
// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index
423-
// CHECK-DAG: %[[VAL_7:.*]] = arith.constant dense<16> : vector<1x4xindex>
416+
// CHECK-DAG: %[[VAL_7:.*]] = arith.constant dense<16> : vector<4xindex>
424417
// CHECK: %[[VAL_8:.*]] = vector.broadcast %[[VAL_1]] : index to vector<4xindex>
425418
// CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_3]] : vector<4xindex>
426-
// CHECK: %[[VAL_10:.*]] = vector.broadcast %[[VAL_9]] : vector<4xindex> to vector<1x4xindex>
427-
// CHECK: %[[VAL_11:.*]] = arith.muli %[[VAL_10]], %[[VAL_7]] : vector<1x4xindex>
428-
// CHECK: %[[VAL_12:.*]] = arith.addi %[[VAL_11]], %[[VAL_7]] : vector<1x4xindex>
419+
// CHECK: %[[VAL_10:.*]] = arith.muli %[[VAL_9]], %[[VAL_7]] : vector<4xindex>
420+
// CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_10]], %[[VAL_7]] : vector<4xindex>
421+
// CHECK: %[[VAL_12:.*]] = vector.broadcast %[[VAL_11]] : vector<4xindex> to vector<1x4xindex>
429422
// CHECK: %[[VAL_13:.*]] = vector.gather %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {{\[}}%[[VAL_12]]], %[[VAL_4]], %[[VAL_5]] : tensor<80x16xf32>, vector<1x4xindex>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32>
430423
// CHECK: %[[VAL_14:.*]] = vector.transfer_write %[[VAL_13]], %[[VAL_2]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
431424
// CHECK: return %[[VAL_14]] : tensor<1x4xf32>
@@ -450,14 +443,12 @@ func.func @vectorize_nd_tensor_extract_with_maxsi_gather(%arg0: tensor<80x16xf32
450443
// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_with_maxsi_gather(
451444
// CHECK-SAME: %[[VAL_0:.*]]: tensor<80x16xf32>,
452445
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> {
453-
// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
454-
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant dense<1264> : vector<1x4xindex>
446+
// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<[1264, 1265, 1266, 1267]> : vector<4xindex>
455447
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant dense<true> : vector<1x4xi1>
456448
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant dense<0.000000e+00> : vector<1x4xf32>
457449
// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index
458450
// CHECK: %[[VAL_7:.*]] = vector.broadcast %[[VAL_2]] : vector<4xindex> to vector<1x4xindex>
459-
// CHECK: %[[VAL_8:.*]] = arith.addi %[[VAL_7]], %[[VAL_3]] : vector<1x4xindex>
460-
// CHECK: %[[VAL_9:.*]] = vector.gather %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {{\[}}%[[VAL_8]]], %[[VAL_4]], %[[VAL_5]] : tensor<80x16xf32>, vector<1x4xindex>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32>
451+
// CHECK: %[[VAL_9:.*]] = vector.gather %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {{\[}}%[[VAL_7]]], %[[VAL_4]], %[[VAL_5]] : tensor<80x16xf32>, vector<1x4xindex>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32>
461452
// CHECK: %[[VAL_10:.*]] = vector.transfer_write %[[VAL_9]], %[[VAL_1]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
462453
// CHECK: return %[[VAL_10]] : tensor<1x4xf32>
463454
// CHECK: }
@@ -519,13 +510,13 @@ func.func @vectorize_reverse_like_tensor_extract(%arg0: tensor<1x2x3xf32>, %arg1
519510
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]
520511
// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]
521512
// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]*]]
522-
// CHECK-DAG: %[[CST:.+]] = arith.constant dense<3> : vector<1x1x3xindex>
513+
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
523514
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
524515
// CHECK-DAG: %[[MASK:.*]] = arith.constant dense<true> : vector<1x1x3xi1>
525516
// CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<1x1x3xf32>
526517
// CHECK-DAG: %[[INIT_IDX:.+]] = arith.constant dense<[2, 1, 0]> : vector<3xindex>
527-
// CHECK: %[[T0:.+]] = vector.broadcast %[[ARG2]] : index to vector<1x1x3xindex>
528-
// CHECK: %[[T1:.+]] = arith.muli %[[T0]], %[[CST]] : vector<1x1x3xindex>
518+
// CHECK: %[[T0:.+]] = arith.muli %[[ARG2]], %[[C3]] : index
519+
// CHECK: %[[T1:.+]] = vector.broadcast %[[T0]] : index to vector<1x1x3xindex>
529520
// CHECK: %[[T2:.+]] = vector.broadcast %[[INIT_IDX]]
530521
// CHECK: %[[T3:.+]] = arith.addi %[[T2]], %[[T1]]
531522
// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[T3]]], %[[MASK]], %[[PASSTHRU]]

mlir/test/Dialect/Vector/vector-sink.mlir

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,70 @@ func.func @broadcast_scalar_extsi_scalable(%a : i8) -> vector<2x[4]xi32> {
257257
return %r : vector<2x[4]xi32>
258258
}
259259

260+
// -----
261+
262+
// CHECK-LABEL: func.func @broadcast_scalar_and_splat_const(
263+
// CHECK-SAME: %[[ARG_0:.*]]: index) -> vector<1x4xindex> {
264+
// CHECK: %[[NEW_CST:.*]] = arith.constant 2 : index
265+
// CHECK: %[[ADD:.*]] = arith.addi %[[ARG_0]], %[[NEW_CST]] : index
266+
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex>
267+
// CHECK: return %[[BCAST]] : vector<1x4xindex>
268+
269+
func.func @broadcast_scalar_and_splat_const(%arg0: index) -> vector<1x4xindex> {
270+
%0 = vector.broadcast %arg0 : index to vector<1x4xindex>
271+
%cst = arith.constant dense<2> : vector<1x4xindex>
272+
%2 = arith.addi %0, %cst : vector<1x4xindex>
273+
return %2 : vector<1x4xindex>
274+
}
275+
276+
// -----
277+
278+
// CHECK-LABEL: func.func @broadcast_scalar_and_splat_const_const_first(
279+
// CHECK-SAME: %[[ARG_0:.*]]: index) -> vector<1x4xindex> {
280+
// CHECK: %[[NEW_CST:.*]] = arith.constant 2 : index
281+
// CHECK: %[[SUB:.*]] = arith.subi %[[NEW_CST]], %[[ARG_0]] : index
282+
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[SUB]] : index to vector<1x4xindex>
283+
// CHECK: return %[[BCAST]] : vector<1x4xindex>
284+
285+
func.func @broadcast_scalar_and_splat_const_const_first(%arg0: index) -> vector<1x4xindex> {
286+
%0 = vector.broadcast %arg0 : index to vector<1x4xindex>
287+
%cst = arith.constant dense<2> : vector<1x4xindex>
288+
%2 = arith.subi %cst, %0 : vector<1x4xindex>
289+
return %2 : vector<1x4xindex>
290+
}
291+
292+
// -----
293+
294+
// CHECK-LABEL: func.func @broadcast_vector_and_splat_const(
295+
// CHECK-SAME: %[[ARG_0:.*]]: vector<4xf32>) -> vector<3x4xf32> {
296+
// CHECK: %[[NEW_CST:.*]] = arith.constant dense<2.000000e+00> : vector<4xf32>
297+
// CHECK: %[[ADD:.*]] = arith.mulf %[[ARG_0]], %[[NEW_CST]] : vector<4xf32>
298+
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : vector<4xf32> to vector<3x4xf32>
299+
// CHECK: return %[[BCAST]] : vector<3x4xf32>
300+
301+
func.func @broadcast_vector_and_splat_const(%arg0: vector<4xf32>) -> vector<3x4xf32> {
302+
%0 = vector.broadcast %arg0 : vector<4xf32> to vector<3x4xf32>
303+
%cst = arith.constant dense<2.000000e+00> : vector<3x4xf32>
304+
%2 = arith.mulf %0, %cst : vector<3x4xf32>
305+
return %2 : vector<3x4xf32>
306+
}
307+
308+
// -----
309+
310+
// CHECK-LABEL: func.func @negative_broadcast_with_non_splat_const(
311+
// CHECK-SAME: %[[ARG_0:.*]]: index) -> vector<1x4xindex> {
312+
// CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[ARG_0]] : index to vector<1x4xindex>
313+
// CHECK-DAG: %[[CST:.*]] = arith.constant dense<{{\[}}[0, 1, 2, 3]]> : vector<1x4xindex>
314+
// CHECK: %[[ADD:.*]] = arith.addi %[[BCAST]], %[[CST]] : vector<1x4xindex>
315+
// CHECK: return %[[ADD]] : vector<1x4xindex>
316+
317+
func.func @negative_broadcast_with_non_splat_const(%arg0: index) -> vector<1x4xindex> {
318+
%0 = vector.broadcast %arg0 : index to vector<1x4xindex>
319+
%cst = arith.constant dense<[[0, 1, 2, 3]]> : vector<1x4xindex>
320+
%2 = arith.addi %0, %cst : vector<1x4xindex>
321+
return %2 : vector<1x4xindex>
322+
}
323+
260324
//===----------------------------------------------------------------------===//
261325
// [Pattern: ReorderElementwiseOpsOnTranspose]
262326
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)