Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 44 additions & 16 deletions mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1005,40 +1005,68 @@ struct ReorderElementwiseOpsOnBroadcast final
"might be a scalar");
}

// Get the type of the lhs operand
auto *lhsBcastOrSplat = op->getOperand(0).getDefiningOp();
if (!lhsBcastOrSplat ||
!isa<vector::BroadcastOp, vector::SplatOp>(*lhsBcastOrSplat))
// Get the type of the first non-constant operand
Operation *firstBroadcastOrSplat = nullptr;
for (Value operand : op->getOperands()) {
Operation *definingOp = operand.getDefiningOp();
if (!definingOp)
return failure();
if (definingOp->hasTrait<OpTrait::ConstantLike>())
continue;
if (!isa<vector::BroadcastOp, vector::SplatOp>(*definingOp))
return failure();
firstBroadcastOrSplat = definingOp;
break;
}
if (!firstBroadcastOrSplat)
return failure();
auto lhsBcastOrSplatType = lhsBcastOrSplat->getOperand(0).getType();
Type firstBroadcastOrSplatType =
firstBroadcastOrSplat->getOperand(0).getType();

// Make sure that all operands are broadcast from identical types:
// * scalar (`vector.broadcast` + `vector.splat`), or
// * vector (`vector.broadcast`).
// Otherwise the re-ordering wouldn't be safe.
if (!llvm::all_of(op->getOperands(), [&lhsBcastOrSplatType](Value val) {
auto bcast = val.getDefiningOp<vector::BroadcastOp>();
if (bcast)
return (bcast.getOperand().getType() == lhsBcastOrSplatType);
auto splat = val.getDefiningOp<vector::SplatOp>();
if (splat)
return (splat.getOperand().getType() == lhsBcastOrSplatType);
return false;
})) {
if (!llvm::all_of(
op->getOperands(), [&firstBroadcastOrSplatType](Value val) {
if (auto bcastOp = val.getDefiningOp<vector::BroadcastOp>())
return (bcastOp.getOperand().getType() ==
firstBroadcastOrSplatType);
if (auto splatOp = val.getDefiningOp<vector::SplatOp>())
return (splatOp.getOperand().getType() ==
firstBroadcastOrSplatType);
SplatElementsAttr splatConst;
return matchPattern(val, m_Constant(&splatConst));
})) {
return failure();
}

// Collect the source values before broadcasting
SmallVector<Value> srcValues;
srcValues.reserve(op->getNumOperands());
for (Value operand : op->getOperands()) {
srcValues.push_back(operand.getDefiningOp()->getOperand(0));
SplatElementsAttr splatConst;
if (matchPattern(operand, m_Constant(&splatConst))) {
Attribute newConst;
if (auto shapedTy = dyn_cast<ShapedType>(firstBroadcastOrSplatType)) {
newConst = splatConst.resizeSplat(shapedTy);
} else {
newConst = splatConst.getSplatValue<Attribute>();
}
Operation *newConstOp =
operand.getDefiningOp()->getDialect()->materializeConstant(
rewriter, newConst, firstBroadcastOrSplatType,
operand.getLoc());
srcValues.push_back(newConstOp->getResult(0));
} else {
srcValues.push_back(operand.getDefiningOp()->getOperand(0));
}
}

// Create the "elementwise" Op
Operation *elementwiseOp =
rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
lhsBcastOrSplatType, op->getAttrs());
firstBroadcastOrSplatType, op->getAttrs());

