@@ -47,6 +47,16 @@ SmallVector<int32_t> flatten(const LinearLayout &ll, StringAttr dim) {
47
47
return vec;
48
48
};
49
49
50
+ SmallVector<int32_t > removeZeros (ArrayRef<int32_t > vec) {
51
+ SmallVector<int32_t > result;
52
+ for (int32_t r : vec) {
53
+ if (r != 0 ) {
54
+ result.push_back (r);
55
+ }
56
+ }
57
+ return result;
58
+ }
59
+
50
60
// [1, 2, 4, 8] -> [[1], [2], [4], [8]]
51
61
std::vector<std::vector<int32_t >> unflatten (ArrayRef<int32_t > basis) {
52
62
std::vector<std::vector<int32_t >> unflattened;
@@ -279,6 +289,7 @@ LinearLayout optimalSwizzling(const LinearLayout &src, const LinearLayout &dst,
279
289
auto *ctx = src.getInDimNames ().begin ()->getContext ();
280
290
auto kReg = StringAttr::get (ctx, " register" );
281
291
auto kLane = StringAttr::get (ctx, " lane" );
292
+ auto kWarp = StringAttr::get (ctx, " warp" );
282
293
283
294
// We work on the flattened tensors as the tensor dimensions are not relevant
284
295
const LinearLayout srcFlat = src.flattenOuts ();
@@ -307,6 +318,65 @@ LinearLayout optimalSwizzling(const LinearLayout &src, const LinearLayout &dst,
307
318
if (vbasis.size () > maxVecBases) {
308
319
vbasis.resize (maxVecBases);
309
320
}
321
+ // We fill-up vbasis until it has 32 bits as best we can
322
+ auto vecFillsBank = (1 << vbasis.size ()) * bitwidth >= 32 ;
323
+ if (!vecFillsBank) {
324
+ auto warpSrc = removeZeros (flatten (srcFlat, kWarp ));
325
+ auto warpDst = removeZeros (flatten (dstFlat, kWarp ));
326
+ auto removeVec = [&vbasis](ArrayRef<int32_t > vec) {
327
+ SmallVector<int32_t > result;
328
+ for (int32_t r : vec) {
329
+ if (!llvm::is_contained (vbasis, r)) {
330
+ result.push_back (r);
331
+ }
332
+ }
333
+ return result;
334
+ };
335
+ auto regSrcWarp = intersectionBasis (removeVec (regSrc), warpDst, dim);
336
+ auto regDstWarp = intersectionBasis (removeVec (regDst), warpSrc, dim);
337
+ // Maximise vectorisation in the load or the store without creating
338
+ // conflicts
339
+ SmallVector<int32_t > largest;
340
+ if (regSrcWarp.size () == regDstWarp.size () && regSrcWarp.size () > 0 ) {
341
+ // We choose the one with the lowest basis in the hope that it will
342
+ // avoid PRMTs. The comparison of the mins will be strict as the sets
343
+ // removeVec(regSrc) and removeVec(regDst) don't intersect
344
+ if (*llvm::min_element (regSrcWarp) < *llvm::min_element (regDstWarp)) {
345
+ largest = regSrcWarp;
346
+ } else {
347
+ largest = regDstWarp;
348
+ }
349
+ } else {
350
+ largest = regSrcWarp.size () > regDstWarp.size () ? regSrcWarp : regDstWarp;
351
+ }
352
+ vbasis.append (largest.begin (), largest.end ());
353
+ if (vbasis.size () > maxVecBases) {
354
+ vbasis.resize (maxVecBases);
355
+ }
356
+ // We allow vbasis.size > Log2_32(32 / bitwidth) at this point, as it is in
357
+ // general good, but one should note
358
+ if (vbasis.size () < llvm::Log2_32 (32 / bitwidth)) {
359
+ // Pad the vectorisation to 32 bits with warp bases
360
+ auto warpSrcWarp = intersectionBasis (warpSrc, warpDst, dim);
361
+ vbasis.append (warpSrcWarp.begin (), warpSrcWarp.end ());
362
+ }
363
+
364
+ int i = 0 ;
365
+ while (vbasis.size () < llvm::Log2_32 (32 / bitwidth) &&
366
+ (i < warpSrc.size () || i < warpDst.size ())) {
367
+ // If we have not filled up a whole bank, we add more warp bases
368
+ // until we have 32 bits. They will at least avoid bank conflicts in one
369
+ // direction
370
+ if (i < warpSrc.size () && !llvm::is_contained (vbasis, warpSrc[i])) {
371
+ vbasis.push_back (warpSrc[i]);
372
+ }
373
+ if (vbasis.size () < llvm::Log2_32 (32 / bitwidth) && i < warpDst.size () &&
374
+ !llvm::is_contained (vbasis, warpDst[i])) {
375
+ vbasis.push_back (warpDst[i]);
376
+ }
377
+ ++i;
378
+ }
379
+ }
310
380
311
381
// Bits in a bank segment: 32 banks x 32 bits
312
382
constexpr int32_t bankBits = 32 * 32 ;
@@ -321,8 +391,11 @@ LinearLayout optimalSwizzling(const LinearLayout &src, const LinearLayout &dst,
321
391
auto bankDst = llvm::to_vector (llvm::concat<int32_t >(vbasis, laneDst));
322
392
323
393
// Whether we'll use b32.v1 / b32.v2 / b32.v4
324
- auto b32Vec =
325
- llvm::Log2_32 (std::max<int32_t >((1 << vbasis.size ()) * bitwidth / 32 , 1 ));
394
+ // FIXME: With !vecFillsBank we may use b32.v2 or b32.v4 for the load or
395
+ // store, but we pesimistically assume we don't.
396
+ auto b32Vec = !vecFillsBank ? 0
397
+ : llvm::Log2_32 (std::max<int32_t >(
398
+ (1 << vbasis.size ()) * bitwidth / 32 , 1 ));
326
399
// Drop the last vec bases of the banks
327
400
bankSrc.resize (bankSrc.size () - b32Vec);
328
401
bankDst.resize (bankDst.size () - b32Vec);
0 commit comments