@@ -139,9 +139,20 @@ template <typename T> T findFactor(T number, T closeTo) {
139139 return closeTo;
140140}
141141
142+ namespace impl {
143+ // Controls the adjustment in case of more than 2 tiles.
144+ enum class AdjustTilesMode {
145+ // Sort the input and switch to the First mode.
146+ Sort,
147+ // Adjust the first and call adjustTiles() recursively for the rest.
148+ First,
149+ // To allow for squeezing, set 1's for all tiles except the last 2.
150+ XeGpu,
151+ };
152+
142153template <typename T>
143154static void adjustTwoTiles (T totalSize, T *aPtr, T *bPtr,
144- T minSize = static_cast <T>( 1 ) ) {
155+ AdjustTilesMode mode ) {
145156 T a = *aPtr;
146157 T b = *bPtr;
147158 assert (a >= b);
@@ -150,6 +161,7 @@ static void adjustTwoTiles(T totalSize, T *aPtr, T *bPtr,
150161 return ;
151162 }
152163
164+ T minSize = static_cast <T>(mode == AdjustTilesMode::XeGpu ? 8 : 1 );
153165 bool aPow2 = isPow2 (a);
154166 bool bPow2 = isPow2 (b);
155167 double ratio = static_cast <double >(a) / static_cast <double >(b);
@@ -208,14 +220,14 @@ static void adjustTwoTiles(T totalSize, T *aPtr, T *bPtr,
208220// and, if possible, is a power of 2.
209221template <typename T>
210222static void adjustTiles (T totalSize, T *begin, T *end,
211- T minSize = static_cast <T>(1 ), bool isSorted = false) {
212- assert ((minSize & (minSize - 1 )) == 0 && " minSize must be a power of 2" );
223+ AdjustTilesMode mode = AdjustTilesMode::Sort) {
213224 auto count = end - begin;
214225 if (count == 0 ) {
215226 return ;
216227 }
217228
218229 if (count == 1 ) {
230+ T minSize = static_cast <T>(mode == AdjustTilesMode::XeGpu ? 8 : 1 );
219231 if (T a = *begin; isPow2 (a)) {
220232 *begin = std::min (std::max (ceilPow2 (a), minSize), floorPow2 (totalSize));
221233 } else {
@@ -225,15 +237,29 @@ static void adjustTiles(T totalSize, T *begin, T *end,
225237 }
226238
227239 if (count > 2 ) {
240+ if (mode == AdjustTilesMode::XeGpu) {
241+ for (unsigned i = 0 ; i < count - 2 ; ++i) {
242+ *(begin + i) = 1 ;
243+ }
244+ T *aPtr = end - 2 ;
245+ T *bPtr = end - 1 ;
246+ if (*aPtr < *bPtr) {
247+ std::swap (aPtr, bPtr);
248+ }
249+ adjustTwoTiles (totalSize, aPtr, bPtr, mode);
250+ return ;
251+ }
252+
228253 SmallVector<T> sorted;
229254 SmallVector<unsigned > indices;
230255 T *head;
231256 T *tail;
232257
233- if (isSorted ) {
258+ if (mode == AdjustTilesMode::First ) {
234259 head = begin;
235260 tail = end;
236261 } else {
262+ assert (mode == AdjustTilesMode::Sort);
237263 SmallVector<std::pair<T, unsigned >> pairs;
238264 pairs.reserve (count);
239265 for (unsigned i = 0 ; i < count; ++i) {
@@ -254,26 +280,29 @@ static void adjustTiles(T totalSize, T *begin, T *end,
254280 // first one and the product of the rest. The second one is the rest.
255281 T first[] = {*head, std::accumulate (head + 2 , tail, *(head + 1 ),
256282 std::multiplies<>())};
257- adjustTiles (totalSize, first, first + 2 , minSize, true );
258- adjustTiles (totalSize / *first, head + 1 , tail, minSize, true );
283+ adjustTiles (totalSize, first, first + 2 , AdjustTilesMode::First );
284+ adjustTiles (totalSize / *first, head + 1 , tail, AdjustTilesMode::First );
259285 *head = *first;
260286
261- if (!isSorted ) {
287+ if (mode == AdjustTilesMode::Sort ) {
262288 for (unsigned i = 0 ; i < count; ++i) {
263289 *(begin + indices[i]) = sorted[i];
264290 }
265291 }
266292 } else if (*begin >= *(end - 1 )) {
267- adjustTwoTiles (totalSize, begin, end - 1 , minSize );
293+ adjustTwoTiles (totalSize, begin, end - 1 , mode );
268294 } else {
269- adjustTwoTiles (totalSize, end - 1 , begin, minSize );
295+ adjustTwoTiles (totalSize, end - 1 , begin, mode );
270296 }
271297}
298+ } // namespace impl
272299
273300template <typename T, unsigned N>
274301static void adjustTiles (T totalSize, SmallVector<T, N> &tiles,
275- T minSize = static_cast <T>(1 )) {
276- adjustTiles (totalSize, tiles.begin (), tiles.end (), minSize);
302+ bool xeGpuMode = false ) {
303+ impl::adjustTiles (totalSize, tiles.begin (), tiles.end (),
304+ xeGpuMode ? impl::AdjustTilesMode::XeGpu
305+ : impl::AdjustTilesMode::Sort);
277306}
278307
279308// Check recursively if the specified operation has an operand that
0 commit comments