Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 45 additions & 25 deletions mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -965,6 +965,28 @@ struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
std::function<bool(BitCastOp)> controlFn;
};

static bool haveSameShapeAndScaling(Type t, Type u) {
auto tVec = dyn_cast<VectorType>(t);
auto uVec = dyn_cast<VectorType>(u);
if (!tVec) {
return !uVec;
}
if (!uVec) {
return false;
}
return tVec.getShape() == uVec.getShape() &&
tVec.getScalableDims() == uVec.getScalableDims();
}

/// If `type` is shaped, clone it with `newElementType`. Otherwise,
/// return `newElementType`.
static Type cloneOrReplace(Type type, Type newElementType) {
if (auto shapedType = dyn_cast<ShapedType>(type)) {
return shapedType.clone(newElementType);
}
return newElementType;
}

/// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex:
///
/// Example:
Expand All @@ -988,23 +1010,22 @@ struct ReorderElementwiseOpsOnBroadcast final
PatternRewriter &rewriter) const override {
if (op->getNumResults() != 1)
return failure();
if (!llvm::isa<ShapedType>(op->getResults()[0].getType()))
auto resultType = dyn_cast<VectorType>(op->getResult(0).getType());
if (!resultType)
return failure();
if (!OpTrait::hasElementwiseMappableTraits(op))
return rewriter.notifyMatchFailure(
op, "Op doesn't have ElementwiseMappableTraits");
if (op->getNumOperands() == 0)
return failure();
if (op->getResults()[0].getType() != op->getOperand(0).getType())
return rewriter.notifyMatchFailure(op,
"result and operand type mismatch");
if (isa<vector::FMAOp>(op)) {
return rewriter.notifyMatchFailure(
op,
"Op only accepts vector types - not supported as broadcast source "
"might be a scalar");
}

Type resultElemType = resultType.getElementType();
// Get the type of the first non-constant operand
Operation *firstBroadcastOrSplat = nullptr;
for (Value operand : op->getOperands()) {
Expand All @@ -1020,24 +1041,23 @@ struct ReorderElementwiseOpsOnBroadcast final
}
if (!firstBroadcastOrSplat)
return failure();
Type firstBroadcastOrSplatType =
firstBroadcastOrSplat->getOperand(0).getType();
Type unbroadcastResultType = cloneOrReplace(
firstBroadcastOrSplat->getOperand(0).getType(), resultElemType);

// Make sure that all operands are broadcast from identical types:
// Make sure that all operands are broadcast from identically-shaped types:
// * scalar (`vector.broadcast` + `vector.splat`), or
// * vector (`vector.broadcast`).
// Otherwise the re-ordering wouldn't be safe.
if (!llvm::all_of(
op->getOperands(), [&firstBroadcastOrSplatType](Value val) {
if (auto bcastOp = val.getDefiningOp<vector::BroadcastOp>())
return (bcastOp.getOperand().getType() ==
firstBroadcastOrSplatType);
if (auto splatOp = val.getDefiningOp<vector::SplatOp>())
return (splatOp.getOperand().getType() ==
firstBroadcastOrSplatType);
SplatElementsAttr splatConst;
return matchPattern(val, m_Constant(&splatConst));
})) {
if (!llvm::all_of(op->getOperands(), [&unbroadcastResultType](Value val) {
if (auto bcastOp = val.getDefiningOp<vector::BroadcastOp>())
return haveSameShapeAndScaling(bcastOp.getOperand().getType(),
unbroadcastResultType);
if (auto splatOp = val.getDefiningOp<vector::SplatOp>())
return haveSameShapeAndScaling(splatOp.getOperand().getType(),
unbroadcastResultType);
SplatElementsAttr splatConst;
return matchPattern(val, m_Constant(&splatConst));
})) {
return failure();
}

Expand All @@ -1048,15 +1068,16 @@ struct ReorderElementwiseOpsOnBroadcast final
SplatElementsAttr splatConst;
if (matchPattern(operand, m_Constant(&splatConst))) {
Attribute newConst;
if (auto shapedTy = dyn_cast<ShapedType>(firstBroadcastOrSplatType)) {
newConst = splatConst.resizeSplat(shapedTy);
Type elementType = getElementTypeOrSelf(operand.getType());
Type newType = cloneOrReplace(unbroadcastResultType, elementType);
if (auto shapedTy = dyn_cast<ShapedType>(unbroadcastResultType)) {
newConst = splatConst.resizeSplat(cast<ShapedType>(newType));
} else {
newConst = splatConst.getSplatValue<Attribute>();
}
Operation *newConstOp =
operand.getDefiningOp()->getDialect()->materializeConstant(
rewriter, newConst, firstBroadcastOrSplatType,
operand.getLoc());
rewriter, newConst, newType, operand.getLoc());
srcValues.push_back(newConstOp->getResult(0));
} else {
srcValues.push_back(operand.getDefiningOp()->getOperand(0));
Expand All @@ -1066,12 +1087,11 @@ struct ReorderElementwiseOpsOnBroadcast final
// Create the "elementwise" Op
Operation *elementwiseOp =
rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
firstBroadcastOrSplatType, op->getAttrs());
unbroadcastResultType, op->getAttrs());

// Replace the original Op with the elementwise Op
auto vectorType = op->getResultTypes()[0];
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
op, vectorType, elementwiseOp->getResults());
op, resultType, elementwiseOp->getResults());