// Replace the original Op with the elementwise Op
auto vectorType = op->getResultTypes()[0];
Expand Down
51 changes: 21 additions & 30 deletions mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -230,18 +230,17 @@ func.func @vectorize_nd_tensor_extract_index_from_tensor(%arg0: tensor<3x3xf32>,
// CHECK-SAME: %[[ARG4:.*]]: tensor<4x7x3x2xf32>
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[PV:.*]] = ub.poison : i32
// CHECK-DAG: %[[CST:.*]] = arith.constant dense<3> : vector<7x2x4x3xindex>
// CHECK-DAG: %[[CST:.*]] = arith.constant dense<3> : vector<4x3xindex>
// CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<true> : vector<4x7x3x2xi1>
// CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<4x7x3x2xf32>
// CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], %[[PV]] {in_bounds = [true, true]} : tensor<4x3xi32>, vector<4x3xi32>
// CHECK: %[[V1:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], %[[PV]] {in_bounds = [true, true]} : tensor<4x3xi32>, vector<4x3xi32>
// CHECK: %[[CAST:.*]] = arith.index_cast %[[V0]] : vector<4x3xi32> to vector<4x3xindex>
// CHECK: %[[B1:.*]] = vector.broadcast %[[CAST]] : vector<4x3xindex> to vector<7x2x4x3xindex>
// CHECK: %[[CAST_1:.*]] = arith.index_cast %[[V1]] : vector<4x3xi32> to vector<4x3xindex>
// CHECK: %[[B2:.*]] = vector.broadcast %[[CAST_1]] : vector<4x3xindex> to vector<7x2x4x3xindex>
// CHECK: %[[MULI:.*]] = arith.muli %[[B1]], %[[CST]] : vector<7x2x4x3xindex>
// CHECK: %[[ADDI:.*]] = arith.addi %[[B2]], %[[MULI]] : vector<7x2x4x3xindex>
// CHECK: %[[T:.*]] = vector.transpose %[[ADDI]], [2, 0, 3, 1] : vector<7x2x4x3xindex> to vector<4x7x3x2xindex>
// CHECK: %[[MULI:.*]] = arith.muli %[[CAST]], %[[CST]] : vector<4x3xindex>
// CHECK: %[[ADDI:.*]] = arith.addi %[[CAST_1]], %[[MULI]] : vector<4x3xindex>
// CHECK: %[[B:.*]] = vector.broadcast %[[ADDI]] : vector<4x3xindex> to vector<7x2x4x3xindex>
// CHECK: %[[T:.*]] = vector.transpose %[[B]], [2, 0, 3, 1] : vector<7x2x4x3xindex> to vector<4x7x3x2xindex>
// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]]] [%[[T]]], %[[CST_1]], %[[PASSTHRU]] : tensor<3x3xf32>, vector<4x7x3x2xindex>, vector<4x7x3x2xi1>, vector<4x7x3x2xf32> into vector<4x7x3x2xf32>
// CHECK: vector.transfer_write %[[GATHER]], %[[ARG4]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true, true]} : vector<4x7x3x2xf32>, tensor<4x7x3x2xf32>

Expand Down Expand Up @@ -270,20 +269,16 @@ func.func @vectorize_nd_tensor_extract_load_1d_column_vector_using_gather_load(%
// CHECK-SAME: %[[ARG0:.*]]: tensor<8x128x768xf32>
// CHECK-SAME: %[[ARG1:.*]]: index
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[CST:.*]] = arith.constant dense<768> : vector<1x8xindex>
// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<128> : vector<1x8xindex>
// CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<8x1xf32>
// CHECK-DAG: %[[CST_2:.*]] = arith.constant dense<true> : vector<8x1xi1>
// CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<true> : vector<8x1xi1>
// CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<[0, 98304, 196608, 294912, 393216, 491520, 589824, 688128]> : vector<8xindex>
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<8x1xf32>
// CHECK: %[[B1:.*]] = vector.broadcast %[[CST_3]] : vector<8xindex> to vector<1x8xindex>
// CHECK: %[[ADDI_ARG1:.*]] = arith.addi %[[ARG1]], %[[ARG1]] : index
// CHECK: %[[MULI_1:.*]] = arith.muli %[[B1]], %[[CST_0]] : vector<1x8xindex>
// CHECK: %[[MULI_2:.*]] = arith.muli %[[MULI_1]], %[[CST]] : vector<1x8xindex>
// CHECK: %[[T:.*]] = vector.transpose %[[MULI_2]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
// CHECK: %[[B1:.*]] = vector.broadcast %[[CST_1]] : vector<8xindex> to vector<1x8xindex>
// CHECK: %[[T:.*]] = vector.transpose %[[B1]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
// CHECK: %[[B2:.*]] = vector.broadcast %[[ADDI_ARG1]] : index to vector<8x1xindex>
// CHECK: %[[ADDI:.*]] = arith.addi %[[B2]], %[[T]] : vector<8x1xindex>
// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[ADDI]]], %[[CST_2]], %[[PASSTHRU]] : tensor<8x128x768xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32>
// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[ADDI]]], %[[CST_0]], %[[PASSTHRU]] : tensor<8x128x768xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32>
// CHECK: vector.transfer_write %[[GATHER]], %[[EMPTY]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32>

