11
11
//
12
12
// ===----------------------------------------------------------------------===//
13
13
14
- #include " mlir/Dialect/Affine/IR/AffineOps.h"
15
14
#include " mlir/Dialect/Arith/IR/Arith.h"
16
- #include " mlir/Dialect/Arith/Utils/Utils.h"
17
- #include " mlir/Dialect/Linalg/IR/Linalg.h"
18
15
#include " mlir/Dialect/MemRef/IR/MemRef.h"
19
- #include " mlir/Dialect/SCF/IR/SCF.h"
20
- #include " mlir/Dialect/Tensor/IR/Tensor.h"
21
- #include " mlir/Dialect/Utils/IndexingUtils.h"
22
- #include " mlir/Dialect/Utils/StructuredOpsUtils.h"
23
16
#include " mlir/Dialect/Vector/IR/VectorOps.h"
24
17
#include " mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
25
18
#include " mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
26
19
#include " mlir/Dialect/Vector/Utils/VectorUtils.h"
27
- #include " mlir/IR/BuiltinAttributeInterfaces.h"
28
20
#include " mlir/IR/BuiltinTypes.h"
29
- #include " mlir/IR/ImplicitLocOpBuilder.h"
30
21
#include " mlir/IR/Location.h"
31
- #include " mlir/IR/Matchers.h"
32
22
#include " mlir/IR/PatternMatch.h"
33
23
#include " mlir/IR/TypeUtilities.h"
34
- #include " mlir/Interfaces/VectorInterfaces.h"
35
24
36
25
#define DEBUG_TYPE " vector-shape-cast-lowering"
37
26
38
27
using namespace mlir ;
39
28
using namespace mlir ::vector;
40
29
30
+ // / Increments n-D `indices` by `step` starting from the innermost dimension.
31
+ static void incIdx (SmallVectorImpl<int64_t > &indices, VectorType vecType,
32
+ int step = 1 ) {
33
+ for (int dim : llvm::reverse (llvm::seq<int >(0 , indices.size ()))) {
34
+ assert (indices[dim] < vecType.getDimSize (dim) &&
35
+ " Indices are out of bound" );
36
+ indices[dim] += step;
37
+ if (indices[dim] < vecType.getDimSize (dim))
38
+ break ;
39
+
40
+ indices[dim] = 0 ;
41
+ step = 1 ;
42
+ }
43
+ }
44
+
41
45
namespace {
42
- // / ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D
43
- // / vectors progressively on the way to target llvm.matrix intrinsics.
44
- // / This iterates over the most major dimension of the 2-D vector and performs
45
- // / rewrites into:
46
- // / vector.extract from 2-D + vector.insert_strided_slice offset into 1-D
47
- class ShapeCastOp2DDownCastRewritePattern
46
+ // / ShapeOp n-D -> 1-D downcast serves the purpose of flattening N-D to 1-D
47
+ // / vectors progressively. This iterates over the n-1 major dimensions of the
48
+ // / n-D vector and performs rewrites into:
49
+ // / vector.extract from n-D + vector.insert_strided_slice offset into 1-D
50
+ class ShapeCastOpNDDownCastRewritePattern
48
51
: public OpRewritePattern<vector::ShapeCastOp> {
49
52
public:
50
53
using OpRewritePattern::OpRewritePattern;
@@ -53,35 +56,52 @@ class ShapeCastOp2DDownCastRewritePattern
53
56
PatternRewriter &rewriter) const override {
54
57
auto sourceVectorType = op.getSourceVectorType ();
55
58
auto resultVectorType = op.getResultVectorType ();
56
-
57
59
if (sourceVectorType.isScalable () || resultVectorType.isScalable ())
58
60
return failure ();
59
61
60
- if (sourceVectorType.getRank () != 2 || resultVectorType.getRank () != 1 )
62
+ int64_t srcRank = sourceVectorType.getRank ();
63
+ int64_t resRank = resultVectorType.getRank ();
64
+ if (srcRank < 2 || resRank != 1 )
61
65
return failure ();
62
66
67
+ // Compute the number of 1-D vector elements involved in the reshape.
68
+ int64_t numElts = 1 ;
69
+ for (int64_t dim = 0 ; dim < srcRank - 1 ; ++dim)
70
+ numElts *= sourceVectorType.getDimSize (dim);
71
+
63
72
auto loc = op.getLoc ();
64
- Value desc = rewriter.create <arith::ConstantOp>(
73
+ SmallVector<int64_t > srcIdx (srcRank - 1 , 0 );
74
+ SmallVector<int64_t > resIdx (resRank, 0 );
75
+ int64_t extractSize = sourceVectorType.getShape ().back ();
76
+ Value result = rewriter.create <arith::ConstantOp>(
65
77
loc, resultVectorType, rewriter.getZeroAttr (resultVectorType));
66
- unsigned mostMinorVectorSize = sourceVectorType.getShape ()[1 ];
67
- for (int64_t i = 0 , e = sourceVectorType.getShape ().front (); i != e; ++i) {
68
- Value vec = rewriter.create <vector::ExtractOp>(loc, op.getSource (), i);
69
- desc = rewriter.create <vector::InsertStridedSliceOp>(
70
- loc, vec, desc,
71
- /* offsets=*/ i * mostMinorVectorSize, /* strides=*/ 1 );
78
+
79
+ // Compute the indices of each 1-D vector element of the source extraction
80
+ // and destination slice insertion and generate such instructions.
81
+ for (int64_t i = 0 ; i < numElts; ++i) {
82
+ if (i != 0 ) {
83
+ incIdx (srcIdx, sourceVectorType, /* step=*/ 1 );
84
+ incIdx (resIdx, resultVectorType, /* step=*/ extractSize);
85
+ }
86
+
87
+ Value extract =
88
+ rewriter.create <vector::ExtractOp>(loc, op.getSource (), srcIdx);
89
+ result = rewriter.create <vector::InsertStridedSliceOp>(
90
+ loc, extract, result,
91
+ /* offsets=*/ resIdx, /* strides=*/ 1 );
72
92
}
73
- rewriter.replaceOp (op, desc);
93
+
94
+ rewriter.replaceOp (op, result);
74
95
return success ();
75
96
}
76
97
};
77
98
78
- // / ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D
79
- // / vectors progressively.
80
- // / This iterates over the most major dimension of the 2-D vector and performs
81
- // / rewrites into:
82
- // / vector.extract_strided_slice from 1-D + vector.insert into 2-D
99
+ // / ShapeOp 1-D -> n-D upcast serves the purpose of unflattening n-D from 1-D
100
+ // / vectors progressively. This iterates over the n-1 major dimension of the n-D
101
+ // / vector and performs rewrites into:
102
+ // / vector.extract_strided_slice from 1-D + vector.insert into n-D
83
103
// / Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle.
84
- class ShapeCastOp2DUpCastRewritePattern
104
+ class ShapeCastOpNDUpCastRewritePattern
85
105
: public OpRewritePattern<vector::ShapeCastOp> {
86
106
public:
87
107
using OpRewritePattern::OpRewritePattern;
@@ -90,43 +110,43 @@ class ShapeCastOp2DUpCastRewritePattern
90
110
PatternRewriter &rewriter) const override {
91
111
auto sourceVectorType = op.getSourceVectorType ();
92
112
auto resultVectorType = op.getResultVectorType ();
93
-
94
113
if (sourceVectorType.isScalable () || resultVectorType.isScalable ())
95
114
return failure ();
96
115
97
- if (sourceVectorType.getRank () != 1 || resultVectorType.getRank () != 2 )
116
+ int64_t srcRank = sourceVectorType.getRank ();
117
+ int64_t resRank = resultVectorType.getRank ();
118
+ if (srcRank != 1 || resRank < 2 )
98
119
return failure ();
99
120
121
+ // Compute the number of 1-D vector elements involved in the reshape.
122
+ int64_t numElts = 1 ;
123
+ for (int64_t dim = 0 ; dim < resRank - 1 ; ++dim)
124
+ numElts *= resultVectorType.getDimSize (dim);
125
+
126
+ // Compute the indices of each 1-D vector element of the source slice
127
+ // extraction and destination insertion and generate such instructions.
100
128
auto loc = op.getLoc ();
101
- Value desc = rewriter.create <arith::ConstantOp>(
129
+ SmallVector<int64_t > srcIdx (srcRank, 0 );
130
+ SmallVector<int64_t > resIdx (resRank - 1 , 0 );
131
+ int64_t extractSize = resultVectorType.getShape ().back ();
132
+ Value result = rewriter.create <arith::ConstantOp>(
102
133
loc, resultVectorType, rewriter.getZeroAttr (resultVectorType));
103
- unsigned mostMinorVectorSize = resultVectorType.getShape ()[1 ];
104
- for (int64_t i = 0 , e = resultVectorType.getShape ().front (); i != e; ++i) {
105
- Value vec = rewriter.create <vector::ExtractStridedSliceOp>(
106
- loc, op.getSource (), /* offsets=*/ i * mostMinorVectorSize,
107
- /* sizes=*/ mostMinorVectorSize,
134
+ for (int64_t i = 0 ; i < numElts; ++i) {
135
+ if (i != 0 ) {
136
+ incIdx (srcIdx, sourceVectorType, /* step=*/ extractSize);
137
+ incIdx (resIdx, resultVectorType, /* step=*/ 1 );
138
+ }
139
+
140
+ Value extract = rewriter.create <vector::ExtractStridedSliceOp>(
141
+ loc, op.getSource (), /* offsets=*/ srcIdx, /* sizes=*/ extractSize,
108
142
/* strides=*/ 1 );
109
- desc = rewriter.create <vector::InsertOp>(loc, vec, desc, i );
143
+ result = rewriter.create <vector::InsertOp>(loc, extract, result, resIdx );
110
144
}
111
- rewriter.replaceOp (op, desc );
145
+ rewriter.replaceOp (op, result );
112
146
return success ();
113
147
}
114
148
};
115
149
116
- static void incIdx (llvm::MutableArrayRef<int64_t > idx, VectorType tp,
117
- int dimIdx, int initialStep = 1 ) {
118
- int step = initialStep;
119
- for (int d = dimIdx; d >= 0 ; d--) {
120
- idx[d] += step;
121
- if (idx[d] >= tp.getDimSize (d)) {
122
- idx[d] = 0 ;
123
- step = 1 ;
124
- } else {
125
- break ;
126
- }
127
- }
128
- }
129
-
130
150
// We typically should not lower general shape cast operations into data
131
151
// movement instructions, since the assumption is that these casts are
132
152
// optimized away during progressive lowering. For completeness, however,
@@ -145,18 +165,14 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
145
165
if (sourceVectorType.isScalable () || resultVectorType.isScalable ())
146
166
return failure ();
147
167
148
- // Special case 2D / 1D lowerings with better implementations.
149
- // TODO: make is ND / 1D to allow generic ND -> 1D -> MD.
168
+ // Special case for n-D / 1-D lowerings with better implementations.
150
169
int64_t srcRank = sourceVectorType.getRank ();
151
170
int64_t resRank = resultVectorType.getRank ();
152
- if ((srcRank == 2 && resRank == 1 ) || (srcRank == 1 && resRank == 2 ))
171
+ if ((srcRank > 1 && resRank == 1 ) || (srcRank == 1 && resRank > 1 ))
153
172
return failure ();
154
173
155
174
// Generic ShapeCast lowering path goes all the way down to unrolled scalar
156
175
// extract/insert chains.
157
- // TODO: consider evolving the semantics to only allow 1D source or dest and
158
- // drop this potentially very expensive lowering.
159
- // Compute number of elements involved in the reshape.
160
176
int64_t numElts = 1 ;
161
177
for (int64_t r = 0 ; r < srcRank; r++)
162
178
numElts *= sourceVectorType.getDimSize (r);
@@ -166,14 +182,14 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
166
182
// x[0,1,0] = y[0,2]
167
183
// etc., incrementing the two index vectors "row-major"
168
184
// within the source and result shape.
169
- SmallVector<int64_t > srcIdx (srcRank);
170
- SmallVector<int64_t > resIdx (resRank);
185
+ SmallVector<int64_t > srcIdx (srcRank, 0 );
186
+ SmallVector<int64_t > resIdx (resRank, 0 );
171
187
Value result = rewriter.create <arith::ConstantOp>(
172
188
loc, resultVectorType, rewriter.getZeroAttr (resultVectorType));
173
189
for (int64_t i = 0 ; i < numElts; i++) {
174
190
if (i != 0 ) {
175
- incIdx (srcIdx, sourceVectorType, srcRank - 1 );
176
- incIdx (resIdx, resultVectorType, resRank - 1 );
191
+ incIdx (srcIdx, sourceVectorType);
192
+ incIdx (resIdx, resultVectorType);
177
193
}
178
194
179
195
Value extract;
@@ -252,7 +268,7 @@ class ScalableShapeCastOpRewritePattern
252
268
// have a single trailing scalable dimension. This is because there are no
253
269
// legal representation of other scalable types in LLVM (and likely won't be
254
270
// soon). There are also (currently) no operations that can index or extract
255
- // from >= 2D scalable vectors or scalable vectors of fixed vectors.
271
+ // from >= 2-D scalable vectors or scalable vectors of fixed vectors.
256
272
if (!isTrailingDimScalable (sourceVectorType) ||
257
273
!isTrailingDimScalable (resultVectorType)) {
258
274
return failure ();
@@ -278,8 +294,8 @@ class ScalableShapeCastOpRewritePattern
278
294
Value result = rewriter.create <arith::ConstantOp>(
279
295
loc, resultVectorType, rewriter.getZeroAttr (resultVectorType));
280
296
281
- SmallVector<int64_t > srcIdx (srcRank);
282
- SmallVector<int64_t > resIdx (resRank);
297
+ SmallVector<int64_t > srcIdx (srcRank, 0 );
298
+ SmallVector<int64_t > resIdx (resRank, 0 );
283
299
284
300
// TODO: Try rewriting this with StaticTileOffsetRange (from IndexingUtils)
285
301
// once D150000 lands.
@@ -334,8 +350,8 @@ class ScalableShapeCastOpRewritePattern
334
350
335
351
// 4. Increment the insert/extract indices, stepping by minExtractionSize
336
352
// for the trailing dimensions.
337
- incIdx (srcIdx, sourceVectorType, srcRank - 1 , minExtractionSize);
338
- incIdx (resIdx, resultVectorType, resRank - 1 , minExtractionSize);
353
+ incIdx (srcIdx, sourceVectorType, /* step= */ minExtractionSize);
354
+ incIdx (resIdx, resultVectorType, /* step= */ minExtractionSize);
339
355
}
340
356
341
357
rewriter.replaceOp (op, result);
@@ -352,8 +368,8 @@ class ScalableShapeCastOpRewritePattern
352
368
353
369
void mlir::vector::populateVectorShapeCastLoweringPatterns (
354
370
RewritePatternSet &patterns, PatternBenefit benefit) {
355
- patterns.add <ShapeCastOp2DDownCastRewritePattern ,
356
- ShapeCastOp2DUpCastRewritePattern , ShapeCastOpRewritePattern,
371
+ patterns.add <ShapeCastOpNDDownCastRewritePattern ,
372
+ ShapeCastOpNDUpCastRewritePattern , ShapeCastOpRewritePattern,
357
373
ScalableShapeCastOpRewritePattern>(patterns.getContext (),
358
374
benefit);
359
375
}
0 commit comments