@@ -255,66 +255,6 @@ struct BroadcastOpToArmSMELowering
255255 }
256256};
257257
258- // / Conversion pattern for vector.splat.
259- // /
260- // / Example:
261- // /
262- // / %splat_to_tile = vector.splat %src : i32 to vector<[4]x[4]xi32>
263- // /
264- // / is converted to:
265- // /
266- // / %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32>
267- // / %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices
268- // / step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>)
269- // / {
270- // / %tile_update = arm_sme.insert_tile_slice
271- // / %broadcast_to_1d, %iter_tile[%tile_slice_index] :
272- // / vector<[4]xi32> into vector<[4]x[4]xi32>
273- // / scf.yield %tile_update : vector<[4]x[4]xi32>
274- // / }
275- // /
276- // / This is identical to vector.broadcast of a scalar.
277- struct SplatOpToArmSMELowering : public OpRewritePattern <vector::SplatOp> {
278- using OpRewritePattern<vector::SplatOp>::OpRewritePattern;
279-
280- LogicalResult matchAndRewrite (vector::SplatOp splatOp,
281- PatternRewriter &rewriter) const final {
282- auto tileType = splatOp.getResult ().getType ();
283- if (!tileType || !arm_sme::isValidSMETileVectorType (tileType))
284- return failure ();
285-
286- auto loc = splatOp.getLoc ();
287- auto srcType = splatOp.getOperand ().getType ();
288-
289- assert (srcType.isIntOrFloat () && " Invalid source type for vector.splat" );
290- // Avoid unused-variable warning when building without assertions.
291- (void )srcType;
292-
293- // First, broadcast the scalar to a 1-d vector.
294- VectorType tileSliceType = VectorType::Builder (tileType).dropDim (0 );
295- Value broadcastOp1D = vector::BroadcastOp::create (
296- rewriter, loc, tileSliceType, splatOp.getInput ());
297-
298- auto initTile = arm_sme::GetTileOp::create (rewriter, loc, tileType);
299-
300- auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex,
301- Value currentTile) {
302- auto nextTile = arm_sme::InsertTileSliceOp::create (
303- b, loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
304- return nextTile.getResult ();
305- };
306-
307- // Next, create a loop over ZA tile slices and "move" the generated 1-d
308- // vector to each slice.
309- auto forOp =
310- createLoopOverTileSlices (rewriter, loc, initTile, makeLoopBody);
311-
312- rewriter.replaceOp (splatOp, forOp.getResult (0 ));
313-
314- return success ();
315- }
316- };
317-
318258// / Conversion pattern for vector.transpose.
319259// /
320260// / Stores the input tile to memory and reloads vertically.
@@ -791,11 +731,25 @@ struct ExtractFromCreateMaskToPselLowering
791731 }
792732};
793733
734+ // Convert all `vector.splat` to `vector.broadcast`. There is a path from
735+ // `vector.broadcast` to ArmSME via another pattern.
736+ struct ConvertSplatToBroadcast : public OpRewritePattern <vector::SplatOp> {
737+ using OpRewritePattern<vector::SplatOp>::OpRewritePattern;
738+
739+ LogicalResult matchAndRewrite (vector::SplatOp splatOp,
740+ PatternRewriter &rewriter) const final {
741+
742+ rewriter.replaceOpWithNewOp <vector::BroadcastOp>(splatOp, splatOp.getType (),
743+ splatOp.getInput ());
744+ return success ();
745+ }
746+ };
747+
794748} // namespace
795749
796750void mlir::populateVectorToArmSMEPatterns (RewritePatternSet &patterns,
797751 MLIRContext &ctx) {
798- patterns.add <BroadcastOpToArmSMELowering, SplatOpToArmSMELowering ,
752+ patterns.add <BroadcastOpToArmSMELowering, ConvertSplatToBroadcast ,
799753 TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
800754 TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
801755 VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
0 commit comments