Skip to content

Commit e3994c7

Browse files
committed
Merge remote-tracking branch 'origin/main' into xegpu-squeeze
2 parents 21ee00b + 2d51c4e commit e3994c7

File tree

3 files changed

+60
-31
lines changed

3 files changed

+60
-31
lines changed

lib/gc/Transforms/GPU/GpuTilingAndFusion.cpp

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -101,17 +101,16 @@ struct GpuTilingAndFusion final
101101
auto sizePerThread = numIterations / numThreads * elementSize;
102102
auto totalSize = std::max(sizePerThread, cachePerThread);
103103
totalSize = std::max(totalSize / elementSize, 64L);
104-
int64_t minTileSize = 1;
104+
bool xeGpu = canLowerToXeGPU(op);
105105

106106
// If the operation could be lowered to XeGPU, make the tiles
107-
// multiple of the vector width and the minimum tile size 8.
108-
if (canLowerToXeGPU(op)) {
109-
minTileSize = 8;
107+
// multiple of the vector width.
108+
if (xeGpu) {
110109
totalSize = std::max(totalSize / vectorWidth, 1L) * vectorWidth;
111110
}
112111

113112
SmallVector<int64_t> tiles = sizes;
114-
adjustTiles(totalSize, tiles, minTileSize);
113+
adjustTiles(totalSize, tiles, xeGpu);
115114

116115
// If the tiles are equal to the sizes, split the largest tile
117116
// to avoid loops elimination by the canonicalizer pass.
@@ -356,16 +355,12 @@ struct GpuTilingAndFusion final
356355
return false;
357356
}
358357

359-
auto shape = type.getShape();
360-
if (isOutput) {
361-
if (shape.size() != 2 || shape[0] * shape[1] < 16) {
362-
return false;
363-
}
364-
} else if (shape.size() > 2) {
365-
return false;
358+
if (auto shape = type.getShape(); shape.size() >= 2) {
359+
return !isOutput ||
360+
std::accumulate(shape.begin() + 1, shape.end(), shape[0],
361+
std::multiplies<>()) >= 16;
366362
}
367-
368-
return true;
363+
return false;
369364
};
370365

371366
if (auto inits = op.getDpsInits();

lib/gc/Transforms/GPU/GpuUtils.h

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,20 @@ template <typename T> T findFactor(T number, T closeTo) {
139139
return closeTo;
140140
}
141141

142+
namespace impl {
143+
// Controls the adjustment in case of more than 2 tiles.
144+
enum class AdjustTilesMode {
145+
// Sort the input and switch to the First mode.
146+
Sort,
147+
// Adjust the first tile and call adjustTiles() recursively for the rest.
148+
First,
149+
// To allow for squeezing, set 1's for all tiles except the last 2.
150+
XeGpu,
151+
};
152+
142153
template <typename T>
143154
static void adjustTwoTiles(T totalSize, T *aPtr, T *bPtr,
144-
T minSize = static_cast<T>(1)) {
155+
AdjustTilesMode mode) {
145156
T a = *aPtr;
146157
T b = *bPtr;
147158
assert(a >= b);
@@ -150,6 +161,7 @@ static void adjustTwoTiles(T totalSize, T *aPtr, T *bPtr,
150161
return;
151162
}
152163

164+
T minSize = static_cast<T>(mode == AdjustTilesMode::XeGpu ? 8 : 1);
153165
bool aPow2 = isPow2(a);
154166
bool bPow2 = isPow2(b);
155167
double ratio = static_cast<double>(a) / static_cast<double>(b);
@@ -208,14 +220,14 @@ static void adjustTwoTiles(T totalSize, T *aPtr, T *bPtr,
208220
// and, if possible, is a power of 2.
209221
template <typename T>
210222
static void adjustTiles(T totalSize, T *begin, T *end,
211-
T minSize = static_cast<T>(1), bool isSorted = false) {
212-
assert((minSize & (minSize - 1)) == 0 && "minSize must be a power of 2");
223+
AdjustTilesMode mode = AdjustTilesMode::Sort) {
213224
auto count = end - begin;
214225
if (count == 0) {
215226
return;
216227
}
217228

218229
if (count == 1) {
230+
T minSize = static_cast<T>(mode == AdjustTilesMode::XeGpu ? 8 : 1);
219231
if (T a = *begin; isPow2(a)) {
220232
*begin = std::min(std::max(ceilPow2(a), minSize), floorPow2(totalSize));
221233
} else {
@@ -225,15 +237,29 @@ static void adjustTiles(T totalSize, T *begin, T *end,
225237
}
226238

227239
if (count > 2) {
240+
if (mode == AdjustTilesMode::XeGpu) {
241+
for (unsigned i = 0; i < count - 2; ++i) {
242+
*(begin + i) = 1;
243+
}
244+
T *aPtr = end - 2;
245+
T *bPtr = end - 1;
246+
if (*aPtr < *bPtr) {
247+
std::swap(aPtr, bPtr);
248+
}
249+
adjustTwoTiles(totalSize, aPtr, bPtr, mode);
250+
return;
251+
}
252+
228253
SmallVector<T> sorted;
229254
SmallVector<unsigned> indices;
230255
T *head;
231256
T *tail;
232257

233-
if (isSorted) {
258+
if (mode == AdjustTilesMode::First) {
234259
head = begin;
235260
tail = end;
236261
} else {
262+
assert(mode == AdjustTilesMode::Sort);
237263
SmallVector<std::pair<T, unsigned>> pairs;
238264
pairs.reserve(count);
239265
for (unsigned i = 0; i < count; ++i) {
@@ -254,26 +280,29 @@ static void adjustTiles(T totalSize, T *begin, T *end,
254280
// first one and the product of the rest. The second one is the rest.
255281
T first[] = {*head, std::accumulate(head + 2, tail, *(head + 1),
256282
std::multiplies<>())};
257-
adjustTiles(totalSize, first, first + 2, minSize, true);
258-
adjustTiles(totalSize / *first, head + 1, tail, minSize, true);
283+
adjustTiles(totalSize, first, first + 2, AdjustTilesMode::First);
284+
adjustTiles(totalSize / *first, head + 1, tail, AdjustTilesMode::First);
259285
*head = *first;
260286

261-
if (!isSorted) {
287+
if (mode == AdjustTilesMode::Sort) {
262288
for (unsigned i = 0; i < count; ++i) {
263289
*(begin + indices[i]) = sorted[i];
264290
}
265291
}
266292
} else if (*begin >= *(end - 1)) {
267-
adjustTwoTiles(totalSize, begin, end - 1, minSize);
293+
adjustTwoTiles(totalSize, begin, end - 1, mode);
268294
} else {
269-
adjustTwoTiles(totalSize, end - 1, begin, minSize);
295+
adjustTwoTiles(totalSize, end - 1, begin, mode);
270296
}
271297
}
298+
} // namespace impl
272299

273300
template <typename T, unsigned N>
274301
static void adjustTiles(T totalSize, SmallVector<T, N> &tiles,
275-
T minSize = static_cast<T>(1)) {
276-
adjustTiles(totalSize, tiles.begin(), tiles.end(), minSize);
302+
bool xeGpuMode = false) {
303+
impl::adjustTiles(totalSize, tiles.begin(), tiles.end(),
304+
xeGpuMode ? impl::AdjustTilesMode::XeGpu
305+
: impl::AdjustTilesMode::Sort);
277306
}
278307

279308
// Check recursively if the specified operation has an operand that

test/mlir/unittests/Transforms/GPU/GpuUtilsTest.cpp

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
TEST(testAdjustTiles, GputUtilsTest) {
1616
bool print = false;
1717
auto testAdjust = [print](int64_t totalSize, SmallVector<int64_t> &tiles,
18-
const SmallVector<int64_t> &expected) {
18+
const SmallVector<int64_t> &expected,
19+
bool xeGpu = false) {
1920
if (print) {
2021
std::cout << totalSize << ": [";
2122
for (unsigned i = 0; i < tiles.size(); i++) {
@@ -24,7 +25,7 @@ TEST(testAdjustTiles, GputUtilsTest) {
2425
std::cout << "] -> [";
2526
}
2627

27-
gc::adjustTiles(totalSize, tiles);
28+
gc::adjustTiles(totalSize, tiles, xeGpu);
2829

2930
if (print) {
3031
for (unsigned i = 0; i < tiles.size(); i++) {
@@ -36,15 +37,15 @@ TEST(testAdjustTiles, GputUtilsTest) {
3637
EXPECT_EQ(tiles, expected);
3738
};
3839
auto test = [testAdjust](int64_t totalSize, SmallVector<int64_t> tiles,
39-
SmallVector<int64_t> expected) {
40+
SmallVector<int64_t> expected, bool xeGpu = false) {
4041
if (tiles.size() != 2 || tiles[0] == tiles[1]) {
41-
testAdjust(totalSize, tiles, expected);
42+
testAdjust(totalSize, tiles, expected, xeGpu);
4243
return;
4344
}
4445
SmallVector<int64_t> reversed(tiles.rbegin(), tiles.rend());
45-
testAdjust(totalSize, tiles, expected);
46+
testAdjust(totalSize, tiles, expected, xeGpu);
4647
std::reverse(expected.begin(), expected.end());
47-
testAdjust(totalSize, reversed, expected);
48+
testAdjust(totalSize, reversed, expected, xeGpu);
4849
};
4950

5051
test(8, {1, 1}, {1, 1});
@@ -91,4 +92,8 @@ TEST(testAdjustTiles, GputUtilsTest) {
9192
test(16384, {60, 128, 512}, {4, 32, 128});
9293
test(16384, {119, 256, 512}, {7, 32, 64});
9394
test(16384, {109, 256, 512}, {109, 8, 16});
95+
96+
test(16384, {8, 16, 32, 256, 512}, {1, 1, 1, 128, 128}, true);
97+
test(16384, {8, 16, 32, 1024, 256}, {1, 1, 1, 256, 64}, true);
98+
test(16384, {8, 16, 32, 16, 4096}, {1, 1, 1, 8, 2048}, true);
9499
}

0 commit comments

Comments
 (0)