@@ -623,6 +623,73 @@ struct LinearizeVectorCreateMask final
623623 }
624624};
625625
626+ // / This pattern linearizes vector.load from vector<1xN> to vector<N>.
627+ // / It currently supports only lineariztion of <1XN> to <N>
628+ // / Following,
629+ // / vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32>
630+ // / is converted to:
631+ // / vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<4xf32>
632+ // / vector.shape_cast %load_result : vector<4xf32> to vector<1x4xf32>
633+ struct LinearizeVectorLoad final : public OpConversionPattern<vector::LoadOp> {
634+ using OpConversionPattern::OpConversionPattern;
635+ LinearizeVectorLoad (const TypeConverter &typeConverter, MLIRContext *context,
636+ PatternBenefit benefit = 1 )
637+ : OpConversionPattern(typeConverter, context, benefit) {}
638+
639+ LogicalResult
640+ matchAndRewrite (vector::LoadOp loadOp, OpAdaptor adaptor,
641+ ConversionPatternRewriter &rewriter) const override {
642+ VectorType vecTy = loadOp.getType ();
643+ if (!vecTy || vecTy.getRank () != 2 || vecTy.getShape ()[0 ] != 1 )
644+ return rewriter.notifyMatchFailure (loadOp, " only vector<1xN> supported" );
645+ auto linearTy = VectorType::get (vecTy.getShape ()[1 ], vecTy.getElementType (),
646+ vecTy.isScalable ());
647+ auto newLoad = rewriter.create <vector::LoadOp>(
648+ loadOp.getLoc (), linearTy, adaptor.getBase (), adaptor.getIndices ());
649+ auto shapeCast = rewriter.create <vector::ShapeCastOp>(
650+ loadOp.getLoc (), vecTy, newLoad.getResult ());
651+ rewriter.replaceOp (loadOp, shapeCast.getResult ());
652+ return success ();
653+ }
654+ };
655+
656+ // / This pattern linearizes vector.store from vector<1xN> to vector<N>.
657+ // / It currently supports only lineariztion of <1XN> to <N>
658+ // / Following,
659+ // / vector.store %arg0, %arg1[%c0, %c0]
660+ // / : vector<1x4xf32>, memref<1x4xf32>
661+ // / is converted to:
662+ // / vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
663+ // / vector.store %arg0, %arg1[%c0, %%c0]
664+ // / : vector<4xf32>, memref<1x4xf32>
665+ struct LinearizeVectorStore final
666+ : public OpConversionPattern<vector::StoreOp> {
667+ using OpConversionPattern::OpConversionPattern;
668+ LinearizeVectorStore (const TypeConverter &typeConverter, MLIRContext *context,
669+ PatternBenefit benefit = 1 )
670+ : OpConversionPattern(typeConverter, context, benefit) {}
671+
672+ LogicalResult
673+ matchAndRewrite (vector::StoreOp storeOp, OpAdaptor adaptor,
674+ ConversionPatternRewriter &rewriter) const override {
675+ VectorType vecTy = storeOp.getValueToStore ().getType ();
676+ if (!vecTy || vecTy.getRank () != 2 || vecTy.getShape ()[0 ] != 1 )
677+ return rewriter.notifyMatchFailure (storeOp, " only vector<1xN> supported" );
678+ auto linearTy = VectorType::get (vecTy.getShape ()[1 ], vecTy.getElementType (),
679+ vecTy.isScalable ());
680+
681+ Value valueToStore = adaptor.getValueToStore ();
682+ if (valueToStore.getType () != linearTy) {
683+ valueToStore = rewriter.create <vector::ShapeCastOp>(
684+ storeOp.getLoc (), linearTy, valueToStore);
685+ }
686+
687+ rewriter.replaceOpWithNewOp <vector::StoreOp>(
688+ storeOp, valueToStore, adaptor.getBase (), adaptor.getIndices ());
689+ return success ();
690+ }
691+ };
692+
626693} // namespace
627694
628695// / This method defines the set of operations that are linearizable, and hence
@@ -714,8 +781,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns(
714781 RewritePatternSet &patterns) {
715782 patterns
716783 .add <LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
717- LinearizeVectorSplat, LinearizeVectorCreateMask>(
718- typeConverter, patterns.getContext ());
784+ LinearizeVectorSplat, LinearizeVectorCreateMask, LinearizeVectorLoad,
785+ LinearizeVectorStore>( typeConverter, patterns.getContext ());
719786}
720787
721788void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns (
0 commit comments