return success();
}
Expand Down
75 changes: 68 additions & 7 deletions mlir/test/Dialect/Vector/vector-sink.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -180,13 +180,14 @@ func.func @negative_not_elementwise() -> vector<2x2xf32> {

// -----

// The source and the result for arith.cmp have different types - not supported

// CHECK-LABEL: func.func @negative_source_and_result_mismatch
// CHECK: %[[BROADCAST:.+]] = vector.broadcast
// CHECK: %[[RETURN:.+]] = arith.cmpf uno, %[[BROADCAST]], %[[BROADCAST]]
// CHECK: return %[[RETURN]]
func.func @negative_source_and_result_mismatch(%arg0 : f32, %arg1 : vector<1xf32>) -> vector<1xi1> {
// The source and the result for arith.cmp have different types

// CHECK-LABEL: func.func @source_and_result_mismatch(
// CHECK-SAME: %[[ARG0:.+]]: f32)
// CHECK: %[[COMPARE:.+]] = arith.cmpf uno, %[[ARG0]], %[[ARG0]]
// CHECK: %[[BROADCAST:.+]] = vector.broadcast %[[COMPARE]] : i1 to vector<1xi1>
// CHECK: return %[[BROADCAST]]
func.func @source_and_result_mismatch(%arg0 : f32) -> vector<1xi1> {
%0 = vector.broadcast %arg0 : f32 to vector<1xf32>
%1 = arith.cmpf uno, %0, %0 : vector<1xf32>
return %1 : vector<1xi1>
Expand Down Expand Up @@ -321,6 +322,66 @@ func.func @negative_broadcast_with_non_splat_const(%arg0: index) -> vector<1x4xi
return %2 : vector<1x4xindex>
}

// -----

// CHECK-LABEL: func.func @broadcast_scalar_mixed_type(
// CHECK-SAME: %[[ARG_0:.*]]: f16) -> vector<1x4xf32> {
// CHECK: %[[EXTF:.*]] = arith.extf %[[ARG_0]] : f16 to f32
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[EXTF]] : f32 to vector<1x4xf32>
// CHECK: return %[[BCAST]] : vector<1x4xf32>

func.func @broadcast_scalar_mixed_type(%arg0: f16) -> vector<1x4xf32> {
%0 = vector.broadcast %arg0 : f16 to vector<1x4xf16>
%1 = arith.extf %0 : vector<1x4xf16> to vector<1x4xf32>
return %1 : vector<1x4xf32>
}

// -----

// CHECK-LABEL: func.func @broadcast_vector_mixed_type(
// CHECK-SAME: %[[ARG_0:.*]]: vector<4xf16>) -> vector<3x4xf32> {
// CHECK: %[[EXTF:.*]] = arith.extf %[[ARG_0]] : vector<4xf16> to vector<4xf32>
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[EXTF]] : vector<4xf32> to vector<3x4xf32>
// CHECK: return %[[BCAST]] : vector<3x4xf32>

func.func @broadcast_vector_mixed_type(%arg0: vector<4xf16>) -> vector<3x4xf32> {
%0 = vector.broadcast %arg0 : vector<4xf16> to vector<3x4xf16>
%1 = arith.extf %0 : vector<3x4xf16> to vector<3x4xf32>
return %1 : vector<3x4xf32>
}

// -----

// CHECK-LABEL: func.func @broadcast_scalar_and_splat_const_mixed_type(
// CHECK-SAME: %[[ARG_0:.*]]: f32) -> vector<1x4xf32> {
// CHECK: %[[NEW_CST:.*]] = arith.constant 3 : i32
// CHECK: %[[POW:.*]] = math.fpowi %[[ARG_0]], %[[NEW_CST]] : f32, i32
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[POW]] : f32 to vector<1x4xf32>
// CHECK: return %[[BCAST]] : vector<1x4xf32>

func.func @broadcast_scalar_and_splat_const_mixed_type(%arg0: f32) -> vector<1x4xf32> {
%0 = vector.broadcast %arg0 : f32 to vector<1x4xf32>
%cst = arith.constant dense<3> : vector<1x4xi32>
%2 = math.fpowi %0, %cst : vector<1x4xf32>, vector<1x4xi32>
return %2 : vector<1x4xf32>
}

// -----

// CHECK-LABEL: func.func @broadcast_vector_and_splat_const_mixed_type(
// CHECK-SAME: %[[ARG_0:.*]]: vector<4xf32>) -> vector<3x4xf32> {
// CHECK: %[[NEW_CST:.*]] = arith.constant dense<3> : vector<4xi32>
// CHECK: %[[POW:.*]] = math.fpowi %[[ARG_0]], %[[NEW_CST]] : vector<4xf32>, vector<4xi32>
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[POW]] : vector<4xf32> to vector<3x4xf32>
// CHECK: return %[[BCAST]] : vector<3x4xf32>

func.func @broadcast_vector_and_splat_const_mixed_type(%arg0: vector<4xf32>) -> vector<3x4xf32> {
%0 = vector.broadcast %arg0 : vector<4xf32> to vector<3x4xf32>
%cst = arith.constant dense<3> : vector<3x4xi32>
%2 = math.fpowi %0, %cst : vector<3x4xf32>, vector<3x4xi32>
return %2 : vector<3x4xf32>
}

//===----------------------------------------------------------------------===//
// [Pattern: ReorderElementwiseOpsOnTranspose]
//===----------------------------------------------------------------------===//
Expand Down