Skip to content

Conversation

@krzysz00
Copy link
Contributor

This patch extends the operation that rewrites elementwise operations whose inputs are all broadcast from the same shape to handle mixed-types, such as when the result and input types don't match, or when the inputs have multiple types.

PR #150867 failed to check for the possibility of type mismatches when rewriting splat constants. In order to fix that issue, we add support for mixed-type operations more generally.

This patch extends the operation that rewrites elementwise operations
whose inputs are all broadcast from the same shape to handle
mixed-types, such as when the result and input types don't match, or
when the inputs have multiple types.

PR llvm#150867 failed to check for the possibility of type mismatches when
rewriting splat constants. In order to fix that issue, we add support
for mixed-type operations more generally.
@llvmbot
Copy link
Member

llvmbot commented Jul 30, 2025

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Krzysztof Drewniak (krzysz00)

Changes

This patch extends the operation that rewrites elementwise operations whose inputs are all broadcast from the same shape to handle mixed-types, such as when the result and input types don't match, or when the inputs have multiple types.

PR #150867 failed to check for the possibility of type mismatches when rewriting splat constants. In order to fix that issue, we add support for mixed-type operations more generally.


Full diff: https://github.com/llvm/llvm-project/pull/151274.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+45-25)
  • (modified) mlir/test/Dialect/Vector/vector-sink.mlir (+68-7)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index c51c7b7270fae..5ade4d6c22a39 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -965,6 +965,28 @@ struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
   std::function<bool(BitCastOp)> controlFn;
 };
 
+static bool haveSameShapeAndScaling(Type t, Type u) {
+  auto tVec = dyn_cast<VectorType>(t);
+  auto uVec = dyn_cast<VectorType>(u);
+  if (!tVec) {
+    return !uVec;
+  }
+  if (!uVec) {
+    return false;
+  }
+  return tVec.getShape() == uVec.getShape() &&
+         tVec.getScalableDims() == uVec.getScalableDims();
+}
+
+/// If `type` is shaped, clone it with `newElementType`. Otherwise,
+/// return `newElementType`.
+static Type cloneOrReplace(Type type, Type newElementType) {
+  if (auto shapedType = dyn_cast<ShapedType>(type)) {
+    return shapedType.clone(newElementType);
+  }
+  return newElementType;
+}
+
 /// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex:
 ///
 /// Example:
@@ -988,16 +1010,14 @@ struct ReorderElementwiseOpsOnBroadcast final
                                 PatternRewriter &rewriter) const override {
     if (op->getNumResults() != 1)
       return failure();
-    if (!llvm::isa<ShapedType>(op->getResults()[0].getType()))
+    auto resultType = dyn_cast<VectorType>(op->getResult(0).getType());
+    if (!resultType)
       return failure();
     if (!OpTrait::hasElementwiseMappableTraits(op))
       return rewriter.notifyMatchFailure(
           op, "Op doesn't have ElementwiseMappableTraits");
     if (op->getNumOperands() == 0)
       return failure();
-    if (op->getResults()[0].getType() != op->getOperand(0).getType())
-      return rewriter.notifyMatchFailure(op,
-                                         "result and operand type mismatch");
     if (isa<vector::FMAOp>(op)) {
       return rewriter.notifyMatchFailure(
           op,
@@ -1005,6 +1025,7 @@ struct ReorderElementwiseOpsOnBroadcast final
           "might be a scalar");
     }
 
+    Type resultElemType = resultType.getElementType();
     // Get the type of the first non-constant operand
     Operation *firstBroadcastOrSplat = nullptr;
     for (Value operand : op->getOperands()) {
@@ -1020,24 +1041,23 @@ struct ReorderElementwiseOpsOnBroadcast final
     }
     if (!firstBroadcastOrSplat)
       return failure();
-    Type firstBroadcastOrSplatType =
-        firstBroadcastOrSplat->getOperand(0).getType();
+    Type unbroadcastResultType = cloneOrReplace(
+        firstBroadcastOrSplat->getOperand(0).getType(), resultElemType);
 
-    // Make sure that all operands are broadcast from identical types:
+    // Make sure that all operands are broadcast from identically-shaped types:
     //  * scalar (`vector.broadcast` + `vector.splat`), or
     //  * vector (`vector.broadcast`).
     // Otherwise the re-ordering wouldn't be safe.
-    if (!llvm::all_of(
-            op->getOperands(), [&firstBroadcastOrSplatType](Value val) {
-              if (auto bcastOp = val.getDefiningOp<vector::BroadcastOp>())
-                return (bcastOp.getOperand().getType() ==
-                        firstBroadcastOrSplatType);
-              if (auto splatOp = val.getDefiningOp<vector::SplatOp>())
-                return (splatOp.getOperand().getType() ==
-                        firstBroadcastOrSplatType);
-              SplatElementsAttr splatConst;
-              return matchPattern(val, m_Constant(&splatConst));
-            })) {
+    if (!llvm::all_of(op->getOperands(), [&unbroadcastResultType](Value val) {
+          if (auto bcastOp = val.getDefiningOp<vector::BroadcastOp>())
+            return haveSameShapeAndScaling(bcastOp.getOperand().getType(),
+                                           unbroadcastResultType);
+          if (auto splatOp = val.getDefiningOp<vector::SplatOp>())
+            return haveSameShapeAndScaling(splatOp.getOperand().getType(),
+                                           unbroadcastResultType);
+          SplatElementsAttr splatConst;
+          return matchPattern(val, m_Constant(&splatConst));
+        })) {
       return failure();
     }
 
@@ -1048,15 +1068,16 @@ struct ReorderElementwiseOpsOnBroadcast final
       SplatElementsAttr splatConst;
       if (matchPattern(operand, m_Constant(&splatConst))) {
         Attribute newConst;
-        if (auto shapedTy = dyn_cast<ShapedType>(firstBroadcastOrSplatType)) {
-          newConst = splatConst.resizeSplat(shapedTy);
+        Type elementType = getElementTypeOrSelf(operand.getType());
+        Type newType = cloneOrReplace(unbroadcastResultType, elementType);
+        if (auto shapedTy = dyn_cast<ShapedType>(unbroadcastResultType)) {
+          newConst = splatConst.resizeSplat(cast<ShapedType>(newType));
         } else {
           newConst = splatConst.getSplatValue<Attribute>();
         }
         Operation *newConstOp =
             operand.getDefiningOp()->getDialect()->materializeConstant(
-                rewriter, newConst, firstBroadcastOrSplatType,
-                operand.getLoc());
+                rewriter, newConst, newType, operand.getLoc());
         srcValues.push_back(newConstOp->getResult(0));
       } else {
         srcValues.push_back(operand.getDefiningOp()->getOperand(0));
@@ -1066,12 +1087,11 @@ struct ReorderElementwiseOpsOnBroadcast final
     // Create the "elementwise" Op
     Operation *elementwiseOp =
         rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
-                        firstBroadcastOrSplatType, op->getAttrs());
+                        unbroadcastResultType, op->getAttrs());
 
     // Replace the original Op with the elementwise Op
-    auto vectorType = op->getResultTypes()[0];
     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
-        op, vectorType, elementwiseOp->getResults());
+        op, resultType, elementwiseOp->getResults());
 
     return success();
   }