// -----
Expand All @@ -309,15 +304,13 @@ func.func @index_from_output_column_vector_gather_load(%src: tensor<8x128xf32>)

// CHECK-LABEL: func.func @index_from_output_column_vector_gather_load(
// CHECK-SAME: %[[SRC:.*]]: tensor<8x128xf32>) -> tensor<8x1xf32> {
// CHECK: %[[C128:.*]] = arith.constant dense<128> : vector<1x8xindex>
// CHECK: %[[IDX_VEC:.*]] = arith.constant dense<[0, 128, 256, 384, 512, 640, 768, 896]> : vector<8xindex>
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[PASS_THRU:.*]] = arith.constant dense<0.000000e+00> : vector<8x1xf32>
// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<8x1xi1>
// CHECK: %[[IDX_VEC:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
// CHECK: %[[OUT:.*]] = tensor.empty() : tensor<8x1xf32>
// CHECK: %[[B:.*]] = vector.broadcast %[[IDX_VEC]] : vector<8xindex> to vector<1x8xindex>
// CHECK: %[[MUL:.*]] = arith.muli %[[B]], %[[C128]] : vector<1x8xindex>
// CHECK: %[[TR:.*]] = vector.transpose %[[MUL]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
// CHECK: %[[TR:.*]] = vector.transpose %[[B]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
// CHECK: %[[GATHER:.*]] = vector.gather %[[SRC]]{{\[}}%[[C0]], %[[C0]]] {{\[}}%[[TR]]], %[[MASK]], %[[PASS_THRU]] : tensor<8x128xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32>
// CHECK: %[[RES:.*]] = vector.transfer_write %[[GATHER]], %[[OUT]]{{\[}}%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32>
// CHECK: return %[[RES]] : tensor<8x1xf32>
Expand Down Expand Up @@ -420,12 +413,12 @@ func.func @vectorize_nd_tensor_extract_with_affine_apply_gather(%6: tensor<80x16
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant dense<true> : vector<1x4xi1>
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant dense<0.000000e+00> : vector<1x4xf32>
// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[VAL_7:.*]] = arith.constant dense<16> : vector<1x4xindex>
// CHECK-DAG: %[[VAL_7:.*]] = arith.constant dense<16> : vector<4xindex>
// CHECK: %[[VAL_8:.*]] = vector.broadcast %[[VAL_1]] : index to vector<4xindex>
// CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_3]] : vector<4xindex>
// CHECK: %[[VAL_10:.*]] = vector.broadcast %[[VAL_9]] : vector<4xindex> to vector<1x4xindex>
// CHECK: %[[VAL_11:.*]] = arith.muli %[[VAL_10]], %[[VAL_7]] : vector<1x4xindex>
// CHECK: %[[VAL_12:.*]] = arith.addi %[[VAL_11]], %[[VAL_7]] : vector<1x4xindex>
// CHECK: %[[VAL_10:.*]] = arith.muli %[[VAL_9]], %[[VAL_7]] : vector<4xindex>
// CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_10]], %[[VAL_7]] : vector<4xindex>
// CHECK: %[[VAL_12:.*]] = vector.broadcast %[[VAL_11]] : vector<4xindex> to vector<1x4xindex>
// CHECK: %[[VAL_13:.*]] = vector.gather %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {{\[}}%[[VAL_12]]], %[[VAL_4]], %[[VAL_5]] : tensor<80x16xf32>, vector<1x4xindex>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32>
// CHECK: %[[VAL_14:.*]] = vector.transfer_write %[[VAL_13]], %[[VAL_2]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
// CHECK: return %[[VAL_14]] : tensor<1x4xf32>
Expand All @@ -450,14 +443,12 @@ func.func @vectorize_nd_tensor_extract_with_maxsi_gather(%arg0: tensor<80x16xf32
// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_with_maxsi_gather(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<80x16xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> {
// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant dense<1264> : vector<1x4xindex>
// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<[1264, 1265, 1266, 1267]> : vector<4xindex>
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant dense<true> : vector<1x4xi1>
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant dense<0.000000e+00> : vector<1x4xf32>
// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_7:.*]] = vector.broadcast %[[VAL_2]] : vector<4xindex> to vector<1x4xindex>
// CHECK: %[[VAL_8:.*]] = arith.addi %[[VAL_7]], %[[VAL_3]] : vector<1x4xindex>
// CHECK: %[[VAL_9:.*]] = vector.gather %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {{\[}}%[[VAL_8]]], %[[VAL_4]], %[[VAL_5]] : tensor<80x16xf32>, vector<1x4xindex>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32>
// CHECK: %[[VAL_9:.*]] = vector.gather %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {{\[}}%[[VAL_7]]], %[[VAL_4]], %[[VAL_5]] : tensor<80x16xf32>, vector<1x4xindex>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32>
// CHECK: %[[VAL_10:.*]] = vector.transfer_write %[[VAL_9]], %[[VAL_1]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
// CHECK: return %[[VAL_10]] : tensor<1x4xf32>
// CHECK: }
Expand Down Expand Up @@ -519,13 +510,13 @@ func.func @vectorize_reverse_like_tensor_extract(%arg0: tensor<1x2x3xf32>, %arg1
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]
// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]
// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]*]]
// CHECK-DAG: %[[CST:.+]] = arith.constant dense<3> : vector<1x1x3xindex>
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[MASK:.*]] = arith.constant dense<true> : vector<1x1x3xi1>
// CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<1x1x3xf32>
// CHECK-DAG: %[[INIT_IDX:.+]] = arith.constant dense<[2, 1, 0]> : vector<3xindex>
// CHECK: %[[T0:.+]] = vector.broadcast %[[ARG2]] : index to vector<1x1x3xindex>
// CHECK: %[[T1:.+]] = arith.muli %[[T0]], %[[CST]] : vector<1x1x3xindex>
// CHECK: %[[T0:.+]] = arith.muli %[[ARG2]], %[[C3]] : index
// CHECK: %[[T1:.+]] = vector.broadcast %[[T0]] : index to vector<1x1x3xindex>
// CHECK: %[[T2:.+]] = vector.broadcast %[[INIT_IDX]]
// CHECK: %[[T3:.+]] = arith.addi %[[T2]], %[[T1]]
// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[T3]]], %[[MASK]], %[[PASSTHRU]]
Expand Down
64 changes: 64 additions & 0 deletions mlir/test/Dialect/Vector/vector-sink.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,70 @@ func.func @broadcast_scalar_extsi_scalable(%a : i8) -> vector<2x[4]xi32> {
return %r : vector<2x[4]xi32>
}

// -----

// CHECK-LABEL: func.func @broadcast_scalar_and_splat_const(
// CHECK-SAME: %[[ARG_0:.*]]: index) -> vector<1x4xindex> {
// CHECK: %[[NEW_CST:.*]] = arith.constant 2 : index
// CHECK: %[[ADD:.*]] = arith.addi %[[ARG_0]], %[[NEW_CST]] : index
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex>
// CHECK: return %[[BCAST]] : vector<1x4xindex>

func.func @broadcast_scalar_and_splat_const(%arg0: index) -> vector<1x4xindex> {
%0 = vector.broadcast %arg0 : index to vector<1x4xindex>
%cst = arith.constant dense<2> : vector<1x4xindex>
%2 = arith.addi %0, %cst : vector<1x4xindex>
return %2 : vector<1x4xindex>
}

// -----

// CHECK-LABEL: func.func @broadcast_scalar_and_splat_const_const_first(
// CHECK-SAME: %[[ARG_0:.*]]: index) -> vector<1x4xindex> {
// CHECK: %[[NEW_CST:.*]] = arith.constant 2 : index
// CHECK: %[[SUB:.*]] = arith.subi %[[NEW_CST]], %[[ARG_0]] : index
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[SUB]] : index to vector<1x4xindex>
// CHECK: return %[[BCAST]] : vector<1x4xindex>

