-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[mlir][Vector] Allow elementwise/broadcast swap to handle mixed types #151274
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Vector] Allow elementwise/broadcast swap to handle mixed types #151274
Conversation
This patch extends the operation that rewrites elementwise operations whose inputs are all broadcast from the same shape to handle mixed-types, such as when the result and input types don't match, or when the inputs have multiple types. PR llvm#150867 failed to check for the possibility of type mismatches when rewriting splat constants. In order to fix that issue, we add support for mixed-type operations more generally.
|
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Krzysztof Drewniak (krzysz00) ChangesThis patch extends the operation that rewrites elementwise operations whose inputs are all broadcast from the same shape to handle mixed-types, such as when the result and input types don't match, or when the inputs have multiple types. PR #150867 failed to check for the possibility of type mismatches when rewriting splat constants. In order to fix that issue, we add support for mixed-type operations more generally. Full diff: https://github.com/llvm/llvm-project/pull/151274.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index c51c7b7270fae..5ade4d6c22a39 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -965,6 +965,28 @@ struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
std::function<bool(BitCastOp)> controlFn;
};
+static bool haveSameShapeAndScaling(Type t, Type u) {
+ auto tVec = dyn_cast<VectorType>(t);
+ auto uVec = dyn_cast<VectorType>(u);
+ if (!tVec) {
+ return !uVec;
+ }
+ if (!uVec) {
+ return false;
+ }
+ return tVec.getShape() == uVec.getShape() &&
+ tVec.getScalableDims() == uVec.getScalableDims();
+}
+
+/// If `type` is shaped, clone it with `newElementType`. Otherwise,
+/// return `newElementType`.
+static Type cloneOrReplace(Type type, Type newElementType) {
+ if (auto shapedType = dyn_cast<ShapedType>(type)) {
+ return shapedType.clone(newElementType);
+ }
+ return newElementType;
+}
+
/// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex:
///
/// Example:
@@ -988,16 +1010,14 @@ struct ReorderElementwiseOpsOnBroadcast final
PatternRewriter &rewriter) const override {
if (op->getNumResults() != 1)
return failure();
- if (!llvm::isa<ShapedType>(op->getResults()[0].getType()))
+ auto resultType = dyn_cast<VectorType>(op->getResult(0).getType());
+ if (!resultType)
return failure();
if (!OpTrait::hasElementwiseMappableTraits(op))
return rewriter.notifyMatchFailure(
op, "Op doesn't have ElementwiseMappableTraits");
if (op->getNumOperands() == 0)
return failure();
- if (op->getResults()[0].getType() != op->getOperand(0).getType())
- return rewriter.notifyMatchFailure(op,
- "result and operand type mismatch");
if (isa<vector::FMAOp>(op)) {
return rewriter.notifyMatchFailure(
op,
@@ -1005,6 +1025,7 @@ struct ReorderElementwiseOpsOnBroadcast final
"might be a scalar");
}
+ Type resultElemType = resultType.getElementType();
// Get the type of the first non-constant operand
Operation *firstBroadcastOrSplat = nullptr;
for (Value operand : op->getOperands()) {
@@ -1020,24 +1041,23 @@ struct ReorderElementwiseOpsOnBroadcast final
}
if (!firstBroadcastOrSplat)
return failure();
- Type firstBroadcastOrSplatType =
- firstBroadcastOrSplat->getOperand(0).getType();
+ Type unbroadcastResultType = cloneOrReplace(
+ firstBroadcastOrSplat->getOperand(0).getType(), resultElemType);
- // Make sure that all operands are broadcast from identical types:
+ // Make sure that all operands are broadcast from identically-shaped types:
// * scalar (`vector.broadcast` + `vector.splat`), or
// * vector (`vector.broadcast`).
// Otherwise the re-ordering wouldn't be safe.
- if (!llvm::all_of(
- op->getOperands(), [&firstBroadcastOrSplatType](Value val) {
- if (auto bcastOp = val.getDefiningOp<vector::BroadcastOp>())
- return (bcastOp.getOperand().getType() ==
- firstBroadcastOrSplatType);
- if (auto splatOp = val.getDefiningOp<vector::SplatOp>())
- return (splatOp.getOperand().getType() ==
- firstBroadcastOrSplatType);
- SplatElementsAttr splatConst;
- return matchPattern(val, m_Constant(&splatConst));
- })) {
+ if (!llvm::all_of(op->getOperands(), [&unbroadcastResultType](Value val) {
+ if (auto bcastOp = val.getDefiningOp<vector::BroadcastOp>())
+ return haveSameShapeAndScaling(bcastOp.getOperand().getType(),
+ unbroadcastResultType);
+ if (auto splatOp = val.getDefiningOp<vector::SplatOp>())
+ return haveSameShapeAndScaling(splatOp.getOperand().getType(),
+ unbroadcastResultType);
+ SplatElementsAttr splatConst;
+ return matchPattern(val, m_Constant(&splatConst));
+ })) {
return failure();
}
@@ -1048,15 +1068,16 @@ struct ReorderElementwiseOpsOnBroadcast final
SplatElementsAttr splatConst;
if (matchPattern(operand, m_Constant(&splatConst))) {
Attribute newConst;
- if (auto shapedTy = dyn_cast<ShapedType>(firstBroadcastOrSplatType)) {
- newConst = splatConst.resizeSplat(shapedTy);
+ Type elementType = getElementTypeOrSelf(operand.getType());
+ Type newType = cloneOrReplace(unbroadcastResultType, elementType);
+ if (auto shapedTy = dyn_cast<ShapedType>(unbroadcastResultType)) {
+ newConst = splatConst.resizeSplat(cast<ShapedType>(newType));
} else {
newConst = splatConst.getSplatValue<Attribute>();
}
Operation *newConstOp =
operand.getDefiningOp()->getDialect()->materializeConstant(
- rewriter, newConst, firstBroadcastOrSplatType,
- operand.getLoc());
+ rewriter, newConst, newType, operand.getLoc());
srcValues.push_back(newConstOp->getResult(0));
} else {
srcValues.push_back(operand.getDefiningOp()->getOperand(0));
@@ -1066,12 +1087,11 @@ struct ReorderElementwiseOpsOnBroadcast final
// Create the "elementwise" Op
Operation *elementwiseOp =
rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
- firstBroadcastOrSplatType, op->getAttrs());
+ unbroadcastResultType, op->getAttrs());
// Replace the original Op with the elementwise Op
- auto vectorType = op->getResultTypes()[0];
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
- op, vectorType, elementwiseOp->getResults());
+ op, resultType, elementwiseOp->getResults());
return success();
}
diff --git a/mlir/test/Dialect/Vector/vector-sink.mlir b/mlir/test/Dialect/Vector/vector-sink.mlir
index f8638ab843ecb..d161197e4bfe4 100644
--- a/mlir/test/Dialect/Vector/vector-sink.mlir
+++ b/mlir/test/Dialect/Vector/vector-sink.mlir
@@ -180,13 +180,14 @@ func.func @negative_not_elementwise() -> vector<2x2xf32> {
// -----
-// The source and the result for arith.cmp have different types - not supported
-
-// CHECK-LABEL: func.func @negative_source_and_result_mismatch
-// CHECK: %[[BROADCAST:.+]] = vector.broadcast
-// CHECK: %[[RETURN:.+]] = arith.cmpf uno, %[[BROADCAST]], %[[BROADCAST]]
-// CHECK: return %[[RETURN]]
-func.func @negative_source_and_result_mismatch(%arg0 : f32, %arg1 : vector<1xf32>) -> vector<1xi1> {
+// The source and the result for arith.cmp have different types
+
+// CHECK-LABEL: func.func @source_and_result_mismatch(
+// CHECK-SAME: %[[ARG0:.+]]: f32)
+// CHECK: %[[COMPARE:.+]] = arith.cmpf uno, %[[ARG0]], %[[ARG0]]
+// CHECK: %[[BROADCAST:.+]] = vector.broadcast %[[COMPARE]] : i1 to vector<1xi1>
+// CHECK: return %[[BROADCAST]]
+func.func @source_and_result_mismatch(%arg0 : f32) -> vector<1xi1> {
%0 = vector.broadcast %arg0 : f32 to vector<1xf32>
%1 = arith.cmpf uno, %0, %0 : vector<1xf32>
return %1 : vector<1xi1>
@@ -321,6 +322,66 @@ func.func @negative_broadcast_with_non_splat_const(%arg0: index) -> vector<1x4xi
return %2 : vector<1x4xindex>
}
+// -----
+
+// CHECK-LABEL: func.func @broadcast_scalar_mixed_type(
+// CHECK-SAME: %[[ARG_0:.*]]: f16) -> vector<1x4xf32> {
+// CHECK: %[[EXTF:.*]] = arith.extf %[[ARG_0]] : f16 to f32
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[EXTF]] : f32 to vector<1x4xf32>
+// CHECK: return %[[BCAST]] : vector<1x4xf32>
+
+func.func @broadcast_scalar_mixed_type(%arg0: f16) -> vector<1x4xf32> {
+ %0 = vector.broadcast %arg0 : f16 to vector<1x4xf16>
+ %1 = arith.extf %0 : vector<1x4xf16> to vector<1x4xf32>
+ return %1 : vector<1x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @broadcast_vector_mixed_type(
+// CHECK-SAME: %[[ARG_0:.*]]: vector<4xf16>) -> vector<3x4xf32> {
+// CHECK: %[[EXTF:.*]] = arith.extf %[[ARG_0]] : vector<4xf16> to vector<4xf32>
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[EXTF]] : vector<4xf32> to vector<3x4xf32>
+// CHECK: return %[[BCAST]] : vector<3x4xf32>
+
+func.func @broadcast_vector_mixed_type(%arg0: vector<4xf16>) -> vector<3x4xf32> {
+ %0 = vector.broadcast %arg0 : vector<4xf16> to vector<3x4xf16>
+ %1 = arith.extf %0 : vector<3x4xf16> to vector<3x4xf32>
+ return %1 : vector<3x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @broadcast_scalar_and_splat_const_mixed_type(
+// CHECK-SAME: %[[ARG_0:.*]]: f32) -> vector<1x4xf32> {
+// CHECK: %[[NEW_CST:.*]] = arith.constant 3 : i32
+// CHECK: %[[POW:.*]] = math.fpowi %[[ARG_0]], %[[NEW_CST]] : f32, i32
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[POW]] : f32 to vector<1x4xf32>
+// CHECK: return %[[BCAST]] : vector<1x4xf32>
+
+func.func @broadcast_scalar_and_splat_const_mixed_type(%arg0: f32) -> vector<1x4xf32> {
+ %0 = vector.broadcast %arg0 : f32 to vector<1x4xf32>
+ %cst = arith.constant dense<3> : vector<1x4xi32>
+ %2 = math.fpowi %0, %cst : vector<1x4xf32>, vector<1x4xi32>
+ return %2 : vector<1x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @broadcast_vector_and_splat_const_mixed_type(
+// CHECK-SAME: %[[ARG_0:.*]]: vector<4xf32>) -> vector<3x4xf32> {
+// CHECK: %[[NEW_CST:.*]] = arith.constant dense<3> : vector<4xi32>
+// CHECK: %[[POW:.*]] = math.fpowi %[[ARG_0]], %[[NEW_CST]] : vector<4xf32>, vector<4xi32>
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[POW]] : vector<4xf32> to vector<3x4xf32>
+// CHECK: return %[[BCAST]] : vector<3x4xf32>
+
+func.func @broadcast_vector_and_splat_const_mixed_type(%arg0: vector<4xf32>) -> vector<3x4xf32> {
+ %0 = vector.broadcast %arg0 : vector<4xf32> to vector<3x4xf32>
+ %cst = arith.constant dense<3> : vector<3x4xi32>
+ %2 = math.fpowi %0, %cst : vector<3x4xf32>, vector<3x4xi32>
+ return %2 : vector<3x4xf32>
+}
+
//===----------------------------------------------------------------------===//
// [Pattern: ReorderElementwiseOpsOnTranspose]
//===----------------------------------------------------------------------===//
|
|
Please also address #150867 (comment) |
Have moved the tests to the other side of the header. Also, I think (this is probably a future, less urgent, PR) that the cast/broadcast pattern is entirely redundant now |
This patch extends the operation that rewrites elementwise operations whose inputs are all broadcast from the same shape to handle mixed-types, such as when the result and input types don't match, or when the inputs have multiple types.
PR #150867 failed to check for the possibility of type mismatches when rewriting splat constants. In order to fix that issue, we add support for mixed-type operations more generally.