diff --git a/mlir/test/Dialect/Vector/vector-sink.mlir b/mlir/test/Dialect/Vector/vector-sink.mlir
index f8638ab843ecb..d161197e4bfe4 100644
--- a/mlir/test/Dialect/Vector/vector-sink.mlir
+++ b/mlir/test/Dialect/Vector/vector-sink.mlir
@@ -180,13 +180,14 @@ func.func @negative_not_elementwise() -> vector<2x2xf32> {
 
 // -----
 
-// The source and the result for arith.cmp have different types - not supported
-
-// CHECK-LABEL: func.func @negative_source_and_result_mismatch
-//       CHECK:   %[[BROADCAST:.+]] = vector.broadcast
-//       CHECK:   %[[RETURN:.+]] = arith.cmpf uno, %[[BROADCAST]], %[[BROADCAST]]
-//       CHECK:   return %[[RETURN]]
-func.func @negative_source_and_result_mismatch(%arg0 : f32, %arg1 : vector<1xf32>) -> vector<1xi1> {
+// The source and the result for arith.cmp have different types
+
+// CHECK-LABEL: func.func @source_and_result_mismatch(
+//  CHECK-SAME: %[[ARG0:.+]]: f32)
+//       CHECK:   %[[COMPARE:.+]] = arith.cmpf uno, %[[ARG0]], %[[ARG0]]
+//       CHECK:   %[[BROADCAST:.+]] = vector.broadcast %[[COMPARE]] : i1 to vector<1xi1>
+//       CHECK:   return %[[BROADCAST]]
+func.func @source_and_result_mismatch(%arg0 : f32) -> vector<1xi1> {
   %0 = vector.broadcast %arg0 : f32 to vector<1xf32>
   %1 = arith.cmpf uno, %0, %0 : vector<1xf32>
   return %1 : vector<1xi1>
@@ -321,6 +322,66 @@ func.func @negative_broadcast_with_non_splat_const(%arg0: index) -> vector<1x4xi
   return %2 : vector<1x4xindex>
 }
 
+// -----
+
+// CHECK-LABEL:   func.func @broadcast_scalar_mixed_type(
+// CHECK-SAME:     %[[ARG_0:.*]]: f16) -> vector<1x4xf32> {
+// CHECK:           %[[EXTF:.*]] = arith.extf %[[ARG_0]] : f16 to f32
+// CHECK:           %[[BCAST:.*]] = vector.broadcast %[[EXTF]] : f32 to vector<1x4xf32>
+// CHECK:           return %[[BCAST]] : vector<1x4xf32>
+
+func.func @broadcast_scalar_mixed_type(%arg0: f16) -> vector<1x4xf32> {
+  %0 = vector.broadcast %arg0 : f16 to vector<1x4xf16>
+  %1 = arith.extf %0 : vector<1x4xf16> to vector<1x4xf32>
+  return %1 : vector<1x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @broadcast_vector_mixed_type(
+// CHECK-SAME:     %[[ARG_0:.*]]: vector<4xf16>) -> vector<3x4xf32> {
+// CHECK:           %[[EXTF:.*]] = arith.extf %[[ARG_0]] : vector<4xf16> to vector<4xf32>
+// CHECK:           %[[BCAST:.*]] = vector.broadcast %[[EXTF]] : vector<4xf32> to vector<3x4xf32>
+// CHECK:           return %[[BCAST]] : vector<3x4xf32>
+
+func.func @broadcast_vector_mixed_type(%arg0: vector<4xf16>) -> vector<3x4xf32> {
+  %0 = vector.broadcast %arg0 : vector<4xf16> to vector<3x4xf16>
+  %1 = arith.extf %0 : vector<3x4xf16> to vector<3x4xf32>
+  return %1 : vector<3x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @broadcast_scalar_and_splat_const_mixed_type(
+// CHECK-SAME:     %[[ARG_0:.*]]: f32) -> vector<1x4xf32> {
+// CHECK:           %[[NEW_CST:.*]] = arith.constant 3 : i32
+// CHECK:           %[[POW:.*]] = math.fpowi %[[ARG_0]], %[[NEW_CST]] : f32, i32
+// CHECK:           %[[BCAST:.*]] = vector.broadcast %[[POW]] : f32 to vector<1x4xf32>
+// CHECK:           return %[[BCAST]] : vector<1x4xf32>
+
+func.func @broadcast_scalar_and_splat_const_mixed_type(%arg0: f32) -> vector<1x4xf32> {
+  %0 = vector.broadcast %arg0 : f32 to vector<1x4xf32>
+  %cst = arith.constant dense<3> : vector<1x4xi32>
+  %2 = math.fpowi %0, %cst : vector<1x4xf32>, vector<1x4xi32>
+  return %2 : vector<1x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @broadcast_vector_and_splat_const_mixed_type(
+// CHECK-SAME:     %[[ARG_0:.*]]: vector<4xf32>) -> vector<3x4xf32> {
+// CHECK:           %[[NEW_CST:.*]] = arith.constant dense<3> : vector<4xi32>
+// CHECK:           %[[POW:.*]] = math.fpowi %[[ARG_0]], %[[NEW_CST]] : vector<4xf32>, vector<4xi32>
+// CHECK:           %[[BCAST:.*]] = vector.broadcast %[[POW]] : vector<4xf32> to vector<3x4xf32>
+// CHECK:           return %[[BCAST]] : vector<3x4xf32>
+
+func.func @broadcast_vector_and_splat_const_mixed_type(%arg0: vector<4xf32>) -> vector<3x4xf32> {
+  %0 = vector.broadcast %arg0 : vector<4xf32> to vector<3x4xf32>
+  %cst = arith.constant dense<3> : vector<3x4xi32>
+  %2 = math.fpowi %0, %cst : vector<3x4xf32>, vector<3x4xi32>
+  return %2 : vector<3x4xf32>
+}
+
 //===----------------------------------------------------------------------===//
 // [Pattern: ReorderElementwiseOpsOnTranspose]
 //===----------------------------------------------------------------------===//

@banach-space
Copy link
Contributor

Please also address #150867 (comment)

@krzysz00
Copy link
Contributor Author

Please also address #150867 (comment)

Have moved the tests to the other side of the header.

Also, I think (this is probably a future, less urgent, PR) that the cast/broadcast pattern is entirely redundant now

@krzysz00 krzysz00 merged commit 9a46091 into llvm:main Jul 30, 2025
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants