12
12
// /
13
13
// ===----------------------------------------------------------------------===//
14
14
15
+ #include " mlir/Dialect/Vector/IR/VectorOps.h"
16
+ #include " mlir/IR/BuiltinOps.h"
17
+ #include " mlir/IR/BuiltinTypes.h"
18
+ #include " mlir/Support/LogicalResult.h"
19
+ #include " llvm/ADT/ArrayRef.h"
20
+ #include " llvm/Transforms/Utils/AddDiscriminators.h"
21
+ #include < cstdint>
15
22
#include < imex/Transforms/Passes.h>
16
23
17
24
#include < mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h>
18
25
#include < mlir/Pass/Pass.h>
19
26
#include < mlir/Transforms/DialectConversion.h>
27
+ #include < numeric>
20
28
21
29
namespace imex {
22
30
#define GEN_PASS_DEF_VECTORLINEARIZE
23
31
#include " imex/Transforms/Passes.h.inc"
24
32
} // namespace imex
25
33
26
34
namespace {
35
+
36
+ struct VectorExtractStridedSliceConversion final
37
+ : public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> {
38
+ using mlir::OpConversionPattern<
39
+ mlir::vector::ExtractStridedSliceOp>::OpConversionPattern;
40
+
41
+ mlir::LogicalResult
42
+ matchAndRewrite (mlir::vector::ExtractStridedSliceOp extractOp,
43
+ OpAdaptor adaptor,
44
+ mlir::ConversionPatternRewriter &rewriter) const override {
45
+ auto dstType = getTypeConverter ()->convertType (extractOp.getType ());
46
+ auto loc = extractOp.getLoc ();
47
+ if (!dstType)
48
+ return rewriter.notifyMatchFailure (loc, " cannot convert type." );
49
+
50
+ if (extractOp.getVector ().getType ().isScalable () ||
51
+ dstType.cast <mlir::VectorType>().isScalable ())
52
+ return rewriter.notifyMatchFailure (loc,
53
+ " scalable vectors are not supported." );
54
+
55
+ auto offsets = extractOp.getOffsets ().getValue ();
56
+ auto sizes = extractOp.getSizes ().getValue ();
57
+ auto strides = extractOp.getStrides ().getValue ();
58
+
59
+ if (!mlir::isConstantIntValue (strides[0 ], 1 ))
60
+ return rewriter.notifyMatchFailure (
61
+ extractOp, " Strided slice with stride != 1 is not supported." );
62
+
63
+ mlir::Value srcVector = adaptor.getVector ();
64
+
65
+ // if kD offsets are specified for nd source vector (n > k), the granularity
66
+ // of the extraction is greater than 1. In this case last (n-k) dimensions
67
+ // form the extraction granularity. example : %0 =
68
+ // vector.extract_strided_slice %src { offsets = [0, 0], sizes = [2, 2],
69
+ // strides = [1, 1]} : vector<4x8x8xf32> to vector<2x2x8xf32>
70
+ // here, extraction granularity is 8.
71
+ int64_t extractSliceLen = 1 ;
72
+ auto n = extractOp.getSourceVectorType ().getRank ();
73
+ auto k = (int64_t )offsets.size ();
74
+ if (n > k) {
75
+ for (unsigned i = 0 ; i < n - k; i++) {
76
+ extractSliceLen *= extractOp.getSourceVectorType ().getShape ()[i + k];
77
+ }
78
+ }
79
+
80
+ // get total number of extracted slices
81
+ int64_t nExtractedSlices = 1 ;
82
+ for (auto size : sizes) {
83
+ nExtractedSlices *= size.cast <mlir::IntegerAttr>().getInt ();
84
+ }
85
+
86
+ // compute the strides of the source vector considering first k dimensions
87
+ llvm::SmallVector<int64_t , 4 > sourceStrides (k, extractSliceLen);
88
+ for (int i = k - 2 ; i >= 0 ; --i) {
89
+ sourceStrides[i] = sourceStrides[i + 1 ] *
90
+ extractOp.getSourceVectorType ().getShape ()[i + 1 ];
91
+ }
92
+ // final shuffle indices has nExtractedElems * extractSliceLen elements
93
+ llvm::SmallVector<int64_t , 4 > indices (nExtractedSlices * extractSliceLen);
94
+ // compute the strides of the extracted kD vector
95
+ llvm::SmallVector<int64_t , 4 > extractedStrides (k, 1 );
96
+ // compute extractedStrides
97
+ for (int i = k - 2 ; i >= 0 ; --i) {
98
+ extractedStrides[i] = extractedStrides[i + 1 ] *
99
+ sizes[i + 1 ].cast <mlir::IntegerAttr>().getInt ();
100
+ }
101
+ // iterate over all extracted slices from 0 to nExtractedElems-1
102
+ // and compute the multi-dimensional index and the corresponding linearized
103
+ // index within the source vector
104
+ for (int64_t i = 0 ; i < nExtractedSlices; ++i) {
105
+ int64_t index = i;
106
+ // compute the corresponding multi-dimensional index
107
+ llvm::SmallVector<int64_t , 4 > multiDimIndex (k, 0 );
108
+ for (int64_t j = 0 ; j < k; ++j) {
109
+ multiDimIndex[j] = (index / extractedStrides[j]);
110
+ index -= multiDimIndex[j] * extractedStrides[j];
111
+ }
112
+ // compute the corresponding linearized index in the source vector
113
+ // i.e. shift the multiDimIndex by the offsets
114
+ int64_t linearizedIndex = 0 ;
115
+ for (int64_t j = 0 ; j < k; ++j) {
116
+ linearizedIndex +=
117
+ (offsets[j].cast <mlir::IntegerAttr>().getInt () + multiDimIndex[j]) *
118
+ sourceStrides[j];
119
+ }
120
+ // fill the indices array form linearizedIndex to linearizedIndex +
121
+ // sliceLen
122
+ for (int64_t j = 0 ; j < extractSliceLen; ++j) {
123
+ indices[i * extractSliceLen + j] = linearizedIndex + j;
124
+ }
125
+ }
126
+ // perform a shuffle to extract the kD vector
127
+ rewriter.replaceOpWithNewOp <mlir::vector::ShuffleOp>(
128
+ extractOp, dstType, srcVector, srcVector,
129
+ rewriter.getI64ArrayAttr (indices));
130
+
131
+ return mlir::success ();
132
+ }
133
+ };
134
+
135
+ struct VectorShffleOpConversion final
136
+ : public mlir::OpConversionPattern<mlir::vector::ShuffleOp> {
137
+ using mlir::OpConversionPattern<mlir::vector::ShuffleOp>::OpConversionPattern;
138
+
139
+ mlir::LogicalResult
140
+ matchAndRewrite (mlir::vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
141
+ mlir::ConversionPatternRewriter &rewriter) const override {
142
+ auto dstType = getTypeConverter ()->convertType (shuffleOp.getType ());
143
+ auto loc = shuffleOp.getLoc ();
144
+ if (!dstType)
145
+ return rewriter.notifyMatchFailure (loc, " cannot convert type." );
146
+
147
+ auto vec1 = adaptor.getV1 ();
148
+ auto vec2 = adaptor.getV2 ();
149
+
150
+ int shuffleSliceLen = 1 ;
151
+ int rank = shuffleOp.getV1 ().getType ().getRank ();
152
+
153
+ // if rank > 1, we need to do the shuffle in the granularity of slices
154
+ // instead of scalars. Size of the slice is equal to the rank-1 innermost
155
+ // dims. Mask of the shuffle op specifies which slice to take from the
156
+ // outermost dim.
157
+ if (rank > 1 ) {
158
+ auto shape = shuffleOp.getV1 ().getType ().getShape ();
159
+ for (unsigned i = 1 ; i < shape.size (); i++) {
160
+ shuffleSliceLen *= shape[i];
161
+ }
162
+ }
163
+
164
+ auto mask = shuffleOp.getMask ();
165
+ auto totalSize = mask.size () * shuffleSliceLen;
166
+
167
+ llvm::SmallVector<int64_t , 2 > indices (totalSize);
168
+ for (auto [i, value] :
169
+ llvm::enumerate (mask.getAsValueRange <mlir::IntegerAttr>())) {
170
+
171
+ int64_t v = value.getZExtValue ();
172
+ std::iota (indices.begin () + shuffleSliceLen * i,
173
+ indices.begin () + shuffleSliceLen * (i + 1 ),
174
+ shuffleSliceLen * v);
175
+ }
176
+
177
+ rewriter.replaceOpWithNewOp <mlir::vector::ShuffleOp>(
178
+ shuffleOp, dstType, vec1, vec2, rewriter.getI64ArrayAttr (indices));
179
+
180
+ return mlir::success ();
181
+ }
182
+ };
183
+
184
+ struct VectorExtractOpConversion final
185
+ : public mlir::OpConversionPattern<mlir::vector::ExtractOp> {
186
+ using OpConversionPattern::OpConversionPattern;
187
+ mlir::LogicalResult
188
+ matchAndRewrite (mlir::vector::ExtractOp extractOp, OpAdaptor adaptor,
189
+ mlir::ConversionPatternRewriter &rewriter) const override {
190
+ auto dstTy = getTypeConverter ()->convertType (extractOp.getType ());
191
+ if (!dstTy)
192
+ return rewriter.notifyMatchFailure (extractOp, " cannot convert type." );
193
+
194
+ // dynamic position is not supported
195
+ if (extractOp.hasDynamicPosition ())
196
+ return rewriter.notifyMatchFailure (extractOp,
197
+ " dynamic position is not supported." );
198
+
199
+ auto shape = extractOp.getVector ().getType ().getShape ();
200
+ auto size = extractOp.getVector ().getType ().getNumElements ();
201
+
202
+ // compute linearized offset
203
+ int64_t linearizedOffset = 0 ;
204
+ auto offsets = extractOp.getStaticPosition ();
205
+ for (auto [i, off] : llvm::enumerate (offsets)) {
206
+ size /= shape[i];
207
+ linearizedOffset += offsets[i] * size;
208
+ }
209
+
210
+ llvm::SmallVector<int64_t , 2 > indices (size);
211
+ std::iota (indices.begin (), indices.end (), linearizedOffset);
212
+ rewriter.replaceOpWithNewOp <mlir::vector::ShuffleOp>(
213
+ extractOp, dstTy, adaptor.getVector (), adaptor.getVector (),
214
+ rewriter.getI64ArrayAttr (indices));
215
+
216
+ return mlir::success ();
217
+ }
218
+ };
219
+
27
220
struct VectorLinearizePass final
28
221
: public imex::impl::VectorLinearizeBase<VectorLinearizePass> {
29
222
@@ -34,6 +227,14 @@ struct VectorLinearizePass final
34
227
mlir::RewritePatternSet patterns (context);
35
228
mlir::ConversionTarget target (*context);
36
229
230
+ target.addDynamicallyLegalOp <mlir::vector::ShuffleOp>([&](mlir::Operation
231
+ *op) {
232
+ return op->getResult (0 ).getType ().cast <mlir::VectorType>().getRank () == 1 ;
233
+ });
234
+
235
+ patterns.add <VectorExtractStridedSliceConversion, VectorShffleOpConversion,
236
+ VectorExtractOpConversion>(typeConverter, context);
237
+
37
238
typeConverter.addConversion ([](mlir::Type type) { return type; });
38
239
mlir::vector::populateVectorLinearizeTypeConversionsAndLegality (
39
240
typeConverter, patterns, target);
0 commit comments