@@ -724,59 +724,6 @@ struct LiftIllegalVectorTransposeToMemory
724724 }
725725};
726726
727- // / A rewrite to turn unit dim transpose-like vector.shape_casts into
728- // / vector.transposes. The shape_cast has to be from an illegal vector type to a
729- // / legal one (as defined by isLegalVectorType).
730- // /
731- // / The reasoning for this is if we've got to this pass and we still have
732- // / shape_casts of illegal types, then they likely will not cancel out. Turning
733- // / them into transposes gives LiftIllegalVectorTransposeToMemory a chance to
734- // / eliminate them.
735- // /
736- // / Example:
737- // /
738- // / BEFORE:
739- // / ```mlir
740- // / %0 = vector.shape_cast %a : vector<[4]x1xf32> to vector<1x[4]xf32>
741- // / ```
742- // /
743- // / AFTER:
744- // / ```mlir
745- // / %0 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
746- // / ```
747- struct ConvertIllegalShapeCastOpsToTransposes
748- : public OpRewritePattern<vector::ShapeCastOp> {
749- using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
750-
751- LogicalResult matchAndRewrite (vector::ShapeCastOp shapeCastOp,
752- PatternRewriter &rewriter) const override {
753- auto sourceType = shapeCastOp.getSourceVectorType ();
754- auto resultType = shapeCastOp.getResultVectorType ();
755- if (isLegalVectorType (sourceType) || !isLegalVectorType (resultType))
756- return rewriter.notifyMatchFailure (shapeCastOp,
757- kMatchFailureNotIllegalToLegal );
758-
759- // Note: If we know that `sourceType` is an illegal vector type (and 2D)
760- // then dim 0 is scalable and dim 1 is fixed.
761- if (sourceType.getRank () != 2 || sourceType.getDimSize (1 ) != 1 )
762- return rewriter.notifyMatchFailure (
763- shapeCastOp, " expected source to be a 2D scalable vector with a "
764- " trailing unit dim" );
765-
766- auto loc = shapeCastOp.getLoc ();
767- auto transpose = rewriter.create <vector::TransposeOp>(
768- loc, shapeCastOp.getSource (), ArrayRef<int64_t >{1 , 0 });
769-
770- if (resultType.getRank () == 1 )
771- rewriter.replaceOpWithNewOp <vector::ShapeCastOp>(shapeCastOp, resultType,
772- transpose);
773- else
774- rewriter.replaceOp (shapeCastOp, transpose);
775-
776- return success ();
777- }
778- };
779-
780727// / Rewrites an illegal/unsupported SVE transfer_write(transpose) to instead use
781728// / the ZA state. This workaround rewrite to support these transposes when ZA is
782729// / available.
@@ -920,6 +867,116 @@ struct LowerIllegalTransposeStoreViaZA
920867 }
921868};
922869
870+ // / Lower `vector.transfer_read` of a scalable column to `scf::for`
871+ // /
872+ // / Lowers a "read" of a scalable column from a MemRef for which there is no
873+ // / hardware pperation that we could use to a loop over the rows to read and
874+ // / loads one element at a time.
875+ // /
876+ // / BEFORE:
877+ // / ```
878+ // / %res = vector.transfer_read %mem[%a, %b] (...)
879+ // / : memref<?x?xf32>, vector<[4]x1xf32>
880+ // / ```
881+ // /
882+ // / AFTER:
883+ // / ```
884+ // / %cst = arith.constant (...) : vector<[4]xf32>
885+ // / %vscale = vector.vscale
886+ // / %c4_vscale = arith.muli %vscale, %c4 : index
887+ // / %scf = scf.for %lb = %c0 to %c4_vscale step %c1 iter_args(%arg4 = %cst)
888+ // / -> (vector<[4]xf32>) {
889+ // /
890+ // / %load = memref.load %mem[%arg3 + %a, %b] : memref<?x?xf32>
891+ // / %vec = vector.insert %load, %cst [%arg3] : f32 into vector<[4]xf32>
892+ // / scf.yield %vec : vector<[4]xf32>
893+ // / }
894+ // / %res = vector.shape_cast %scf : vector<[4]xf32> to vector<[4]x1xf32>
895+ // / ```
896+ // /
897+ // / TODO: This transformation isn't specific to SME - move it to the SVE
898+ // / dialect.
899+ // / TODO: Check the in_bounds attribute and generate vector.maskedload if
900+ // / required.
901+ struct LowerColumnTransferReadToLoops
902+ : public OpRewritePattern<vector::TransferReadOp> {
903+ using OpRewritePattern::OpRewritePattern;
904+
905+ LogicalResult matchAndRewrite (vector::TransferReadOp readOp,
906+ PatternRewriter &rewriter) const override {
907+ // NOTE: This is a fairly low-level transformation, so we shouldn't be
908+ // adding support for Tensors without good rationale.
909+ if (readOp.hasPureTensorSemantics ())
910+ return rewriter.notifyMatchFailure (
911+ readOp, " Tensor semantics are unsupported (either bufferize or "
912+ " extend this pattern)" );
913+
914+ auto resType = readOp.getVectorType ();
915+
916+ if (resType.getRank () != 2 )
917+ return rewriter.notifyMatchFailure (readOp,
918+ " Only 2D vectors are supported!" );
919+
920+ if (resType.getShape ()[1 ] != 1 )
921+ return rewriter.notifyMatchFailure (
922+ readOp, " The trailing output dim is != 1 (not supported ATM)" );
923+
924+ if (!resType.getScalableDims ()[0 ] || resType.getScalableDims ()[1 ])
925+ return rewriter.notifyMatchFailure (
926+ readOp, " Expected the leading dim to be scalable and the trailing "
927+ " dim to be fixed." );
928+
929+ // Create new result type - similar to the original vector with the
930+ // trailing unit dim collapsed.
931+ int64_t numRows = resType.getShape ()[0 ];
932+ VectorType newResType = VectorType::get (numRows, resType.getElementType (),
933+ /* scalableDims=*/ {true });
934+
935+ // Create a loop over all rows and load one element at a time.
936+ auto loc = readOp.getLoc ();
937+ auto lowerBound = rewriter.create <arith::ConstantIndexOp>(loc, 0 );
938+ auto createVscaleMultiple =
939+ vector::makeVscaleConstantBuilder (rewriter, loc);
940+ auto upperBound = createVscaleMultiple (numRows);
941+ auto step = rewriter.create <arith::ConstantIndexOp>(loc, 1 );
942+ Value init = rewriter.create <arith::ConstantOp>(
943+ loc, newResType, DenseElementsAttr::get (newResType, 0 .0f ));
944+
945+ scf::ForOp loadLoop;
946+ {
947+ OpBuilder::InsertionGuard g (rewriter);
948+ loadLoop = rewriter.create <scf::ForOp>(loc, lowerBound, upperBound, step,
949+ ValueRange{init});
950+ rewriter.setInsertionPointToStart (loadLoop.getBody ());
951+
952+ auto tileSliceIndex = loadLoop.getInductionVar ();
953+
954+ auto idx0 = rewriter.create <arith::AddIOp>(loc, tileSliceIndex,
955+ readOp.getIndices ()[0 ]);
956+ auto idx1 = readOp.getIndices ()[1 ];
957+
958+ Value scalar = rewriter.create <memref::LoadOp>(
959+ loc, readOp.getBase (), SmallVector<Value>({idx0, idx1}));
960+
961+ Operation *updateInit = rewriter.create <vector::InsertOp>(
962+ loc, scalar, loadLoop.getRegionIterArg (0 ), tileSliceIndex);
963+
964+ rewriter.create <scf::YieldOp>(loc, updateInit->getResult (0 ));
965+ }
966+
967+ // The read operation has been "legalized", but since the original result
968+ // type was a 2D vector, we need to cast before returning the result. This
969+ // ShapeCast should cancel-out with some other ShapeCast (i.e. it's a
970+ // no-op).
971+ auto sc = rewriter.create <vector::ShapeCastOp>(
972+ loc, readOp.getResult ().getType (), loadLoop.getResult (0 ));
973+
974+ rewriter.replaceOp (readOp, sc);
975+
976+ return success ();
977+ }
978+ };
979+
923980struct VectorLegalizationPass
924981 : public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
925982 void runOnOperation () override {
@@ -941,10 +998,10 @@ struct VectorLegalizationPass
941998
942999 // Apply preprocessing patterns.
9431000 RewritePatternSet rewritePatterns (context);
944- rewritePatterns. add <FoldExtractFromVectorOfSMELikeCreateMasks,
945- LiftIllegalVectorTransposeToMemory ,
946- ConvertIllegalShapeCastOpsToTransposes ,
947- LowerIllegalTransposeStoreViaZA>(context);
1001+ rewritePatterns
1002+ . add <FoldExtractFromVectorOfSMELikeCreateMasks ,
1003+ LowerColumnTransferReadToLoops, LiftIllegalVectorTransposeToMemory ,
1004+ LowerIllegalTransposeStoreViaZA>(context);
9481005 if (failed (
9491006 applyPatternsGreedily (getOperation (), std::move (rewritePatterns))))
9501007 return signalPassFailure ();
0 commit comments