@@ -835,13 +835,50 @@ struct UnrollToElements final : public OpRewritePattern<vector::ToElementsOp> {
835835 }
836836};
837837
838+ // / Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the
839+ // / outermost dimension. For example:
840+ // / ```
841+ // / %v = vector.from_elements %e0, %e1, %e2, %e3, %e4, %e5 : vector<2x3xf32>
842+ // /
843+ // / ==>
844+ // /
845+ // / %0 = ub.poison : vector<2x3xf32>
846+ // / %v0 = vector.from_elements %e0, %e1, %e2 : vector<3xf32>
847+ // / %1 = vector.insert %v0, %0 [0] : vector<3xf32> into vector<2x3xf32>
848+ // / %v1 = vector.from_elements %e3, %e4, %e5 : vector<3xf32>
849+ // / %v = vector.insert %v1, %1 [1] : vector<3xf32> into vector<2x3xf32>
850+ // / ```
851+ // /
852+ // / When applied exhaustively, this will produce a sequence of 1-d from_elements
853+ // / ops.
854+ struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
855+ using OpRewritePattern::OpRewritePattern;
856+
857+ LogicalResult matchAndRewrite (vector::FromElementsOp op,
858+ PatternRewriter &rewriter) const override {
859+ ValueRange allElements = op.getElements ();
860+
861+ auto unrollFromElementsFn = [&](PatternRewriter &rewriter, Location loc,
862+ VectorType subTy, int64_t index) {
863+ size_t subTyNumElements = subTy.getNumElements ();
864+ assert ((index + 1 ) * subTyNumElements <= allElements.size () &&
865+ " out of bounds" );
866+ ValueRange subElements =
867+ allElements.slice (index * subTyNumElements, subTyNumElements);
868+ return vector::FromElementsOp::create (rewriter, loc, subTy, subElements);
869+ };
870+
871+ return unrollVectorOp (op, rewriter, unrollFromElementsFn);
872+ }
873+ };
874+
838875} // namespace
839876
840877void mlir::vector::populateVectorUnrollPatterns (
841878 RewritePatternSet &patterns, const UnrollVectorOptions &options,
842879 PatternBenefit benefit) {
843- populateVectorFromElementsLoweringPatterns (patterns);
844- patterns. add <UnrollToElements>(patterns. getContext (), benefit);
880+ patterns. add <UnrollFromElements, UnrollToElements> (patterns. getContext (),
881+ benefit);
845882 patterns.add <UnrollTransferReadPattern, UnrollTransferWritePattern,
846883 UnrollContractionPattern, UnrollElementwisePattern,
847884 UnrollReductionPattern, UnrollMultiReductionPattern,
@@ -854,3 +891,8 @@ void mlir::vector::populateVectorToElementsUnrollPatterns(
854891 RewritePatternSet &patterns, PatternBenefit benefit) {
855892 patterns.add <UnrollToElements>(patterns.getContext (), benefit);
856893}
894+
895+ void mlir::vector::populateVectorFromElementsUnrollPatterns (
896+ RewritePatternSet &patterns, PatternBenefit benefit) {
897+ patterns.add <UnrollFromElements>(patterns.getContext (), benefit);
898+ }
0 commit comments