func.func @broadcast_scalar_and_splat_const_const_first(%arg0: index) -> vector<1x4xindex> {
%0 = vector.broadcast %arg0 : index to vector<1x4xindex>
%cst = arith.constant dense<2> : vector<1x4xindex>
%2 = arith.subi %cst, %0 : vector<1x4xindex>
return %2 : vector<1x4xindex>
}

// -----

// CHECK-LABEL: func.func @broadcast_vector_and_splat_const(
// CHECK-SAME: %[[ARG_0:.*]]: vector<4xf32>) -> vector<3x4xf32> {
// CHECK: %[[NEW_CST:.*]] = arith.constant dense<2.000000e+00> : vector<4xf32>
// CHECK: %[[ADD:.*]] = arith.mulf %[[ARG_0]], %[[NEW_CST]] : vector<4xf32>
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : vector<4xf32> to vector<3x4xf32>
// CHECK: return %[[BCAST]] : vector<3x4xf32>

func.func @broadcast_vector_and_splat_const(%arg0: vector<4xf32>) -> vector<3x4xf32> {
%0 = vector.broadcast %arg0 : vector<4xf32> to vector<3x4xf32>
%cst = arith.constant dense<2.000000e+00> : vector<3x4xf32>
%2 = arith.mulf %0, %cst : vector<3x4xf32>
return %2 : vector<3x4xf32>
}

