1111//
1212// ===----------------------------------------------------------------------===//
1313
14- #include " mlir/Dialect/Affine/IR/AffineOps.h"
1514#include " mlir/Dialect/Arith/IR/Arith.h"
16- #include " mlir/Dialect/Arith/Utils/Utils.h"
17- #include " mlir/Dialect/Linalg/IR/Linalg.h"
1815#include " mlir/Dialect/MemRef/IR/MemRef.h"
19- #include " mlir/Dialect/SCF/IR/SCF.h"
20- #include " mlir/Dialect/Tensor/IR/Tensor.h"
21- #include " mlir/Dialect/Utils/IndexingUtils.h"
22- #include " mlir/Dialect/Utils/StructuredOpsUtils.h"
2316#include " mlir/Dialect/Vector/IR/VectorOps.h"
2417#include " mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
2518#include " mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
2619#include " mlir/Dialect/Vector/Utils/VectorUtils.h"
27- #include " mlir/IR/BuiltinAttributeInterfaces.h"
2820#include " mlir/IR/BuiltinTypes.h"
29- #include " mlir/IR/ImplicitLocOpBuilder.h"
3021#include " mlir/IR/Location.h"
31- #include " mlir/IR/Matchers.h"
3222#include " mlir/IR/PatternMatch.h"
3323#include " mlir/IR/TypeUtilities.h"
34- #include " mlir/Interfaces/VectorInterfaces.h"
3524
3625#define DEBUG_TYPE " vector-shape-cast-lowering"
3726
3827using namespace mlir ;
3928using namespace mlir ::vector;
4029
30+ // / Increments n-D `indices` by `step` starting from the innermost dimension.
31+ static void incIdx (SmallVectorImpl<int64_t > &indices, VectorType vecType,
32+ int step = 1 ) {
33+ for (int dim : llvm::reverse (llvm::seq<int >(0 , indices.size ()))) {
34+ indices[dim] += step;
35+ if (indices[dim] < vecType.getDimSize (dim))
36+ break ;
37+
38+ indices[dim] = 0 ;
39+ step = 1 ;
40+ }
41+ }
42+
4143namespace {
42- // / ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D
43- // / vectors progressively on the way to target llvm.matrix intrinsics.
44- // / This iterates over the most major dimension of the 2-D vector and performs
45- // / rewrites into:
46- // / vector.extract from 2-D + vector.insert_strided_slice offset into 1-D
47- class ShapeCastOp2DDownCastRewritePattern
44+ // / ShapeOp n-D -> 1-D downcast serves the purpose of flattening N-D to 1-D
45+ // / vectors progressively. This iterates over the n-1 major dimensions of the
46+ // / n-D vector and performs rewrites into:
47+ // / vector.extract from n-D + vector.insert_strided_slice offset into 1-D
48+ class ShapeCastOpNDDownCastRewritePattern
4849 : public OpRewritePattern<vector::ShapeCastOp> {
4950public:
5051 using OpRewritePattern::OpRewritePattern;
@@ -53,35 +54,52 @@ class ShapeCastOp2DDownCastRewritePattern
5354 PatternRewriter &rewriter) const override {
5455 auto sourceVectorType = op.getSourceVectorType ();
5556 auto resultVectorType = op.getResultVectorType ();
56-
5757 if (sourceVectorType.isScalable () || resultVectorType.isScalable ())
5858 return failure ();
5959
60- if (sourceVectorType.getRank () != 2 || resultVectorType.getRank () != 1 )
60+ int64_t srcRank = sourceVectorType.getRank ();
61+ int64_t resRank = resultVectorType.getRank ();
62+ if (srcRank < 2 || resRank != 1 )
6163 return failure ();
6264
65+ // Compute the number of 1-D vector elements involved in the reshape.
66+ int64_t numElts = 1 ;
67+ for (int64_t dim = 0 ; dim < srcRank - 1 ; ++dim)
68+ numElts *= sourceVectorType.getDimSize (dim);
69+
6370 auto loc = op.getLoc ();
64- Value desc = rewriter.create <arith::ConstantOp>(
71+ SmallVector<int64_t > srcIdx (srcRank - 1 );
72+ SmallVector<int64_t > resIdx (resRank);
73+ int64_t extractSize = sourceVectorType.getShape ().back ();
74+ Value result = rewriter.create <arith::ConstantOp>(
6575 loc, resultVectorType, rewriter.getZeroAttr (resultVectorType));
66- unsigned mostMinorVectorSize = sourceVectorType.getShape ()[1 ];
67- for (int64_t i = 0 , e = sourceVectorType.getShape ().front (); i != e; ++i) {
68- Value vec = rewriter.create <vector::ExtractOp>(loc, op.getSource (), i);
69- desc = rewriter.create <vector::InsertStridedSliceOp>(
70- loc, vec, desc,
71- /* offsets=*/ i * mostMinorVectorSize, /* strides=*/ 1 );
76+
77+ // Compute the indices of each 1-D vector element of the source extraction
78+ // and destination slice insertion and generate such instructions.
79+ for (int64_t i = 0 ; i < numElts; ++i) {
80+ if (i != 0 ) {
81+ incIdx (srcIdx, sourceVectorType, /* step=*/ 1 );
82+ incIdx (resIdx, resultVectorType, /* step=*/ extractSize);
83+ }
84+
85+ Value extract =
86+ rewriter.create <vector::ExtractOp>(loc, op.getSource (), srcIdx);
87+ result = rewriter.create <vector::InsertStridedSliceOp>(
88+ loc, extract, result,
89+ /* offsets=*/ resIdx, /* strides=*/ 1 );
7290 }
73- rewriter.replaceOp (op, desc);
91+
92+ rewriter.replaceOp (op, result);
7493 return success ();
7594 }
7695};
7796
78- // / ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D
79- // / vectors progressively.
80- // / This iterates over the most major dimension of the 2-D vector and performs
81- // / rewrites into:
82- // / vector.extract_strided_slice from 1-D + vector.insert into 2-D
97+ // / ShapeOp 1-D -> n-D upcast serves the purpose of unflattening n-D from 1-D
98+ // / vectors progressively. This iterates over the n-1 major dimension of the n-D
99+ // / vector and performs rewrites into:
100+ // / vector.extract_strided_slice from 1-D + vector.insert into n-D
83101// / Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle.
84- class ShapeCastOp2DUpCastRewritePattern
102+ class ShapeCastOpNDUpCastRewritePattern
85103 : public OpRewritePattern<vector::ShapeCastOp> {
86104public:
87105 using OpRewritePattern::OpRewritePattern;
@@ -90,43 +108,43 @@ class ShapeCastOp2DUpCastRewritePattern
90108 PatternRewriter &rewriter) const override {
91109 auto sourceVectorType = op.getSourceVectorType ();
92110 auto resultVectorType = op.getResultVectorType ();
93-
94111 if (sourceVectorType.isScalable () || resultVectorType.isScalable ())
95112 return failure ();
96113
97- if (sourceVectorType.getRank () != 1 || resultVectorType.getRank () != 2 )
114+ int64_t srcRank = sourceVectorType.getRank ();
115+ int64_t resRank = resultVectorType.getRank ();
116+ if (srcRank != 1 || resRank < 2 )
98117 return failure ();
99118
119+ // Compute the number of 1-D vector elements involved in the reshape.
120+ int64_t numElts = 1 ;
121+ for (int64_t dim = 0 ; dim < resRank - 1 ; ++dim)
122+ numElts *= resultVectorType.getDimSize (dim);
123+
124+ // Compute the indices of each 1-D vector element of the source slice
125+ // extraction and destination insertion and generate such instructions.
100126 auto loc = op.getLoc ();
101- Value desc = rewriter.create <arith::ConstantOp>(
127+ SmallVector<int64_t > srcIdx (srcRank);
128+ SmallVector<int64_t > resIdx (resRank - 1 );
129+ int64_t extractSize = resultVectorType.getShape ().back ();
130+ Value result = rewriter.create <arith::ConstantOp>(
102131 loc, resultVectorType, rewriter.getZeroAttr (resultVectorType));
103- unsigned mostMinorVectorSize = resultVectorType.getShape ()[1 ];
104- for (int64_t i = 0 , e = resultVectorType.getShape ().front (); i != e; ++i) {
105- Value vec = rewriter.create <vector::ExtractStridedSliceOp>(
106- loc, op.getSource (), /* offsets=*/ i * mostMinorVectorSize,
107- /* sizes=*/ mostMinorVectorSize,
132+ for (int64_t i = 0 ; i < numElts; ++i) {
133+ if (i != 0 ) {
134+ incIdx (srcIdx, sourceVectorType, /* step=*/ extractSize);
135+ incIdx (resIdx, resultVectorType, /* step=*/ 1 );
136+ }
137+
138+ Value extract = rewriter.create <vector::ExtractStridedSliceOp>(
139+ loc, op.getSource (), /* offsets=*/ srcIdx, /* sizes=*/ extractSize,
108140 /* strides=*/ 1 );
109- desc = rewriter.create <vector::InsertOp>(loc, vec, desc, i );
141+ result = rewriter.create <vector::InsertOp>(loc, extract, result, resIdx );
110142 }
111- rewriter.replaceOp (op, desc );
143+ rewriter.replaceOp (op, result );
112144 return success ();
113145 }
114146};
115147
116- static void incIdx (llvm::MutableArrayRef<int64_t > idx, VectorType tp,
117- int dimIdx, int initialStep = 1 ) {
118- int step = initialStep;
119- for (int d = dimIdx; d >= 0 ; d--) {
120- idx[d] += step;
121- if (idx[d] >= tp.getDimSize (d)) {
122- idx[d] = 0 ;
123- step = 1 ;
124- } else {
125- break ;
126- }
127- }
128- }
129-
130148// We typically should not lower general shape cast operations into data
131149// movement instructions, since the assumption is that these casts are
132150// optimized away during progressive lowering. For completeness, however,
@@ -145,18 +163,14 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
145163 if (sourceVectorType.isScalable () || resultVectorType.isScalable ())
146164 return failure ();
147165
148- // Special case 2D / 1D lowerings with better implementations.
149- // TODO: make is ND / 1D to allow generic ND -> 1D -> MD.
166+ // Special case for n-D / 1-D lowerings with better implementations.
150167 int64_t srcRank = sourceVectorType.getRank ();
151168 int64_t resRank = resultVectorType.getRank ();
152- if ((srcRank == 2 && resRank == 1 ) || (srcRank == 1 && resRank == 2 ))
169+ if ((srcRank > 1 && resRank == 1 ) || (srcRank == 1 && resRank > 1 ))
153170 return failure ();
154171
155172 // Generic ShapeCast lowering path goes all the way down to unrolled scalar
156173 // extract/insert chains.
157- // TODO: consider evolving the semantics to only allow 1D source or dest and
158- // drop this potentially very expensive lowering.
159- // Compute number of elements involved in the reshape.
160174 int64_t numElts = 1 ;
161175 for (int64_t r = 0 ; r < srcRank; r++)
162176 numElts *= sourceVectorType.getDimSize (r);
@@ -172,8 +186,8 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
172186 loc, resultVectorType, rewriter.getZeroAttr (resultVectorType));
173187 for (int64_t i = 0 ; i < numElts; i++) {
174188 if (i != 0 ) {
175- incIdx (srcIdx, sourceVectorType, srcRank - 1 );
176- incIdx (resIdx, resultVectorType, resRank - 1 );
189+ incIdx (srcIdx, sourceVectorType);
190+ incIdx (resIdx, resultVectorType);
177191 }
178192
179193 Value extract;
@@ -252,7 +266,7 @@ class ScalableShapeCastOpRewritePattern
252266 // have a single trailing scalable dimension. This is because there are no
253267 // legal representation of other scalable types in LLVM (and likely won't be
254268 // soon). There are also (currently) no operations that can index or extract
255- // from >= 2D scalable vectors or scalable vectors of fixed vectors.
269+ // from >= 2-D scalable vectors or scalable vectors of fixed vectors.
256270 if (!isTrailingDimScalable (sourceVectorType) ||
257271 !isTrailingDimScalable (resultVectorType)) {
258272 return failure ();
@@ -334,8 +348,8 @@ class ScalableShapeCastOpRewritePattern
334348
335349 // 4. Increment the insert/extract indices, stepping by minExtractionSize
336350 // for the trailing dimensions.
337- incIdx (srcIdx, sourceVectorType, srcRank - 1 , minExtractionSize);
338- incIdx (resIdx, resultVectorType, resRank - 1 , minExtractionSize);
351+ incIdx (srcIdx, sourceVectorType, /* step= */ minExtractionSize);
352+ incIdx (resIdx, resultVectorType, /* step= */ minExtractionSize);
339353 }
340354
341355 rewriter.replaceOp (op, result);
@@ -352,8 +366,8 @@ class ScalableShapeCastOpRewritePattern
352366
353367void mlir::vector::populateVectorShapeCastLoweringPatterns (
354368 RewritePatternSet &patterns, PatternBenefit benefit) {
355- patterns.add <ShapeCastOp2DDownCastRewritePattern ,
356- ShapeCastOp2DUpCastRewritePattern , ShapeCastOpRewritePattern,
369+ patterns.add <ShapeCastOpNDDownCastRewritePattern ,
370+ ShapeCastOpNDUpCastRewritePattern , ShapeCastOpRewritePattern,
357371 ScalableShapeCastOpRewritePattern>(patterns.getContext (),
358372 benefit);
359373}
0 commit comments