-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[mlir][armsme][vector] Replace splat with broadcast #148024
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir Author: James Newling (newling) ChangesPart of deprecation of vector.splat RFC: https://discourse.llvm.org/t/rfc-mlir-vector-deprecate-then-remove-vector-splat/87143/4 Full diff: https://github.com/llvm/llvm-project/pull/148024.diff 4 Files Affected:
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index 21ea444e31821..ccef71b45e4c8 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -604,7 +604,7 @@ struct InsertTileSliceConversion
rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(),
/*scalableDims=*/{true});
- auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
+ auto allActiveMask = rewriter.create<vector::BroadcastOp>(loc, predTy, one);
// Create 'arm_sme.intr.write.(horiz|vert)' to write vector to tile slice.
switch (insertTileSliceOp.getLayout()) {
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 458628c29c6ac..d0eb7091cd279 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -324,7 +324,8 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
// Splat pad into 1-D vector matching type of tile slice.
VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
- auto pad1DOp = rewriter.create<vector::SplatOp>(loc, tileSliceType, padOp);
+ auto pad1DOp =
+ rewriter.create<vector::BroadcastOp>(loc, tileSliceType, padOp);
auto loadSlice = rewriter.create<vector::MaskedLoadOp>(
loc, tileSliceType, tileLoadOp.getBase(), memrefIndices, maskOp1D,
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index d6f9495b2567c..7c8f95c0c194f 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -255,66 +255,6 @@ struct BroadcastOpToArmSMELowering
}
};
-/// Conversion pattern for vector.splat.
-///
-/// Example:
-///
-/// %splat_to_tile = vector.splat %src : i32 to vector<[4]x[4]xi32>
-///
-/// is converted to:
-///
-/// %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32>
-/// %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices
-/// step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>)
-/// {
-/// %tile_update = arm_sme.insert_tile_slice
-/// %broadcast_to_1d, %iter_tile[%tile_slice_index] :
-/// vector<[4]xi32> into vector<[4]x[4]xi32>
-/// scf.yield %tile_update : vector<[4]x[4]xi32>
-/// }
-///
-/// This is identical to vector.broadcast of a scalar.
-struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> {
- using OpRewritePattern<vector::SplatOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(vector::SplatOp splatOp,
- PatternRewriter &rewriter) const final {
- auto tileType = splatOp.getResult().getType();
- if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
- return failure();
-
- auto loc = splatOp.getLoc();
- auto srcType = splatOp.getOperand().getType();
-
- assert(srcType.isIntOrFloat() && "Invalid source type for vector.splat");
- // Avoid unused-variable warning when building without assertions.
- (void)srcType;
-
- // First, broadcast the scalar to a 1-d vector.
- VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
- Value broadcastOp1D = rewriter.create<vector::BroadcastOp>(
- loc, tileSliceType, splatOp.getInput());
-
- auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
-
- auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex,
- Value currentTile) {
- auto nextTile = b.create<arm_sme::InsertTileSliceOp>(
- loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
- return nextTile.getResult();
- };
-
- // Next, create a loop over ZA tile slices and "move" the generated 1-d
- // vector to each slice.
- auto forOp =
- createLoopOverTileSlices(rewriter, loc, initTile, makeLoopBody);
-
- rewriter.replaceOp(splatOp, forOp.getResult(0));
-
- return success();
- }
-};
-
/// Conversion pattern for vector.transpose.
///
/// Stores the input tile to memory and reloads vertically.
@@ -790,11 +730,25 @@ struct ExtractFromCreateMaskToPselLowering
}
};
+// Convert all `vector.splat` to `vector.broadcast`. There is a path from
+// `vector.broadcast` to ArmSME via another pattern.
+struct ConvertSplatToBroadcast : public OpRewritePattern<vector::SplatOp> {
+ using OpRewritePattern<vector::SplatOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::SplatOp splatOp,
+ PatternRewriter &rewriter) const final {
+
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splatOp, splatOp.getType(),
+ splatOp.getInput());
+ return success();
+ }
+};
+
} // namespace
void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
MLIRContext &ctx) {
- patterns.add<BroadcastOpToArmSMELowering, SplatOpToArmSMELowering,
+ patterns.add<BroadcastOpToArmSMELowering, ConvertSplatToBroadcast,
TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index 4ae710aa29113..6f2766ddc6e6e 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -87,7 +87,7 @@ func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(%src : memref<?x?xi32>)
// CHECK-NEXT: %[[MASK_INDEX:.*]] = arith.index_cast %[[MASK]] : i32 to index
// CHECK-NEXT: %[[MASK_1D:.*]] = vector.create_mask %[[MASK_INDEX]] : vector<[4]xi1>
// CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
-// CHECK: %[[PAD_1D:.*]] = vector.splat %[[PAD]] : vector<[4]xi32>
+// CHECK: %[[PAD_1D:.*]] = vector.broadcast %[[PAD]] : i32 to vector<[4]xi32>
// CHECK: %[[LOAD_SLICE:.*]] = vector.maskedload %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[MASK_1D]], %[[PAD_1D]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]xi32> into vector<[4]xi32>
// CHECK: %[[TILE_UPDATE:.*]] = arm_sme.insert_tile_slice %[[LOAD_SLICE]], %[[CURRENT_TILE]][%[[TILE_SLICE_INDEX]]] : vector<[4]xi32> into vector<[4]x[4]xi32>
// CHECK-NEXT: scf.yield %[[TILE_UPDATE]] : vector<[4]x[4]xi32>
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to go through this conversion if we are planning to remove vector.splat?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My goal here is to remove all splat-specific logic. This new pattern will be removed when we eventually remove vector.splat.
banach-space
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for pushing on this and apologies for missing this one earlier!
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
21f2890 to
35db575
Compare
|
@dcaballe please let me know if you'd like further changes, or want to dive deeper into the deprecation plan (there is no rush) @banach-space thx and nw! |
Part of deprecation of vector.splat RFC: https://discourse.llvm.org/t/rfc-mlir-vector-deprecate-then-remove-vector-splat/87143/4
Part of deprecation of vector.splat
RFC: https://discourse.llvm.org/t/rfc-mlir-vector-deprecate-then-remove-vector-splat/87143/4
More complete deprecation: #147818