17
17
18
18
namespace imex {
19
19
20
+ using VectorTypedValue = mlir::TypedValue<mlir::VectorType>;
21
+ using funcTy = VectorTypedValue(mlir::Value, mlir::Value, mlir::Location,
22
+ mlir::PatternRewriter &);
23
+
24
+ // see its description in XeTileOpConversion.cpp
25
+ extern VectorTypedValue concat (mlir::Value v1, mlir::Value v2,
26
+ mlir::Location loc,
27
+ mlir::PatternRewriter &rewriter);
28
+
29
+ // see its description in XeTileOpConversion.cpp
30
+ extern mlir::Value mergeVectorsWrapper (mlir::ValueRange ins,
31
+ std::function<funcTy> transFunc,
32
+ mlir::Location loc,
33
+ XeGPUOneToNPatterRewriter &rewriter);
34
+
35
+ static mlir::Value createBinOp (mlir::vector::CombiningKind kind,
36
+ mlir::Value lhs, mlir::Value rhs,
37
+ mlir::Type elemTy, mlir::Location &loc,
38
+ XeGPUOneToNPatterRewriter &rewriter) {
39
+
40
+ // ADD and MUL are defined for both Integers and Floats,
41
+ // need to generate code based on element data type.
42
+ if (kind == mlir::vector::CombiningKind::ADD) {
43
+ if (elemTy.isa <mlir::FloatType>()) {
44
+ return rewriter.create <mlir::arith::AddFOp>(loc, lhs, rhs);
45
+ }
46
+ if (elemTy.isa <mlir::IntegerType>()) {
47
+ return rewriter.create <mlir::arith::AddIOp>(loc, lhs, rhs);
48
+ }
49
+ }
50
+
51
+ if (kind == mlir::vector::CombiningKind::MUL) {
52
+ if (elemTy.isa <mlir::FloatType>()) {
53
+ return rewriter.create <mlir::arith::MulFOp>(loc, lhs, rhs);
54
+ }
55
+ if (elemTy.isa <mlir::IntegerType>()) {
56
+ return rewriter.create <mlir::arith::MulIOp>(loc, lhs, rhs);
57
+ }
58
+ }
59
+
60
+ switch (kind) {
61
+ // the following are for ints only
62
+ case mlir::vector::CombiningKind::MINUI:
63
+ return rewriter.create <mlir::arith::MinUIOp>(loc, lhs, rhs);
64
+ case mlir::vector::CombiningKind::MINSI:
65
+ return rewriter.create <mlir::arith::MinSIOp>(loc, lhs, rhs);
66
+ case mlir::vector::CombiningKind::MAXUI:
67
+ return rewriter.create <mlir::arith::MaxUIOp>(loc, lhs, rhs);
68
+ case mlir::vector::CombiningKind::MAXSI:
69
+ return rewriter.create <mlir::arith::MaxSIOp>(loc, lhs, rhs);
70
+ case mlir::vector::CombiningKind::AND:
71
+ return rewriter.create <mlir::arith::AndIOp>(loc, lhs, rhs);
72
+ case mlir::vector::CombiningKind::OR:
73
+ return rewriter.create <mlir::arith::OrIOp>(loc, lhs, rhs);
74
+ case mlir::vector::CombiningKind::XOR:
75
+ return rewriter.create <mlir::arith::XOrIOp>(loc, lhs, rhs);
76
+ // the following are for floats only
77
+ case mlir::vector::CombiningKind::MINNUMF:
78
+ return rewriter.create <mlir::arith::MinNumFOp>(loc, lhs, rhs);
79
+ case mlir::vector::CombiningKind::MAXNUMF:
80
+ return rewriter.create <mlir::arith::MaxNumFOp>(loc, lhs, rhs);
81
+ case mlir::vector::CombiningKind::MINIMUMF:
82
+ return rewriter.create <mlir::arith::MinimumFOp>(loc, lhs, rhs);
83
+ case mlir::vector::CombiningKind::MAXIMUMF:
84
+ return rewriter.create <mlir::arith::MaximumFOp>(loc, lhs, rhs);
85
+ default :
86
+ llvm_unreachable (" Unexpected CombiningKind." );
87
+ return lhs;
88
+ }
89
+ }
90
+
91
+ llvm::SmallVector<mlir::Value>
92
+ lowerOuterReduction (mlir::ValueRange sources, llvm::ArrayRef<int64_t > shape,
93
+ mlir::vector::CombiningKind kind, mlir::Location loc,
94
+ mlir::Type elemTy, XeGPUOneToNPatterRewriter &rewriter) {
95
+ assert (shape.size () == 4 && " shape should be 4D." );
96
+ llvm::SmallVector<mlir::Value> intermediates;
97
+ for (auto j = 0 ; j < shape[1 ]; j++) {
98
+ auto combiningVal = sources[j];
99
+ for (auto i = 1 ; i < shape[0 ]; i++) {
100
+ combiningVal = createBinOp (kind, combiningVal, sources[i * shape[1 ] + j],
101
+ elemTy, loc, rewriter);
102
+ }
103
+ {
104
+ // TODO: After blocking If the first dimension of the small block is not
105
+ // 1, the combiningVal is now in shape as, e.g., vector<4x16xf16> instead
106
+ // of vector<1x16xf16> then more reductions are needed in dim0, to make it
107
+ // as vector<1x16xf16>. Currently, this is not implemented, since we are
108
+ // now restricted blocking pass to set it as 1 now. It may cannot achieve
109
+ // peak performance in some cases.
110
+ assert (shape[2 ] == 1 &&
111
+ " more reductions is needed in dim0, but not supported." );
112
+ }
113
+ intermediates.push_back (combiningVal);
114
+ }
115
+ return intermediates;
116
+ }
117
+
118
+ // expected input is type of vector<ixjx1xnxf16>, where i and n is power of 2
119
+ // and the third dim is always 1, which should be set by the blocking pass.
120
+ llvm::SmallVector<mlir::Value> lowerInnerReductionWithIntraVectorShuffles (
121
+ mlir::ValueRange sources, llvm::ArrayRef<int64_t > shape,
122
+ mlir::vector::CombiningKind kind, mlir::Location loc, mlir::Type elemTy,
123
+ XeGPUOneToNPatterRewriter &rewriter) {
124
+
125
+ assert (shape.size () == 4 && " shape should be 4D." );
126
+
127
+ auto isPowerOfTwo = [](auto n) { return (n & (n - 1 )) == 0 ; };
128
+
129
+ // make sure the dim0 of the block is 1 in blocking pass
130
+ // different from outer reduction, this is strictly required
131
+ // for this method.
132
+ assert (shape[2 ] == 1 && " dim0 of the block has to be 1." );
133
+ assert (isPowerOfTwo (shape[0 ]) && isPowerOfTwo (shape[3 ]) &&
134
+ " sizes of dim1 and dim4 should be power of 2." );
135
+
136
+ auto genShuffleMasks = [&](int blkSize, int vecSize) {
137
+ llvm::SmallVector<int64_t > mask1;
138
+ llvm::SmallVector<int64_t > mask2;
139
+ auto s1 = 0 , s2 = blkSize;
140
+ for (auto i = 0 ; i < vecSize; i++) {
141
+ if (i && i % blkSize == 0 ) {
142
+ s1 += blkSize;
143
+ s2 += blkSize;
144
+ }
145
+
146
+ mask1.push_back (s1);
147
+ mask2.push_back (s2);
148
+ s1++;
149
+ s2++;
150
+ }
151
+ return std::make_pair (mask1, mask2);
152
+ };
153
+
154
+ // Stage 1: vector<ixjx1xnxf16> equals to a grid of ixj of vector<1xnxf16>
155
+ // after lowering to xegpu. This stage performs j-1 reduction operations on
156
+ // j dim of the grid, the result is a vector of vector<mxnxf16> with size i.
157
+ llvm::SmallVector<mlir::Value> intermediates (shape[0 ]);
158
+ for (auto i = 0 ; i < shape[0 ]; i++) {
159
+ auto combiningVal = sources[i * shape[1 ]];
160
+ for (auto j = 1 ; j < shape[1 ]; j++) {
161
+ combiningVal = createBinOp (kind, combiningVal, sources[i * shape[1 ] + j],
162
+ elemTy, loc, rewriter);
163
+ }
164
+ // cast the result of e.g., vector<1x16xf16> into vector<16xf16>
165
+ auto targetTy = mlir::VectorType::get ({shape[3 ]}, elemTy);
166
+ combiningVal =
167
+ rewriter.create <mlir::vector::ShapeCastOp>(loc, targetTy, combiningVal);
168
+ intermediates[i] = combiningVal;
169
+ }
170
+
171
+ // Stage 2: doing intra vector reduction with shuffle Ops.
172
+ // Each vector in the result of stage 1 can be viewed as a row
173
+ // each row has e.g., 32 elements:
174
+ // v1 = [a0 a1 a2 a3 a4 a5 a6 a7 a8 a9 ... a31]
175
+ // v2 = [b0 b1 b2 b3 b4 b5 b6 b7 b8 b9 ... b31]
176
+ // ...
177
+ // vn = [p0 p1 p2 p3 p4 p5 p6 p7 p8 p9 ... p31]
178
+ // it will repeately doing shuffle between two consecutive vectors
179
+ // v1 and v2, v3 and v4, ..., vn-1 and vn with a block size. Such
180
+ // that we can get two new vectors. The block size is typically
181
+ // starts with half of the vector size. For example, for v1 and v2,
182
+ // it is 16, and we can get:
183
+ // nv1 = [a0, .., a15, b0, .., b15]
184
+ // nv2 = [a16, .., a31, b16, .., b31]
185
+ // and we then performs nv1 + nv2 (if reduction op is add)
186
+ // such that the left half of the vector contains the partial reduction
187
+ // of v1, and the right half contains the partial reduction of v2.
188
+ // and the the number of vectors is reduced by half after one iteration.
189
+ // and we reduce the block size by half, and repeat the process until
190
+ // the block size is 1.
191
+ // The intermediate result of this stage is an array of vectors with
192
+ // type, e.g., vector<nxf16>, array size is `i/n`. And these vectors
193
+ // will be merged into a single vector with type vector<ixf16>.
194
+ auto blkSize = shape[3 ] / 2 ;
195
+ while (blkSize) {
196
+ auto workList = intermediates;
197
+ intermediates.clear ();
198
+ assert (workList.size () % 2 == 0 && " The size should be divisible by 2." );
199
+ auto masks = genShuffleMasks (blkSize, shape[3 ]);
200
+ for (size_t i = 0 ; i < workList.size (); i += 2 ) {
201
+ auto v1 = workList[i];
202
+ auto v2 = workList[i + 1 ];
203
+ auto shuffleOp1 =
204
+ rewriter.create <mlir::vector::ShuffleOp>(loc, v1, v2, masks.first );
205
+ auto shuffleOp2 =
206
+ rewriter.create <mlir::vector::ShuffleOp>(loc, v1, v2, masks.second );
207
+ auto reductionVal =
208
+ createBinOp (kind, shuffleOp1, shuffleOp2, elemTy, loc, rewriter);
209
+ intermediates.push_back (reductionVal);
210
+ }
211
+ blkSize /= 2 ;
212
+ }
213
+ return intermediates;
214
+ }
215
+
216
+ class SgVectorMultiDimReductionOpPattern
217
+ : public SgXeTileToXeGPUConversion<mlir::vector::MultiDimReductionOp> {
218
+ using SgXeTileToXeGPUConversion<
219
+ mlir::vector::MultiDimReductionOp>::SgXeTileToXeGPUConversion;
220
+
221
+ mlir::LogicalResult
222
+ matchAndRewrite (mlir::vector::MultiDimReductionOp op, OpAdaptor adaptor,
223
+ XeGPUOneToNPatterRewriter &rewriter) const override {
224
+ auto srcTy = op.getSource ().getType ();
225
+ auto elemTy = srcTy.getElementType ();
226
+ auto dims = op.getReductionDims ();
227
+ // its input should be a 4D vector, and has 2 reduction dims,
228
+ // otherwise run the blocking pass first.
229
+ if (dims.size () != 2 || srcTy.getRank () != 4 )
230
+ return mlir::failure ();
231
+
232
+ auto loc = op.getLoc ();
233
+ auto shape = srcTy.getShape ();
234
+ auto sources = adaptor.getSource ();
235
+
236
+ rewriter.setInsertionPoint (op);
237
+ // doing reduction on outer dimension
238
+ if (mlir::isConstantIntValue (dims[0 ], 0 ) &&
239
+ mlir::isConstantIntValue (dims[1 ], 2 )) {
240
+ auto intermediates = lowerOuterReduction (sources, shape, op.getKind (),
241
+ loc, elemTy, rewriter);
242
+ {
243
+ // TODO: need a better way to represent the result (align with
244
+ // unpack/pack logic). currently we just shuffle them and cast it to the
245
+ // type/shape in xetile program.
246
+ auto reducedVal =
247
+ mergeVectorsWrapper (intermediates, concat, loc, rewriter);
248
+ auto targetTy = mlir::VectorType::get ({shape[1 ], shape[3 ]}, elemTy);
249
+ auto newOp = rewriter.create <mlir::vector::ShapeCastOp>(loc, targetTy,
250
+ reducedVal);
251
+ rewriter.replaceOp (op, newOp);
252
+ }
253
+ return mlir::success ();
254
+ }
255
+
256
+ // doing reduction on inner dimension
257
+ if (mlir::isConstantIntValue (dims[0 ], 1 ) &&
258
+ mlir::isConstantIntValue (dims[1 ], 3 )) {
259
+ auto intermediates = lowerInnerReductionWithIntraVectorShuffles (
260
+ sources, shape, op.getKind (), loc, elemTy, rewriter);
261
+
262
+ { // TODO: need a better way to represent the result (align with
263
+ // unpack/pack logic).
264
+ // currently we just shuffle them and cast it to the type/shape in
265
+ // xetile program.
266
+ auto reductionVal =
267
+ mergeVectorsWrapper (intermediates, concat, loc, rewriter);
268
+ auto targetTy = mlir::VectorType::get ({shape[0 ], shape[2 ]}, elemTy);
269
+ auto newOp = rewriter.create <mlir::vector::ShapeCastOp>(loc, targetTy,
270
+ reductionVal);
271
+ rewriter.replaceOp (op, newOp);
272
+ }
273
+ return mlir::success ();
274
+ }
275
+
276
+ // something is wrong
277
+ return op.emitError (" unsupported reduction operation." );
278
+ }
279
+ };
280
+
20
281
class SgArithConstantOpPattern
21
282
: public SgXeTileToXeGPUConversion<mlir::arith::ConstantOp> {
22
283
using SgXeTileToXeGPUConversion<
@@ -26,8 +287,7 @@ class SgArithConstantOpPattern
26
287
matchAndRewrite (mlir::arith::ConstantOp op, OpAdaptor adaptor,
27
288
XeGPUOneToNPatterRewriter &rewriter) const override {
28
289
auto loc = op.getLoc ();
29
- auto value =
30
- llvm::dyn_cast_if_present<mlir::DenseElementsAttr>(op.getValue ());
290
+ auto value = llvm::dyn_cast<mlir::DenseElementsAttr>(op.getValue ());
31
291
32
292
// We only interesting 4D vectors
33
293
if (!value || value.getType ().getRank () != 4 )
@@ -38,8 +298,8 @@ class SgArithConstantOpPattern
38
298
value.value_end <mlir::Attribute>());
39
299
40
300
auto shape = value.getType ().getShape ();
41
- auto vecTy =
42
- mlir::VectorType::get ({shape[2 ], shape[3 ]}, value. getElementType () );
301
+ auto elemTy = value. getElementType ();
302
+ auto vecTy = mlir::VectorType::get ({shape[2 ], shape[3 ]}, elemTy );
43
303
44
304
// slice a block of (shape[2], shape[3]) from elems.
45
305
auto slice = [&](int i, int j) {
@@ -83,8 +343,8 @@ bool isLegalArithOp(mlir::Operation *op) {
83
343
void populateArithOpConversionPatterns (imex::XeGPUTypeConverter &converter,
84
344
mlir::RewritePatternSet &patterns,
85
345
TileUsageAnalysis &analysis) {
86
- patterns.add <SgArithConstantOpPattern>(patterns. getContext (), converter,
87
- analysis);
346
+ patterns.add <SgArithConstantOpPattern, SgVectorMultiDimReductionOpPattern>(
347
+ patterns. getContext (), converter, analysis);
88
348
}
89
349
90
350
} // namespace imex
0 commit comments