Skip to content

Commit 0b3015c

Browse files
committed
Merge remote-tracking branch 'upstream/main'
2 parents 9539371 + 5898c0b commit 0b3015c

File tree

5 files changed

+961
-165
lines changed

5 files changed

+961
-165
lines changed

lib/Conversion/XeTileToXeGPU/ArithOpConversion.cpp

Lines changed: 266 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,267 @@
1717

1818
namespace imex {
1919

20+
using VectorTypedValue = mlir::TypedValue<mlir::VectorType>;
21+
using funcTy = VectorTypedValue(mlir::Value, mlir::Value, mlir::Location,
22+
mlir::PatternRewriter &);
23+
24+
// see its description in XeTileOpConversion.cpp
25+
extern VectorTypedValue concat(mlir::Value v1, mlir::Value v2,
26+
mlir::Location loc,
27+
mlir::PatternRewriter &rewriter);
28+
29+
// see its description in XeTileOpConversion.cpp
30+
extern mlir::Value mergeVectorsWrapper(mlir::ValueRange ins,
31+
std::function<funcTy> transFunc,
32+
mlir::Location loc,
33+
XeGPUOneToNPatterRewriter &rewriter);
34+
35+
static mlir::Value createBinOp(mlir::vector::CombiningKind kind,
36+
mlir::Value lhs, mlir::Value rhs,
37+
mlir::Type elemTy, mlir::Location &loc,
38+
XeGPUOneToNPatterRewriter &rewriter) {
39+
40+
// ADD and MUL are defined for both Integers and Floats,
41+
// need to generate code based on element data type.
42+
if (kind == mlir::vector::CombiningKind::ADD) {
43+
if (elemTy.isa<mlir::FloatType>()) {
44+
return rewriter.create<mlir::arith::AddFOp>(loc, lhs, rhs);
45+
}
46+
if (elemTy.isa<mlir::IntegerType>()) {
47+
return rewriter.create<mlir::arith::AddIOp>(loc, lhs, rhs);
48+
}
49+
}
50+
51+
if (kind == mlir::vector::CombiningKind::MUL) {
52+
if (elemTy.isa<mlir::FloatType>()) {
53+
return rewriter.create<mlir::arith::MulFOp>(loc, lhs, rhs);
54+
}
55+
if (elemTy.isa<mlir::IntegerType>()) {
56+
return rewriter.create<mlir::arith::MulIOp>(loc, lhs, rhs);
57+
}
58+
}
59+
60+
switch (kind) {
61+
// the following are for ints only
62+
case mlir::vector::CombiningKind::MINUI:
63+
return rewriter.create<mlir::arith::MinUIOp>(loc, lhs, rhs);
64+
case mlir::vector::CombiningKind::MINSI:
65+
return rewriter.create<mlir::arith::MinSIOp>(loc, lhs, rhs);
66+
case mlir::vector::CombiningKind::MAXUI:
67+
return rewriter.create<mlir::arith::MaxUIOp>(loc, lhs, rhs);
68+
case mlir::vector::CombiningKind::MAXSI:
69+
return rewriter.create<mlir::arith::MaxSIOp>(loc, lhs, rhs);
70+
case mlir::vector::CombiningKind::AND:
71+
return rewriter.create<mlir::arith::AndIOp>(loc, lhs, rhs);
72+
case mlir::vector::CombiningKind::OR:
73+
return rewriter.create<mlir::arith::OrIOp>(loc, lhs, rhs);
74+
case mlir::vector::CombiningKind::XOR:
75+
return rewriter.create<mlir::arith::XOrIOp>(loc, lhs, rhs);
76+
// the following are for floats only
77+
case mlir::vector::CombiningKind::MINNUMF:
78+
return rewriter.create<mlir::arith::MinNumFOp>(loc, lhs, rhs);
79+
case mlir::vector::CombiningKind::MAXNUMF:
80+
return rewriter.create<mlir::arith::MaxNumFOp>(loc, lhs, rhs);
81+
case mlir::vector::CombiningKind::MINIMUMF:
82+
return rewriter.create<mlir::arith::MinimumFOp>(loc, lhs, rhs);
83+
case mlir::vector::CombiningKind::MAXIMUMF:
84+
return rewriter.create<mlir::arith::MaximumFOp>(loc, lhs, rhs);
85+
default:
86+
llvm_unreachable("Unexpected CombiningKind.");
87+
return lhs;
88+
}
89+
}
90+
91+
llvm::SmallVector<mlir::Value>
92+
lowerOuterReduction(mlir::ValueRange sources, llvm::ArrayRef<int64_t> shape,
93+
mlir::vector::CombiningKind kind, mlir::Location loc,
94+
mlir::Type elemTy, XeGPUOneToNPatterRewriter &rewriter) {
95+
assert(shape.size() == 4 && "shape should be 4D.");
96+
llvm::SmallVector<mlir::Value> intermediates;
97+
for (auto j = 0; j < shape[1]; j++) {
98+
auto combiningVal = sources[j];
99+
for (auto i = 1; i < shape[0]; i++) {
100+
combiningVal = createBinOp(kind, combiningVal, sources[i * shape[1] + j],
101+
elemTy, loc, rewriter);
102+
}
103+
{
104+
// TODO: After blocking If the first dimension of the small block is not
105+
// 1, the combiningVal is now in shape as, e.g., vector<4x16xf16> instead
106+
// of vector<1x16xf16> then more reductions are needed in dim0, to make it
107+
// as vector<1x16xf16>. Currently, this is not implemented, since we are
108+
// now restricted blocking pass to set it as 1 now. It may cannot achieve
109+
// peak performance in some cases.
110+
assert(shape[2] == 1 &&
111+
"more reductions is needed in dim0, but not supported.");
112+
}
113+
intermediates.push_back(combiningVal);
114+
}
115+
return intermediates;
116+
}
117+
118+
// expected input is type of vector<ixjx1xnxf16>, where i and n is power of 2
119+
// and the third dim is always 1, which should be set by the blocking pass.
120+
llvm::SmallVector<mlir::Value> lowerInnerReductionWithIntraVectorShuffles(
121+
mlir::ValueRange sources, llvm::ArrayRef<int64_t> shape,
122+
mlir::vector::CombiningKind kind, mlir::Location loc, mlir::Type elemTy,
123+
XeGPUOneToNPatterRewriter &rewriter) {
124+
125+
assert(shape.size() == 4 && "shape should be 4D.");
126+
127+
auto isPowerOfTwo = [](auto n) { return (n & (n - 1)) == 0; };
128+
129+
// make sure the dim0 of the block is 1 in blocking pass
130+
// different from outer reduction, this is strictly required
131+
// for this method.
132+
assert(shape[2] == 1 && "dim0 of the block has to be 1.");
133+
assert(isPowerOfTwo(shape[0]) && isPowerOfTwo(shape[3]) &&
134+
"sizes of dim1 and dim4 should be power of 2.");
135+
136+
auto genShuffleMasks = [&](int blkSize, int vecSize) {
137+
llvm::SmallVector<int64_t> mask1;
138+
llvm::SmallVector<int64_t> mask2;
139+
auto s1 = 0, s2 = blkSize;
140+
for (auto i = 0; i < vecSize; i++) {
141+
if (i && i % blkSize == 0) {
142+
s1 += blkSize;
143+
s2 += blkSize;
144+
}
145+
146+
mask1.push_back(s1);
147+
mask2.push_back(s2);
148+
s1++;
149+
s2++;
150+
}
151+
return std::make_pair(mask1, mask2);
152+
};
153+
154+
// Stage 1: vector<ixjx1xnxf16> equals to a grid of ixj of vector<1xnxf16>
155+
// after lowering to xegpu. This stage performs j-1 reduction operations on
156+
// j dim of the grid, the result is a vector of vector<mxnxf16> with size i.
157+
llvm::SmallVector<mlir::Value> intermediates(shape[0]);
158+
for (auto i = 0; i < shape[0]; i++) {
159+
auto combiningVal = sources[i * shape[1]];
160+
for (auto j = 1; j < shape[1]; j++) {
161+
combiningVal = createBinOp(kind, combiningVal, sources[i * shape[1] + j],
162+
elemTy, loc, rewriter);
163+
}
164+
// cast the result of e.g., vector<1x16xf16> into vector<16xf16>
165+
auto targetTy = mlir::VectorType::get({shape[3]}, elemTy);
166+
combiningVal =
167+
rewriter.create<mlir::vector::ShapeCastOp>(loc, targetTy, combiningVal);
168+
intermediates[i] = combiningVal;
169+
}
170+
171+
// Stage 2: doing intra vector reduction with shuffle Ops.
172+
// Each vector in the result of stage 1 can be viewed as a row
173+
// each row has e.g., 32 elements:
174+
// v1 = [a0 a1 a2 a3 a4 a5 a6 a7 a8 a9 ... a31]
175+
// v2 = [b0 b1 b2 b3 b4 b5 b6 b7 b8 b9 ... b31]
176+
// ...
177+
// vn = [p0 p1 p2 p3 p4 p5 p6 p7 p8 p9 ... p31]
178+
// it will repeately doing shuffle between two consecutive vectors
179+
// v1 and v2, v3 and v4, ..., vn-1 and vn with a block size. Such
180+
// that we can get two new vectors. The block size is typically
181+
// starts with half of the vector size. For example, for v1 and v2,
182+
// it is 16, and we can get:
183+
// nv1 = [a0, .., a15, b0, .., b15]
184+
// nv2 = [a16, .., a31, b16, .., b31]
185+
// and we then performs nv1 + nv2 (if reduction op is add)
186+
// such that the left half of the vector contains the partial reduction
187+
// of v1, and the right half contains the partial reduction of v2.
188+
// and the the number of vectors is reduced by half after one iteration.
189+
// and we reduce the block size by half, and repeat the process until
190+
// the block size is 1.
191+
// The intermediate result of this stage is an array of vectors with
192+
// type, e.g., vector<nxf16>, array size is `i/n`. And these vectors
193+
// will be merged into a single vector with type vector<ixf16>.
194+
auto blkSize = shape[3] / 2;
195+
while (blkSize) {
196+
auto workList = intermediates;
197+
intermediates.clear();
198+
assert(workList.size() % 2 == 0 && "The size should be divisible by 2.");
199+
auto masks = genShuffleMasks(blkSize, shape[3]);
200+
for (size_t i = 0; i < workList.size(); i += 2) {
201+
auto v1 = workList[i];
202+
auto v2 = workList[i + 1];
203+
auto shuffleOp1 =
204+
rewriter.create<mlir::vector::ShuffleOp>(loc, v1, v2, masks.first);
205+
auto shuffleOp2 =
206+
rewriter.create<mlir::vector::ShuffleOp>(loc, v1, v2, masks.second);
207+
auto reductionVal =
208+
createBinOp(kind, shuffleOp1, shuffleOp2, elemTy, loc, rewriter);
209+
intermediates.push_back(reductionVal);
210+
}
211+
blkSize /= 2;
212+
}
213+
return intermediates;
214+
}
215+
216+
class SgVectorMultiDimReductionOpPattern
217+
: public SgXeTileToXeGPUConversion<mlir::vector::MultiDimReductionOp> {
218+
using SgXeTileToXeGPUConversion<
219+
mlir::vector::MultiDimReductionOp>::SgXeTileToXeGPUConversion;
220+
221+
mlir::LogicalResult
222+
matchAndRewrite(mlir::vector::MultiDimReductionOp op, OpAdaptor adaptor,
223+
XeGPUOneToNPatterRewriter &rewriter) const override {
224+
auto srcTy = op.getSource().getType();
225+
auto elemTy = srcTy.getElementType();
226+
auto dims = op.getReductionDims();
227+
// its input should be a 4D vector, and has 2 reduction dims,
228+
// otherwise run the blocking pass first.
229+
if (dims.size() != 2 || srcTy.getRank() != 4)
230+
return mlir::failure();
231+
232+
auto loc = op.getLoc();
233+
auto shape = srcTy.getShape();
234+
auto sources = adaptor.getSource();
235+
236+
rewriter.setInsertionPoint(op);
237+
// doing reduction on outer dimension
238+
if (mlir::isConstantIntValue(dims[0], 0) &&
239+
mlir::isConstantIntValue(dims[1], 2)) {
240+
auto intermediates = lowerOuterReduction(sources, shape, op.getKind(),
241+
loc, elemTy, rewriter);
242+
{
243+
// TODO: need a better way to represent the result (align with
244+
// unpack/pack logic). currently we just shuffle them and cast it to the
245+
// type/shape in xetile program.
246+
auto reducedVal =
247+
mergeVectorsWrapper(intermediates, concat, loc, rewriter);
248+
auto targetTy = mlir::VectorType::get({shape[1], shape[3]}, elemTy);
249+
auto newOp = rewriter.create<mlir::vector::ShapeCastOp>(loc, targetTy,
250+
reducedVal);
251+
rewriter.replaceOp(op, newOp);
252+
}
253+
return mlir::success();
254+
}
255+
256+
// doing reduction on inner dimension
257+
if (mlir::isConstantIntValue(dims[0], 1) &&
258+
mlir::isConstantIntValue(dims[1], 3)) {
259+
auto intermediates = lowerInnerReductionWithIntraVectorShuffles(
260+
sources, shape, op.getKind(), loc, elemTy, rewriter);
261+
262+
{ // TODO: need a better way to represent the result (align with
263+
// unpack/pack logic).
264+
// currently we just shuffle them and cast it to the type/shape in
265+
// xetile program.
266+
auto reductionVal =
267+
mergeVectorsWrapper(intermediates, concat, loc, rewriter);
268+
auto targetTy = mlir::VectorType::get({shape[0], shape[2]}, elemTy);
269+
auto newOp = rewriter.create<mlir::vector::ShapeCastOp>(loc, targetTy,
270+
reductionVal);
271+
rewriter.replaceOp(op, newOp);
272+
}
273+
return mlir::success();
274+
}
275+
276+
// something is wrong
277+
return op.emitError("unsupported reduction operation.");
278+
}
279+
};
280+
20281
class SgArithConstantOpPattern
21282
: public SgXeTileToXeGPUConversion<mlir::arith::ConstantOp> {
22283
using SgXeTileToXeGPUConversion<
@@ -26,8 +287,7 @@ class SgArithConstantOpPattern
26287
matchAndRewrite(mlir::arith::ConstantOp op, OpAdaptor adaptor,
27288
XeGPUOneToNPatterRewriter &rewriter) const override {
28289
auto loc = op.getLoc();
29-
auto value =
30-
llvm::dyn_cast_if_present<mlir::DenseElementsAttr>(op.getValue());
290+
auto value = llvm::dyn_cast<mlir::DenseElementsAttr>(op.getValue());
31291

32292
// We only interesting 4D vectors
33293
if (!value || value.getType().getRank() != 4)
@@ -38,8 +298,8 @@ class SgArithConstantOpPattern
38298
value.value_end<mlir::Attribute>());
39299

40300
auto shape = value.getType().getShape();
41-
auto vecTy =
42-
mlir::VectorType::get({shape[2], shape[3]}, value.getElementType());
301+
auto elemTy = value.getElementType();
302+
auto vecTy = mlir::VectorType::get({shape[2], shape[3]}, elemTy);
43303

44304
// slice a block of (shape[2], shape[3]) from elems.
45305
auto slice = [&](int i, int j) {
@@ -83,8 +343,8 @@ bool isLegalArithOp(mlir::Operation *op) {
83343
void populateArithOpConversionPatterns(imex::XeGPUTypeConverter &converter,
84344
mlir::RewritePatternSet &patterns,
85345
TileUsageAnalysis &analysis) {
86-
patterns.add<SgArithConstantOpPattern>(patterns.getContext(), converter,
87-
analysis);
346+
patterns.add<SgArithConstantOpPattern, SgVectorMultiDimReductionOpPattern>(
347+
patterns.getContext(), converter, analysis);
88348
}
89349

90350
} // namespace imex

0 commit comments

Comments
 (0)