Skip to content

Commit 8b9c70d

Browse files
authored
[mlir] Move vector.{to_elements,from_elements} unrolling to VectorUnroll.cpp (llvm#159118)
This PR moves the patterns that unroll vector.to_elements and vector.from_elements into the file with other vector unrolling operations. This PR also adds these unrolling patterns into the `populateVectorUnrollPatterns`. And renames `populateVectorToElementsLoweringPatterns` `populateVectorFromElementsLoweringPatterns` to `populateVectorToElementsUnrollPatterns` `populateVectorFromElementsUnrollPatterns`.
1 parent 24504c3 commit 8b9c70d

File tree

11 files changed

+170
-142
lines changed

11 files changed

+170
-142
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -306,20 +306,6 @@ void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns);
306306
void populateVectorToFromElementsToShuffleTreePatterns(
307307
RewritePatternSet &patterns, PatternBenefit benefit = 1);
308308

309-
/// Populate the pattern set with the following patterns:
310-
///
311-
/// [UnrollFromElements]
312-
/// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the
313-
/// outermost dimension.
314-
void populateVectorFromElementsLoweringPatterns(RewritePatternSet &patterns,
315-
PatternBenefit benefit = 1);
316-
317-
/// Populate the pattern set with the following patterns:
318-
///
319-
/// [UnrollToElements]
320-
void populateVectorToElementsLoweringPatterns(RewritePatternSet &patterns,
321-
PatternBenefit benefit = 1);
322-
323309
/// Populate the pattern set with the following patterns:
324310
///
325311
/// [ContractionOpToMatmulOpLowering]

mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,16 @@ void populateVectorUnrollPatterns(RewritePatternSet &patterns,
322322
const UnrollVectorOptions &options,
323323
PatternBenefit benefit = 1);
324324

325+
/// Unrolls 2 or more dimensional `vector.to_elements` ops by unrolling the
326+
/// outermost dimension of the operand.
327+
void populateVectorToElementsUnrollPatterns(RewritePatternSet &patterns,
328+
PatternBenefit benefit = 1);
329+
330+
/// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the
331+
/// outermost dimension.
332+
void populateVectorFromElementsUnrollPatterns(RewritePatternSet &patterns,
333+
PatternBenefit benefit = 1);
334+
325335
/// Collect a set of leading one dimension removal patterns.
326336
///
327337
/// These patterns insert vector.shape_cast to remove leading one dimensions

mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -534,7 +534,7 @@ void GpuToLLVMConversionPass::runOnOperation() {
534534
/*maxTransferRank=*/1);
535535
// Transform N-D vector.from_elements to 1-D vector.from_elements before
536536
// conversion.
537-
vector::populateVectorFromElementsLoweringPatterns(patterns);
537+
vector::populateVectorFromElementsUnrollPatterns(patterns);
538538
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
539539
return signalPassFailure();
540540
}

mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,7 @@ struct LowerGpuOpsToNVVMOpsPass final
372372
populateGpuRewritePatterns(patterns);
373373
// Transform N-D vector.from_elements to 1-D vector.from_elements before
374374
// conversion.
375-
vector::populateVectorFromElementsLoweringPatterns(patterns);
375+
vector::populateVectorFromElementsUnrollPatterns(patterns);
376376
if (failed(applyPatternsGreedily(m, std::move(patterns))))
377377
return signalPassFailure();
378378
}

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,8 @@ void ConvertVectorToLLVMPass::runOnOperation() {
9494
populateVectorStepLoweringPatterns(patterns);
9595
populateVectorRankReducingFMAPattern(patterns);
9696
populateVectorGatherLoweringPatterns(patterns);
97-
populateVectorFromElementsLoweringPatterns(patterns);
98-
populateVectorToElementsLoweringPatterns(patterns);
97+
populateVectorFromElementsUnrollPatterns(patterns);
98+
populateVectorToElementsUnrollPatterns(patterns);
9999
if (armI8MM) {
100100
if (armNeon)
101101
arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns);

mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,12 +146,12 @@ void transform::ApplyLowerGatherPatternsOp::populatePatterns(
146146

147147
void transform::ApplyUnrollFromElementsPatternsOp::populatePatterns(
148148
RewritePatternSet &patterns) {
149-
vector::populateVectorFromElementsLoweringPatterns(patterns);
149+
vector::populateVectorFromElementsUnrollPatterns(patterns);
150150
}
151151

152152
void transform::ApplyUnrollToElementsPatternsOp::populatePatterns(
153153
RewritePatternSet &patterns) {
154-
vector::populateVectorToElementsLoweringPatterns(patterns);
154+
vector::populateVectorToElementsUnrollPatterns(patterns);
155155
}
156156

157157
void transform::ApplyLowerScanPatternsOp::populatePatterns(

mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ add_mlir_dialect_library(MLIRVectorTransforms
33
LowerVectorBitCast.cpp
44
LowerVectorBroadcast.cpp
55
LowerVectorContract.cpp
6-
LowerVectorFromElements.cpp
76
LowerVectorGather.cpp
87
LowerVectorInterleave.cpp
98
LowerVectorMask.cpp
@@ -12,7 +11,6 @@ add_mlir_dialect_library(MLIRVectorTransforms
1211
LowerVectorShapeCast.cpp
1312
LowerVectorShuffle.cpp
1413
LowerVectorStep.cpp
15-
LowerVectorToElements.cpp
1614
LowerVectorToFromElementsToShuffleTree.cpp
1715
LowerVectorTransfer.cpp
1816
LowerVectorTranspose.cpp

mlir/lib/Dialect/Vector/Transforms/LowerVectorFromElements.cpp

Lines changed: 0 additions & 65 deletions
This file was deleted.

mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp

Lines changed: 0 additions & 53 deletions
This file was deleted.

mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp

Lines changed: 110 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1414
#include "mlir/Dialect/Utils/IndexingUtils.h"
15+
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
1516
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
1617
#include "mlir/Interfaces/VectorInterfaces.h"
1718
#include "llvm/ADT/MapVector.h"
@@ -809,6 +810,55 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
809810
vector::UnrollVectorOptions options;
810811
};
811812

813+
/// Unrolls 2 or more dimensional `vector.to_elements` ops by unrolling the
814+
/// outermost dimension of the operand. For example:
815+
///
816+
/// ```
817+
/// %0:4 = vector.to_elements %v : vector<2x2xf32>
818+
///
819+
/// ==>
820+
///
821+
/// %v0 = vector.extract %v[0] : vector<2x2xf32> from vector<2x2x2xf32>
822+
/// %v1 = vector.extract %v[1] : vector<2x2xf32> from vector<2x2x2xf32>
823+
/// %0:4 = vector.to_elements %v0 : vector<2x2xf32>
824+
/// %1:4 = vector.to_elements %v1 : vector<2x2xf32>
825+
/// ```
826+
///
827+
/// When this pattern is applied until a fixed-point is reached,
828+
/// this will produce a sequence of 1-d from_elements
829+
/// ops.
830+
struct UnrollToElements final : public OpRewritePattern<vector::ToElementsOp> {
831+
UnrollToElements(MLIRContext *context,
832+
const vector::UnrollVectorOptions &options,
833+
PatternBenefit benefit = 1)
834+
: OpRewritePattern<vector::ToElementsOp>(context, benefit),
835+
options(options) {}
836+
837+
LogicalResult matchAndRewrite(vector::ToElementsOp op,
838+
PatternRewriter &rewriter) const override {
839+
840+
TypedValue<VectorType> source = op.getSource();
841+
FailureOr<SmallVector<Value>> result =
842+
vector::unrollVectorValue(source, rewriter);
843+
if (failed(result)) {
844+
return failure();
845+
}
846+
SmallVector<Value> vectors = *result;
847+
848+
SmallVector<Value> results;
849+
for (Value vector : vectors) {
850+
auto subElements =
851+
vector::ToElementsOp::create(rewriter, op.getLoc(), vector);
852+
llvm::append_range(results, subElements.getResults());
853+
}
854+
rewriter.replaceOp(op, results);
855+
return success();
856+
}
857+
858+
private:
859+
vector::UnrollVectorOptions options;
860+
};
861+
812862
/// This pattern unrolls `vector.step` operations according to the provided
813863
/// target unroll shape. It decomposes a large step vector into smaller step
814864
/// vectors (segments) and assembles the result by inserting each computed
@@ -884,6 +934,51 @@ struct UnrollStepPattern : public OpRewritePattern<vector::StepOp> {
884934
vector::UnrollVectorOptions options;
885935
};
886936

937+
/// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the
938+
/// outermost dimension. For example:
939+
/// ```
940+
/// %v = vector.from_elements %e0, %e1, %e2, %e3, %e4, %e5 : vector<2x3xf32>
941+
///
942+
/// ==>
943+
///
944+
/// %0 = ub.poison : vector<2x3xf32>
945+
/// %v0 = vector.from_elements %e0, %e1, %e2 : vector<3xf32>
946+
/// %1 = vector.insert %v0, %0 [0] : vector<3xf32> into vector<2x3xf32>
947+
/// %v1 = vector.from_elements %e3, %e4, %e5 : vector<3xf32>
948+
/// %v = vector.insert %v1, %1 [1] : vector<3xf32> into vector<2x3xf32>
949+
/// ```
950+
///
951+
/// When this pattern is applied until a fixed-point is reached,
952+
/// this will produce a sequence of 1-d from_elements
953+
/// ops.
954+
struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
955+
UnrollFromElements(MLIRContext *context,
956+
const vector::UnrollVectorOptions &options,
957+
PatternBenefit benefit = 1)
958+
: OpRewritePattern<vector::FromElementsOp>(context, benefit),
959+
options(options) {}
960+
961+
LogicalResult matchAndRewrite(vector::FromElementsOp op,
962+
PatternRewriter &rewriter) const override {
963+
ValueRange allElements = op.getElements();
964+
965+
auto unrollFromElementsFn = [&](PatternRewriter &rewriter, Location loc,
966+
VectorType subTy, int64_t index) {
967+
size_t subTyNumElements = subTy.getNumElements();
968+
assert((index + 1) * subTyNumElements <= allElements.size() &&
969+
"out of bounds");
970+
ValueRange subElements =
971+
allElements.slice(index * subTyNumElements, subTyNumElements);
972+
return vector::FromElementsOp::create(rewriter, loc, subTy, subElements);
973+
};
974+
975+
return unrollVectorOp(op, rewriter, unrollFromElementsFn);
976+
}
977+
978+
private:
979+
vector::UnrollVectorOptions options;
980+
};
981+
887982
} // namespace
888983

889984
void mlir::vector::populateVectorUnrollPatterns(
@@ -893,6 +988,19 @@ void mlir::vector::populateVectorUnrollPatterns(
893988
UnrollContractionPattern, UnrollElementwisePattern,
894989
UnrollReductionPattern, UnrollMultiReductionPattern,
895990
UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
896-
UnrollStorePattern, UnrollBroadcastPattern, UnrollStepPattern>(
897-
patterns.getContext(), options, benefit);
991+
UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
992+
UnrollToElements, UnrollStepPattern>(patterns.getContext(),
993+
options, benefit);
994+
}
995+
996+
void mlir::vector::populateVectorToElementsUnrollPatterns(
997+
RewritePatternSet &patterns, PatternBenefit benefit) {
998+
patterns.add<UnrollToElements>(patterns.getContext(), UnrollVectorOptions(),
999+
benefit);
1000+
}
1001+
1002+
void mlir::vector::populateVectorFromElementsUnrollPatterns(
1003+
RewritePatternSet &patterns, PatternBenefit benefit) {
1004+
patterns.add<UnrollFromElements>(patterns.getContext(), UnrollVectorOptions(),
1005+
benefit);
8981006
}

0 commit comments

Comments
 (0)