// -----

// CHECK-LABEL: func.func @negative_broadcast_with_non_splat_const(
// CHECK-SAME: %[[ARG_0:.*]]: index) -> vector<1x4xindex> {
// CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[ARG_0]] : index to vector<1x4xindex>
// CHECK-DAG: %[[CST:.*]] = arith.constant dense<{{\[}}[0, 1, 2, 3]]> : vector<1x4xindex>
// CHECK: %[[ADD:.*]] = arith.addi %[[BCAST]], %[[CST]] : vector<1x4xindex>
// CHECK: return %[[ADD]] : vector<1x4xindex>

func.func @negative_broadcast_with_non_splat_const(%arg0: index) -> vector<1x4xindex> {
%0 = vector.broadcast %arg0 : index to vector<1x4xindex>
%cst = arith.constant dense<[[0, 1, 2, 3]]> : vector<1x4xindex>
%2 = arith.addi %0, %cst : vector<1x4xindex>
return %2 : vector<1x4xindex>
Comment on lines +260 to +321
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] Could you move these tests near other tests for ReorderElementwiseOpsOnBroadcast, i.e. here:

//-----------------------------------------------------------------------------
// [Pattern: ReorderElementwiseOpsOnBroadcast]
//-----------------------------------------------------------------------------
// CHECK-LABEL: func.func @broadcast_scalar_with_bcast(
// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index) -> vector<1x4xindex> {
// CHECK: %[[ADD:.*]] = arith.addi %[[ARG_0]], %[[ARG_1]] : index
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex>
// CHECK: return %[[BCAST]] : vector<1x4xindex>
func.func @broadcast_scalar_with_bcast(%arg1: index, %arg2: index) -> vector<1x4xindex> {
%0 = vector.broadcast %arg1 : index to vector<1x4xindex>
%1 = vector.broadcast %arg2 : index to vector<1x4xindex>
%2 = arith.addi %0, %1 : vector<1x4xindex>
return %2 : vector<1x4xindex>
}
// CHECK-LABEL: func.func @broadcast_scalar_with_bcast_scalable(
// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index) -> vector<1x[4]xindex> {
// CHECK: %[[ADD:.*]] = arith.addi %[[ARG_0]], %[[ARG_1]] : index
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x[4]xindex>
// CHECK: return %[[BCAST]] : vector<1x[4]xindex>
func.func @broadcast_scalar_with_bcast_scalable(%arg1: index, %arg2: index) -> vector<1x[4]xindex> {
%0 = vector.broadcast %arg1 : index to vector<1x[4]xindex>
%1 = vector.broadcast %arg2 : index to vector<1x[4]xindex>
%2 = arith.addi %0, %1 : vector<1x[4]xindex>
return %2 : vector<1x[4]xindex>
}
// -----
// CHECK-LABEL: func.func @broadcast_scalar_with_bcast_and_splat(
// CHECK-SAME: %[[ARG1:.*]]: index,
// CHECK-SAME: %[[ARG2:.*]]: index) -> vector<1x4xindex> {
// CHECK: %[[ADD:.*]] = arith.addi %[[ARG1]], %[[ARG2]] : index
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex>
// CHECK: return %[[BCAST]] : vector<1x4xindex>
func.func @broadcast_scalar_with_bcast_and_splat(%arg1: index, %arg2: index) -> vector<1x4xindex> {
%0 = vector.splat %arg1 : vector<1x4xindex>
%1 = vector.broadcast %arg2 : index to vector<1x4xindex>
%2 = arith.addi %0, %1 : vector<1x4xindex>
return %2 : vector<1x4xindex>
}
// CHECK-LABEL: func.func @broadcast_scalar_with_bcast_and_splat_scalable(
// CHECK-SAME: %[[ARG1:.*]]: index,
// CHECK-SAME: %[[ARG2:.*]]: index) -> vector<1x[4]xindex> {
// CHECK: %[[ADD:.*]] = arith.addi %[[ARG1]], %[[ARG2]] : index
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x[4]xindex>
// CHECK: return %[[BCAST]] : vector<1x[4]xindex>
func.func @broadcast_scalar_with_bcast_and_splat_scalable(%arg1: index, %arg2: index) -> vector<1x[4]xindex> {
%0 = vector.splat %arg1 : vector<1x[4]xindex>
%1 = vector.broadcast %arg2 : index to vector<1x[4]xindex>
%2 = arith.addi %0, %1 : vector<1x[4]xindex>
return %2 : vector<1x[4]xindex>
}
// -----
// CHECK-LABEL: func.func @broadcast_vector(
// CHECK-SAME: %[[ARG_0:.*]]: vector<4xf32>,
// CHECK-SAME: %[[ARG_1:.*]]: vector<4xf32>) -> vector<3x4xf32> {
// CHECK: %[[ADDF:.*]] = arith.addf %[[ARG_0]], %[[ARG_1]] : vector<4xf32>
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADDF]] : vector<4xf32> to vector<3x4xf32>
// CHECK: return %[[BCAST]] : vector<3x4xf32>
func.func @broadcast_vector(%arg1: vector<4xf32>, %arg2: vector<4xf32>) -> vector<3x4xf32> {
%arg1_bcast = vector.broadcast %arg1 : vector<4xf32> to vector<3x4xf32>
%arg2_bcast = vector.broadcast %arg2 : vector<4xf32> to vector<3x4xf32>
%2 = arith.addf %arg1_bcast, %arg2_bcast : vector<3x4xf32>
return %2 : vector<3x4xf32>
}
// CHECK-LABEL: func.func @broadcast_vector_scalable(
// CHECK-SAME: %[[ARG_0:.*]]: vector<[4]xf32>,
// CHECK-SAME: %[[ARG_1:.*]]: vector<[4]xf32>) -> vector<3x[4]xf32> {
// CHECK: %[[ADDF:.*]] = arith.addf %[[ARG_0]], %[[ARG_1]] : vector<[4]xf32>
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADDF]] : vector<[4]xf32> to vector<3x[4]xf32>
// CHECK: return %[[BCAST]] : vector<3x[4]xf32>
func.func @broadcast_vector_scalable(%arg1: vector<[4]xf32>, %arg2: vector<[4]xf32>) -> vector<3x[4]xf32> {
%arg1_bcast = vector.broadcast %arg1 : vector<[4]xf32> to vector<3x[4]xf32>
%arg2_bcast = vector.broadcast %arg2 : vector<[4]xf32> to vector<3x[4]xf32>
%2 = arith.addf %arg1_bcast, %arg2_bcast : vector<3x[4]xf32>
return %2 : vector<3x[4]xf32>
}
// -----
// CHECK-LABEL: func.func @broadcast_scalar_and_vec(
// CHECK-SAME: %[[ARG1:.*]]: index,
// CHECK-SAME: %[[ARG2:.*]]: vector<4xindex>) -> vector<1x4xindex> {
// CHECK: %[[SPLAT:.*]] = vector.splat %[[ARG1]] : vector<1x4xindex>
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ARG2]] : vector<4xindex> to vector<1x4xindex>
// CHECK: %[[ADD:.*]] = arith.addi %[[SPLAT]], %[[BCAST]] : vector<1x4xindex>
// CHECK: return %[[ADD]] : vector<1x4xindex>
func.func @broadcast_scalar_and_vec(%arg1: index, %arg2: vector<4xindex>) -> vector<1x4xindex> {
%0 = vector.splat %arg1 : vector<1x4xindex>
%1 = vector.broadcast %arg2 : vector<4xindex> to vector<1x4xindex>
%2 = arith.addi %0, %1 : vector<1x4xindex>
return %2 : vector<1x4xindex>
}
// CHECK-LABEL: func.func @broadcast_scalar_and_vec_scalable(
// CHECK-SAME: %[[ARG1:.*]]: index,
// CHECK-SAME: %[[ARG2:.*]]: vector<[4]xindex>) -> vector<1x[4]xindex> {
// CHECK: %[[SPLAT:.*]] = vector.splat %[[ARG1]] : vector<1x[4]xindex>
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ARG2]] : vector<[4]xindex> to vector<1x[4]xindex>
// CHECK: %[[ADD:.*]] = arith.addi %[[SPLAT]], %[[BCAST]] : vector<1x[4]xindex>
// CHECK: return %[[ADD]] : vector<1x[4]xindex>
func.func @broadcast_scalar_and_vec_scalable(%arg1: index, %arg2: vector<[4]xindex>) -> vector<1x[4]xindex> {
%0 = vector.splat %arg1 : vector<1x[4]xindex>
%1 = vector.broadcast %arg2 : vector<[4]xindex> to vector<1x[4]xindex>
%2 = arith.addi %0, %1 : vector<1x[4]xindex>
return %2 : vector<1x[4]xindex>
}
// -----
// CHECK-LABEL: func.func @broadcast_vector_and_scalar(
// CHECK-SAME: %[[ARG_0:.*]]: i32,
// CHECK-SAME: %[[ARG_1:.*]]: vector<4xi32>) -> vector<4xi32> {
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ARG_0]] : i32 to vector<4xi32>
// CHECK: %[[ADD:.*]] = arith.addi %[[BCAST]], %[[ARG_1]] : vector<4xi32>
// CHECK: return %[[ADD]] : vector<4xi32>
func.func @broadcast_vector_and_scalar(%arg1: i32, %arg2: vector<4xi32>) -> vector<4xi32> {
%arg1_bcast = vector.broadcast %arg1 : i32 to vector<4xi32>
%2 = arith.addi %arg1_bcast, %arg2 : vector<4xi32>
return %2 : vector<4xi32>
}
// CHECK-LABEL: func.func @broadcast_vector_and_scalar_scalable(
// CHECK-SAME: %[[ARG_0:.*]]: i32,
// CHECK-SAME: %[[ARG_1:.*]]: vector<[4]xi32>) -> vector<[4]xi32> {
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ARG_0]] : i32 to vector<[4]xi32>
// CHECK: %[[ADD:.*]] = arith.addi %[[BCAST]], %[[ARG_1]] : vector<[4]xi32>
// CHECK: return %[[ADD]] : vector<[4]xi32>
func.func @broadcast_vector_and_scalar_scalable(%arg1: i32, %arg2: vector<[4]xi32>) -> vector<[4]xi32> {
%arg1_bcast = vector.broadcast %arg1 : i32 to vector<[4]xi32>
%2 = arith.addi %arg1_bcast, %arg2 : vector<[4]xi32>
return %2 : vector<[4]xi32>
}
// -----
#matmat_accesses = [
affine_map<(i, j, k) -> (i, k)>,
affine_map<(i, j, k) -> (k, j)>,
affine_map<(i, j, k) -> (i, j)>
]
#matmat_trait = {
indexing_maps = #matmat_accesses,
iterator_types = ["parallel", "parallel", "reduction"]
}
// CHECK-LABEL: func.func @negative_not_elementwise
// CHECK-DAG: %[[F1:.*]] = arith.constant dense<1.000000e+00> : vector<2x2xf32>
// CHECK-DAG: %[[F2:.*]] = arith.constant dense<2.000000e+00> : vector<2x2xf32>
// CHECK-DAG: %[[F3:.*]] = arith.constant dense<3.000000e+00> : vector<2x2xf32>
// CHECK: %[[RES:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[F1]], %[[F2]], %[[F3]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
func.func @negative_not_elementwise() -> vector<2x2xf32> {
%f1 = arith.constant 1.0: f32
%f2 = arith.constant 2.0: f32
%f3 = arith.constant 3.0: f32
%A = vector.broadcast %f1 : f32 to vector<2x2xf32>
%B = vector.broadcast %f2 : f32 to vector<2x2xf32>
%C = vector.broadcast %f3 : f32 to vector<2x2xf32>
%res = vector.contract #matmat_trait %A, %B, %C
: vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
return %res : vector<2x2xf32>
}
// -----
// The source and the result for arith.cmp have different types - not supported
// CHECK-LABEL: func.func @negative_source_and_result_mismatch
// CHECK: %[[BROADCAST:.+]] = vector.broadcast
// CHECK: %[[RETURN:.+]] = arith.cmpf uno, %[[BROADCAST]], %[[BROADCAST]]
// CHECK: return %[[RETURN]]
func.func @negative_source_and_result_mismatch(%arg0 : f32, %arg1 : vector<1xf32>) -> vector<1xi1> {
%0 = vector.broadcast %arg0 : f32 to vector<1xf32>
%1 = arith.cmpf uno, %0, %0 : vector<1xf32>
return %1 : vector<1xi1>
}
// -----
// vector.fma only supports vectors - currently it's not possible to replace this with e.g.:
// %scalar_res = vector.fma %scalar_1, %scalar2
// %vec_res = vector.broadcast %scalar_res
//
// TODO: It should be possible to support this case
// CHECK-LABEL: func.func @negative_op_only_supports_vectors
// CHECK: %[[BROADCAST:.+]] = vector.broadcast
// CHECK: %[[RESULT:.+]] = vector.fma %[[BROADCAST]]
// CHECK: return %[[RESULT]]
func.func @negative_op_only_supports_vectors(%arg0 : f32) -> vector<1xf32> {
%0 = vector.broadcast %arg0 : f32 to vector<1xf32>
%1 = vector.fma %0, %0, %0 : vector<1xf32>
return %1 : vector<1xf32>
}

Thanks and sorry for not taking a look earlier!

}

//===----------------------------------------------------------------------===//
// [Pattern: ReorderElementwiseOpsOnTranspose]
//===----------------------------------------------------------------------===//
Expand Down