12
12
13
13
#include " mlir/Dialect/Affine/IR/AffineOps.h"
14
14
#include " mlir/Dialect/Utils/IndexingUtils.h"
15
+ #include " mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
15
16
#include " mlir/Dialect/Vector/Transforms/VectorTransforms.h"
16
17
#include " mlir/Interfaces/VectorInterfaces.h"
17
18
#include " llvm/ADT/MapVector.h"
@@ -809,6 +810,55 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
809
810
vector::UnrollVectorOptions options;
810
811
};
811
812
813
+ // / Unrolls 2 or more dimensional `vector.to_elements` ops by unrolling the
814
+ // / outermost dimension of the operand. For example:
815
+ // /
816
+ // / ```
817
+ // / %0:4 = vector.to_elements %v : vector<2x2xf32>
818
+ // /
819
+ // / ==>
820
+ // /
821
+ // / %v0 = vector.extract %v[0] : vector<2x2xf32> from vector<2x2x2xf32>
822
+ // / %v1 = vector.extract %v[1] : vector<2x2xf32> from vector<2x2x2xf32>
823
+ // / %0:4 = vector.to_elements %v0 : vector<2x2xf32>
824
+ // / %1:4 = vector.to_elements %v1 : vector<2x2xf32>
825
+ // / ```
826
+ // /
827
+ // / When this pattern is applied until a fixed-point is reached,
828
+ // / this will produce a sequence of 1-d from_elements
829
+ // / ops.
830
+ struct UnrollToElements final : public OpRewritePattern<vector::ToElementsOp> {
831
+ UnrollToElements (MLIRContext *context,
832
+ const vector::UnrollVectorOptions &options,
833
+ PatternBenefit benefit = 1 )
834
+ : OpRewritePattern<vector::ToElementsOp>(context, benefit),
835
+ options (options) {}
836
+
837
+ LogicalResult matchAndRewrite (vector::ToElementsOp op,
838
+ PatternRewriter &rewriter) const override {
839
+
840
+ TypedValue<VectorType> source = op.getSource ();
841
+ FailureOr<SmallVector<Value>> result =
842
+ vector::unrollVectorValue (source, rewriter);
843
+ if (failed (result)) {
844
+ return failure ();
845
+ }
846
+ SmallVector<Value> vectors = *result;
847
+
848
+ SmallVector<Value> results;
849
+ for (Value vector : vectors) {
850
+ auto subElements =
851
+ vector::ToElementsOp::create (rewriter, op.getLoc (), vector);
852
+ llvm::append_range (results, subElements.getResults ());
853
+ }
854
+ rewriter.replaceOp (op, results);
855
+ return success ();
856
+ }
857
+
858
+ private:
859
+ vector::UnrollVectorOptions options;
860
+ };
861
+
812
862
// / This pattern unrolls `vector.step` operations according to the provided
813
863
// / target unroll shape. It decomposes a large step vector into smaller step
814
864
// / vectors (segments) and assembles the result by inserting each computed
@@ -884,6 +934,51 @@ struct UnrollStepPattern : public OpRewritePattern<vector::StepOp> {
884
934
vector::UnrollVectorOptions options;
885
935
};
886
936
937
+ // / Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the
938
+ // / outermost dimension. For example:
939
+ // / ```
940
+ // / %v = vector.from_elements %e0, %e1, %e2, %e3, %e4, %e5 : vector<2x3xf32>
941
+ // /
942
+ // / ==>
943
+ // /
944
+ // / %0 = ub.poison : vector<2x3xf32>
945
+ // / %v0 = vector.from_elements %e0, %e1, %e2 : vector<3xf32>
946
+ // / %1 = vector.insert %v0, %0 [0] : vector<3xf32> into vector<2x3xf32>
947
+ // / %v1 = vector.from_elements %e3, %e4, %e5 : vector<3xf32>
948
+ // / %v = vector.insert %v1, %1 [1] : vector<3xf32> into vector<2x3xf32>
949
+ // / ```
950
+ // /
951
+ // / When this pattern is applied until a fixed-point is reached,
952
+ // / this will produce a sequence of 1-d from_elements
953
+ // / ops.
954
+ struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
955
+ UnrollFromElements (MLIRContext *context,
956
+ const vector::UnrollVectorOptions &options,
957
+ PatternBenefit benefit = 1 )
958
+ : OpRewritePattern<vector::FromElementsOp>(context, benefit),
959
+ options (options) {}
960
+
961
+ LogicalResult matchAndRewrite (vector::FromElementsOp op,
962
+ PatternRewriter &rewriter) const override {
963
+ ValueRange allElements = op.getElements ();
964
+
965
+ auto unrollFromElementsFn = [&](PatternRewriter &rewriter, Location loc,
966
+ VectorType subTy, int64_t index) {
967
+ size_t subTyNumElements = subTy.getNumElements ();
968
+ assert ((index + 1 ) * subTyNumElements <= allElements.size () &&
969
+ " out of bounds" );
970
+ ValueRange subElements =
971
+ allElements.slice (index * subTyNumElements, subTyNumElements);
972
+ return vector::FromElementsOp::create (rewriter, loc, subTy, subElements);
973
+ };
974
+
975
+ return unrollVectorOp (op, rewriter, unrollFromElementsFn);
976
+ }
977
+
978
+ private:
979
+ vector::UnrollVectorOptions options;
980
+ };
981
+
887
982
} // namespace
888
983
889
984
void mlir::vector::populateVectorUnrollPatterns (
@@ -893,6 +988,19 @@ void mlir::vector::populateVectorUnrollPatterns(
893
988
UnrollContractionPattern, UnrollElementwisePattern,
894
989
UnrollReductionPattern, UnrollMultiReductionPattern,
895
990
UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
896
- UnrollStorePattern, UnrollBroadcastPattern, UnrollStepPattern>(
897
- patterns.getContext (), options, benefit);
991
+ UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
992
+ UnrollToElements, UnrollStepPattern>(patterns.getContext (),
993
+ options, benefit);
994
+ }
995
+
996
+ void mlir::vector::populateVectorToElementsUnrollPatterns (
997
+ RewritePatternSet &patterns, PatternBenefit benefit) {
998
+ patterns.add <UnrollToElements>(patterns.getContext (), UnrollVectorOptions (),
999
+ benefit);
1000
+ }
1001
+
1002
+ void mlir::vector::populateVectorFromElementsUnrollPatterns (
1003
+ RewritePatternSet &patterns, PatternBenefit benefit) {
1004
+ patterns.add <UnrollFromElements>(patterns.getContext (), UnrollVectorOptions (),
1005
+ benefit);
898
1006
}
0 commit comments