Skip to content

Conversation

@newling
Copy link
Contributor

@newling newling commented Jul 10, 2025

Part of deprecation of vector.splat

RFC: https://discourse.llvm.org/t/rfc-mlir-vector-deprecate-then-remove-vector-splat/87143/4
More complete deprecation: #147818

@llvmbot
Copy link
Member

llvmbot commented Jul 10, 2025

@llvm/pr-subscribers-mlir

Author: James Newling (newling)

Changes

Part of deprecation of vector.splat

RFC: https://discourse.llvm.org/t/rfc-mlir-vector-deprecate-then-remove-vector-splat/87143/4
More complete deprecation: #147818


Full diff: https://github.com/llvm/llvm-project/pull/148024.diff

4 Files Affected:

  • (modified) mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp (+1-1)
  • (modified) mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp (+2-1)
  • (modified) mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp (+15-61)
  • (modified) mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir (+1-1)
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index 21ea444e31821..ccef71b45e4c8 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -604,7 +604,7 @@ struct InsertTileSliceConversion
         rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
     auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(),
                                   /*scalableDims=*/{true});
-    auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
+    auto allActiveMask = rewriter.create<vector::BroadcastOp>(loc, predTy, one);
 
     // Create 'arm_sme.intr.write.(horiz|vert)' to write vector to tile slice.
     switch (insertTileSliceOp.getLayout()) {
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 458628c29c6ac..d0eb7091cd279 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -324,7 +324,8 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
 
     // Splat pad into 1-D vector matching type of tile slice.
     VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
-    auto pad1DOp = rewriter.create<vector::SplatOp>(loc, tileSliceType, padOp);
+    auto pad1DOp =
+        rewriter.create<vector::BroadcastOp>(loc, tileSliceType, padOp);
 
     auto loadSlice = rewriter.create<vector::MaskedLoadOp>(
         loc, tileSliceType, tileLoadOp.getBase(), memrefIndices, maskOp1D,
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index d6f9495b2567c..7c8f95c0c194f 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -255,66 +255,6 @@ struct BroadcastOpToArmSMELowering
   }
 };
 
-/// Conversion pattern for vector.splat.
-///
-/// Example:
-///
-///   %splat_to_tile = vector.splat %src : i32 to vector<[4]x[4]xi32>
-///
-/// is converted to:
-///
-///   %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32>
-///   %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices
-///       step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>)
-///   {
-///     %tile_update = arm_sme.insert_tile_slice
-///        %broadcast_to_1d, %iter_tile[%tile_slice_index] :
-///        vector<[4]xi32> into vector<[4]x[4]xi32>
-///     scf.yield %tile_update : vector<[4]x[4]xi32>
-///   }
-///
-/// This is identical to vector.broadcast of a scalar.
-struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> {
-  using OpRewritePattern<vector::SplatOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::SplatOp splatOp,
-                                PatternRewriter &rewriter) const final {
-    auto tileType = splatOp.getResult().getType();
-    if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
-      return failure();
-
-    auto loc = splatOp.getLoc();
-    auto srcType = splatOp.getOperand().getType();
-
-    assert(srcType.isIntOrFloat() && "Invalid source type for vector.splat");
-    // Avoid unused-variable warning when building without assertions.
-    (void)srcType;
-
-    // First, broadcast the scalar to a 1-d vector.
-    VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
-    Value broadcastOp1D = rewriter.create<vector::BroadcastOp>(
-        loc, tileSliceType, splatOp.getInput());
-
-    auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
-
-    auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex,
-                            Value currentTile) {
-      auto nextTile = b.create<arm_sme::InsertTileSliceOp>(
-          loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
-      return nextTile.getResult();
-    };
-
-    // Next, create a loop over ZA tile slices and "move" the generated 1-d
-    // vector to each slice.
-    auto forOp =
-        createLoopOverTileSlices(rewriter, loc, initTile, makeLoopBody);
-
-    rewriter.replaceOp(splatOp, forOp.getResult(0));
-
-    return success();
-  }
-};
-
 /// Conversion pattern for vector.transpose.
 ///
 /// Stores the input tile to memory and reloads vertically.
@@ -790,11 +730,25 @@ struct ExtractFromCreateMaskToPselLowering
   }
 };
 
+// Convert all `vector.splat` to `vector.broadcast`. There is a path from
+// `vector.broadcast` to ArmSME via another pattern.
+struct ConvertSplatToBroadcast : public OpRewritePattern<vector::SplatOp> {
+  using OpRewritePattern<vector::SplatOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::SplatOp splatOp,
+                                PatternRewriter &rewriter) const final {
+
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splatOp, splatOp.getType(),
+                                                     splatOp.getInput());
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
                                           MLIRContext &ctx) {
-  patterns.add<BroadcastOpToArmSMELowering, SplatOpToArmSMELowering,
+  patterns.add<BroadcastOpToArmSMELowering, ConvertSplatToBroadcast,
                TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
                TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
                VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index 4ae710aa29113..6f2766ddc6e6e 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -87,7 +87,7 @@ func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(%src : memref<?x?xi32>)
 // CHECK-NEXT:        %[[MASK_INDEX:.*]] = arith.index_cast %[[MASK]] : i32 to index
 // CHECK-NEXT:        %[[MASK_1D:.*]] = vector.create_mask %[[MASK_INDEX]] : vector<[4]xi1>
 // CHECK-NEXT:        %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
-// CHECK:             %[[PAD_1D:.*]] = vector.splat %[[PAD]] : vector<[4]xi32>
+// CHECK:             %[[PAD_1D:.*]] = vector.broadcast %[[PAD]] : i32 to vector<[4]xi32>
 // CHECK:             %[[LOAD_SLICE:.*]] = vector.maskedload %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[MASK_1D]], %[[PAD_1D]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]xi32> into vector<[4]xi32>
 // CHECK:             %[[TILE_UPDATE:.*]] = arm_sme.insert_tile_slice %[[LOAD_SLICE]], %[[CURRENT_TILE]][%[[TILE_SLICE_INDEX]]] : vector<[4]xi32> into vector<[4]x[4]xi32>
 // CHECK-NEXT:        scf.yield %[[TILE_UPDATE]] : vector<[4]x[4]xi32>

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to go through this conversion if we are planning to remove vector.splat?

Copy link
Contributor Author

@newling newling Jul 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My goal here is to remove all splat-specific logic. This new pattern will be removed when we eventually remove vector.splat.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for pushing on this and apologies for missing this one earlier!

@github-actions
Copy link

github-actions bot commented Jul 22, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@newling newling force-pushed the deprecate-vector-splat-armsme branch 2 times, most recently from 21f2890 to 35db575 Compare July 23, 2025 16:45
@newling
Copy link
Contributor Author

newling commented Jul 23, 2025

@dcaballe please let me know if you'd like further changes, or want to dive deeper into the deprecation plan (there is no rush)

@banach-space thx and nw!

@newling newling merged commit e67f323 into llvm:main Jul 23, 2025
8 checks passed
mahesh-attarde pushed a commit to mahesh-attarde/llvm-project that referenced this pull request Jul 28, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants