@@ -666,14 +666,69 @@ struct VectorPrintToArmSMELowering : public OpRewritePattern<vector::PrintOp> {
666666 }
667667};
668668
669+ // / Folds a MoveTileSliceToVectorOp + TransferWriteOp to a StoreTileSliceOp.
670+ // /
671+ // / BEFORE:
672+ // / ```mlir
673+ // / %slice = arm_sme.move_tile_slice_to_vector %tile[%index]
674+ // / : vector<[4]xf32> from vector<[4]x[4]xf32>
675+ // / vector.transfer_write %slice, %memref[%i, %j], %mask {in_bounds = [true]}
676+ // / : vector<[4]xf32>, memref<?x?xf32>
677+ // / ```
678+ // / AFTER:
679+ // / ```mlir
680+ // / arm_sme.store_tile_slice %tile, %index, %mask, %memref[%i, %j]
681+ // / : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
682+ // / ```
683+ struct FoldTransferWriteOfExtractTileSlice
684+ : public OpRewritePattern<vector::TransferWriteOp> {
685+ using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
686+
687+ LogicalResult matchAndRewrite (vector::TransferWriteOp writeOp,
688+ PatternRewriter &rewriter) const final {
689+ if (!isa<MemRefType>(writeOp.getSource ().getType ()))
690+ return rewriter.notifyMatchFailure (writeOp, " destination not a memref" );
691+
692+ if (writeOp.hasOutOfBoundsDim ())
693+ return rewriter.notifyMatchFailure (writeOp,
694+ " not inbounds transfer write" );
695+
696+ auto moveTileSlice =
697+ writeOp.getVector ().getDefiningOp <arm_sme::MoveTileSliceToVectorOp>();
698+ if (!moveTileSlice)
699+ return rewriter.notifyMatchFailure (
700+ writeOp, " vector to store not from MoveTileSliceToVectorOp" );
701+
702+ AffineMap map = writeOp.getPermutationMap ();
703+ if (!map.isMinorIdentity ())
704+ return rewriter.notifyMatchFailure (writeOp,
705+ " unsupported permutation map" );
706+
707+ Value mask = writeOp.getMask ();
708+ if (!mask) {
709+ auto maskType = writeOp.getVectorType ().clone (rewriter.getI1Type ());
710+ mask = rewriter.create <arith::ConstantOp>(
711+ writeOp.getLoc (), maskType, DenseElementsAttr::get (maskType, true ));
712+ }
713+
714+ rewriter.replaceOpWithNewOp <arm_sme::StoreTileSliceOp>(
715+ writeOp, moveTileSlice.getTile (), moveTileSlice.getTileSliceIndex (),
716+ mask, writeOp.getSource (), writeOp.getIndices (),
717+ moveTileSlice.getLayout ());
718+ return success ();
719+ }
720+ };
721+
669722} // namespace
670723
671724void mlir::populateVectorToArmSMEPatterns (RewritePatternSet &patterns,
672725 MLIRContext &ctx) {
673- patterns.add <BroadcastOpToArmSMELowering, SplatOpToArmSMELowering,
674- TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
675- TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
676- VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
677- VectorExtractToArmSMELowering, VectorInsertToArmSMELowering,
678- VectorPrintToArmSMELowering>(&ctx);
726+ patterns
727+ .add <BroadcastOpToArmSMELowering, SplatOpToArmSMELowering,
728+ TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
729+ TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
730+ VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
731+ VectorExtractToArmSMELowering, VectorInsertToArmSMELowering,
732+ VectorPrintToArmSMELowering, FoldTransferWriteOfExtractTileSlice>(
733+ &ctx);
679734}
0 commit comments