@@ -798,6 +798,51 @@ struct LinearizeVectorFromElements final
798
798
}
799
799
};
800
800
801
+ // / This pattern linearizes the operand in `vector.to_elements` operations
802
+ // / by converting the source type to a 1-D vector while preserving all element
803
+ // / values. The transformation creates a linearized `vector.shape_cast`
804
+ // / followed by a `vector.to_elements`.
805
+ // /
806
+ // / Example:
807
+ // /
808
+ // / %0:4 = vector.to_elements %v : vector<2x2xf32>
809
+ // /
810
+ // / is converted to:
811
+ // /
812
+ // / %vector_cast = vector.shape_cast %v : vector<2x2xf32> to vector<4xf32>
813
+ // / %0:4 = vector.to_elements %vector_cast : vector<4xf32>
814
+ // /
815
+ struct LinearizeVectorToElements final
816
+ : public OpConversionPattern<vector::ToElementsOp> {
817
+ using OpConversionPattern::OpConversionPattern;
818
+
819
+ LinearizeVectorToElements (const TypeConverter &typeConverter,
820
+ MLIRContext *context, PatternBenefit benefit = 1 )
821
+ : OpConversionPattern(typeConverter, context, benefit) {}
822
+
823
+ LogicalResult
824
+ matchAndRewrite (vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
825
+ ConversionPatternRewriter &rewriter) const override {
826
+
827
+ VectorType vecType = toElementsOp.getSource ().getType ();
828
+ if (vecType.getRank () <= 1 )
829
+ return rewriter.notifyMatchFailure (
830
+ toElementsOp, " the rank is already less than or equal to 1" );
831
+
832
+ assert (vecType.getNumScalableDims () == 0 &&
833
+ " to_elements does not support scalable vectors" );
834
+ auto vec1DType =
835
+ VectorType::get ({vecType.getNumElements ()}, vecType.getElementType ());
836
+ Value shapeCast = vector::ShapeCastOp::create (
837
+ rewriter, toElementsOp.getLoc (), vec1DType, toElementsOp.getSource ());
838
+ auto newToElementsOp =
839
+ vector::ToElementsOp::create (rewriter, toElementsOp.getLoc (),
840
+ toElementsOp.getResultTypes (), shapeCast);
841
+ rewriter.replaceOp (toElementsOp, newToElementsOp);
842
+ return success ();
843
+ }
844
+ };
845
+
801
846
} // namespace
802
847
803
848
// / This method defines the set of operations that are linearizable, and hence
@@ -890,8 +935,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns(
890
935
patterns
891
936
.add <LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
892
937
LinearizeVectorSplat, LinearizeVectorCreateMask, LinearizeVectorLoad,
893
- LinearizeVectorStore, LinearizeVectorFromElements>(
894
- typeConverter, patterns.getContext ());
938
+ LinearizeVectorStore, LinearizeVectorFromElements,
939
+ LinearizeVectorToElements>( typeConverter, patterns.getContext ());
895
940
}
896
941
897
942
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns (
0 commit comments