@@ -674,6 +674,93 @@ struct LinearizeVectorCreateMask final
674674 }
675675};
676676
677+ // / This pattern linearizes vector.load from vector<1x1x...xN> to vector<N>
678+ // / It currently supports linearization where all but the last dimension are 1
679+ // / The following,
680+ // / vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32>
681+ // / is converted to:
682+ // / vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<4xf32>
683+ // / vector.shape_cast %load_result : vector<4xf32> to vector<1x4xf32>
684+ // / For generic cases, the vector unroll pass should be used to unroll the load
685+ // / to vector<1x1x...xN> form and then linearized
686+ struct LinearizeVectorLoad final : public OpConversionPattern<vector::LoadOp> {
687+ using OpConversionPattern::OpConversionPattern;
688+ LinearizeVectorLoad (const TypeConverter &typeConverter, MLIRContext *context,
689+ PatternBenefit benefit = 1 )
690+ : OpConversionPattern(typeConverter, context, benefit) {}
691+
692+ LogicalResult
693+ matchAndRewrite (vector::LoadOp loadOp, OpAdaptor adaptor,
694+ ConversionPatternRewriter &rewriter) const override {
695+ VectorType vecTy = loadOp.getType ();
696+ if (!vecTy)
697+ return rewriter.notifyMatchFailure (loadOp, " expected vector type" );
698+
699+ auto shape = vecTy.getShape ();
700+ auto scalableDims = vecTy.getScalableDims ();
701+ // All but the last dim must be 1, and only the last dim may be scalable (if
702+ // any).
703+ if (!llvm::all_of (shape.drop_back (1 ), [](auto d) { return d == 1 ; }))
704+ return rewriter.notifyMatchFailure (loadOp,
705+ " only vector<1x1x...xN> supported" );
706+
707+ if (llvm::any_of (scalableDims.drop_back (1 ), [](bool s) { return s; }))
708+ return rewriter.notifyMatchFailure (loadOp,
709+ " only innermost dim may be scalable" );
710+
711+ auto linearTy = typeConverter->convertType <VectorType>(vecTy);
712+
713+ auto newLoad = rewriter.create <vector::LoadOp>(
714+ loadOp.getLoc (), linearTy, adaptor.getBase (), adaptor.getIndices ());
715+ rewriter.replaceOp (loadOp, newLoad.getResult ());
716+ return success ();
717+ }
718+ };
719+
720+ // / This pattern linearizes vector.store from vector<1x1x...xN> to vector<N>
721+ // / It currently supports linearization where all but the last dimension are 1
722+ // / The following,
723+ // / vector.store %arg0, %arg1[%c0, %c0]s
724+ // / : vector<1x4xf32>, memref<1x4xf32>
725+ // / is converted to:
726+ // / vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
727+ // / vector.store %arg0, %arg1[%c0, %c0]
728+ // / : vector<4xf32>, memref<1x4xf32>
729+ // / For generic cases, the vector unroll pass should be used to unroll the store
730+ // / to vector<1x1x...xN> form and then linearized
731+ struct LinearizeVectorStore final
732+ : public OpConversionPattern<vector::StoreOp> {
733+ using OpConversionPattern::OpConversionPattern;
734+ LinearizeVectorStore (const TypeConverter &typeConverter, MLIRContext *context,
735+ PatternBenefit benefit = 1 )
736+ : OpConversionPattern(typeConverter, context, benefit) {}
737+
738+ LogicalResult
739+ matchAndRewrite (vector::StoreOp storeOp, OpAdaptor adaptor,
740+ ConversionPatternRewriter &rewriter) const override {
741+ VectorType vecTy = storeOp.getValueToStore ().getType ();
742+ if (!vecTy)
743+ return rewriter.notifyMatchFailure (storeOp, " expected vector type" );
744+
745+ auto shape = vecTy.getShape ();
746+ auto scalableDims = vecTy.getScalableDims ();
747+ // All but the last dim must be 1, and only the last dim may be scalable (if
748+ // any).
749+ if (!llvm::all_of (shape.drop_back (1 ), [](auto d) { return d == 1 ; }))
750+ return rewriter.notifyMatchFailure (storeOp,
751+ " only vector<1x1x...xN> supported" );
752+
753+ if (llvm::any_of (scalableDims.drop_back (1 ), [](bool s) { return s; }))
754+ return rewriter.notifyMatchFailure (storeOp,
755+ " only innermost dim may be scalable" );
756+
757+ rewriter.replaceOpWithNewOp <vector::StoreOp>(
758+ storeOp, adaptor.getValueToStore (), adaptor.getBase (),
759+ adaptor.getIndices ());
760+ return success ();
761+ }
762+ };
763+
677764} // namespace
678765
679766// / This method defines the set of operations that are linearizable, and hence
@@ -765,8 +852,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns(
765852 RewritePatternSet &patterns) {
766853 patterns
767854 .add <LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
768- LinearizeVectorSplat, LinearizeVectorCreateMask>(
769- typeConverter, patterns.getContext ());
855+ LinearizeVectorSplat, LinearizeVectorCreateMask, LinearizeVectorLoad,
856+ LinearizeVectorStore>( typeConverter, patterns.getContext ());
770857}
771858
772859void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns (
0 commit comments