@@ -798,6 +798,51 @@ struct LinearizeVectorFromElements final
798798 }
799799};
800800
801+ // / This pattern linearizes the operand in `vector.to_elements` operations
802+ // / by converting the source type to a 1-D vector while preserving all element
803+ // / values. The transformation creates a linearized `vector.shape_cast`
804+ // / followed by a `vector.to_elements`.
805+ // /
806+ // / Example:
807+ // /
808+ // / %0:4 = vector.to_elements %v : vector<2x2xf32>
809+ // /
810+ // / is converted to:
811+ // /
812+ // / %vector_cast = vector.shape_cast %v : vector<2x2xf32> to vector<4xf32>
813+ // / %0:4 = vector.to_elements %vector_cast : vector<4xf32>
814+ // /
815+ struct LinearizeVectorToElements final
816+ : public OpConversionPattern<vector::ToElementsOp> {
817+ using OpConversionPattern::OpConversionPattern;
818+
819+ LinearizeVectorToElements (const TypeConverter &typeConverter,
820+ MLIRContext *context, PatternBenefit benefit = 1 )
821+ : OpConversionPattern(typeConverter, context, benefit) {}
822+
823+ LogicalResult
824+ matchAndRewrite (vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
825+ ConversionPatternRewriter &rewriter) const override {
826+
827+ VectorType vecType = toElementsOp.getSource ().getType ();
828+ if (vecType.getRank () <= 1 )
829+ return rewriter.notifyMatchFailure (
830+ toElementsOp, " the rank is already less than or equal to 1" );
831+
832+ assert (vecType.getNumScalableDims () == 0 &&
833+ " to_elements does not support scalable vectors" );
834+ auto vec1DType =
835+ VectorType::get ({vecType.getNumElements ()}, vecType.getElementType ());
836+ Value shapeCast = vector::ShapeCastOp::create (
837+ rewriter, toElementsOp.getLoc (), vec1DType, toElementsOp.getSource ());
838+ auto newToElementsOp =
839+ vector::ToElementsOp::create (rewriter, toElementsOp.getLoc (),
840+ toElementsOp.getResultTypes (), shapeCast);
841+ rewriter.replaceOp (toElementsOp, newToElementsOp);
842+ return success ();
843+ }
844+ };
845+
801846} // namespace
802847
803848// / This method defines the set of operations that are linearizable, and hence
@@ -890,8 +935,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns(
890935 patterns
891936 .add <LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
892937 LinearizeVectorSplat, LinearizeVectorCreateMask, LinearizeVectorLoad,
893- LinearizeVectorStore, LinearizeVectorFromElements>(
894- typeConverter, patterns.getContext ());
938+ LinearizeVectorStore, LinearizeVectorFromElements,
939+ LinearizeVectorToElements>( typeConverter, patterns.getContext ());
895940}
896941
897942void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns (
0 commit comments