Skip to content

Commit 04431c5

Browse files
committed
Copy over UnrollFromElements
1 parent 5c3e7d5 commit 04431c5

File tree

2 files changed

+52
-2
lines changed

2 files changed

+52
-2
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,14 @@ void populateVectorUnrollPatterns(RewritePatternSet &patterns,
328328
void populateVectorToElementsUnrollPatterns(RewritePatternSet &patterns,
329329
PatternBenefit benefit = 1);
330330

331+
/// Populate the pattern set with the following patterns:
332+
///
333+
/// [UnrollFromElements]
334+
/// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the
335+
/// outermost dimension.
336+
void populateVectorFromElementsUnrollPatterns(RewritePatternSet &patterns,
337+
PatternBenefit benefit = 1);
338+
331339
/// Collect a set of leading one dimension removal patterns.
332340
///
333341
/// These patterns insert vector.shape_cast to remove leading one dimensions

mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -835,13 +835,50 @@ struct UnrollToElements final : public OpRewritePattern<vector::ToElementsOp> {
835835
}
836836
};
837837

838+
/// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the
839+
/// outermost dimension. For example:
840+
/// ```
841+
/// %v = vector.from_elements %e0, %e1, %e2, %e3, %e4, %e5 : vector<2x3xf32>
842+
///
843+
/// ==>
844+
///
845+
/// %0 = ub.poison : vector<2x3xf32>
846+
/// %v0 = vector.from_elements %e0, %e1, %e2 : vector<3xf32>
847+
/// %1 = vector.insert %v0, %0 [0] : vector<3xf32> into vector<2x3xf32>
848+
/// %v1 = vector.from_elements %e3, %e4, %e5 : vector<3xf32>
849+
/// %v = vector.insert %v1, %1 [1] : vector<3xf32> into vector<2x3xf32>
850+
/// ```
851+
///
852+
/// When applied exhaustively, this will produce a sequence of 1-d from_elements
853+
/// ops.
854+
struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
855+
using OpRewritePattern::OpRewritePattern;
856+
857+
LogicalResult matchAndRewrite(vector::FromElementsOp op,
858+
PatternRewriter &rewriter) const override {
859+
ValueRange allElements = op.getElements();
860+
861+
auto unrollFromElementsFn = [&](PatternRewriter &rewriter, Location loc,
862+
VectorType subTy, int64_t index) {
863+
size_t subTyNumElements = subTy.getNumElements();
864+
assert((index + 1) * subTyNumElements <= allElements.size() &&
865+
"out of bounds");
866+
ValueRange subElements =
867+
allElements.slice(index * subTyNumElements, subTyNumElements);
868+
return vector::FromElementsOp::create(rewriter, loc, subTy, subElements);
869+
};
870+
871+
return unrollVectorOp(op, rewriter, unrollFromElementsFn);
872+
}
873+
};
874+
838875
} // namespace
839876

840877
void mlir::vector::populateVectorUnrollPatterns(
841878
RewritePatternSet &patterns, const UnrollVectorOptions &options,
842879
PatternBenefit benefit) {
843-
populateVectorFromElementsLoweringPatterns(patterns);
844-
patterns.add<UnrollToElements>(patterns.getContext(), benefit);
880+
patterns.add<UnrollFromElements, UnrollToElements>(patterns.getContext(),
881+
benefit);
845882
patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
846883
UnrollContractionPattern, UnrollElementwisePattern,
847884
UnrollReductionPattern, UnrollMultiReductionPattern,
@@ -854,3 +891,8 @@ void mlir::vector::populateVectorToElementsUnrollPatterns(
854891
RewritePatternSet &patterns, PatternBenefit benefit) {
855892
patterns.add<UnrollToElements>(patterns.getContext(), benefit);
856893
}
894+
895+
void mlir::vector::populateVectorFromElementsUnrollPatterns(
896+
RewritePatternSet &patterns, PatternBenefit benefit) {
897+
patterns.add<UnrollFromElements>(patterns.getContext(), benefit);
898+
}

0 commit comments

Comments
 (0)