@@ -100,7 +100,8 @@ SmallVector<int32_t> nullspaceBasis(ArrayRef<int32_t> vectors, int32_t dim) {
100100// without sacrificing vectorisation and split it into its own
101101// `reps` dimension
102102LinearLayout buildReps (MLIRContext *ctx, const LinearLayout &src,
103- const LinearLayout &dst, const LinearLayout &smem) {
103+ const LinearLayout &dst, const LinearLayout &smem,
104+ int32_t leaveReps) {
104105 auto kVec = StringAttr::get (ctx, " vector" );
105106 auto kBank = StringAttr::get (ctx, " bank" );
106107 auto kSegment = StringAttr::get (ctx, " segment" );
@@ -116,8 +117,16 @@ LinearLayout buildReps(MLIRContext *ctx, const LinearLayout &src,
116117 SetVector<int32_t > segment;
117118 SetVector<int32_t > reps;
118119 for (auto s : smemSegment) {
120+ // Do not move the first leaveReps bases from reps to segment
121+ // as we need them to vectorise the instructions (think .x2 and .x4 in
122+ // ldmatrix)
119123 if (srcRegs.contains (s) && dstRegs.contains (s)) {
120- reps.insert (s);
124+ if (leaveReps > 0 ) {
125+ leaveReps--;
126+ segment.insert (s);
127+ } else {
128+ reps.insert (s);
129+ }
121130 } else {
122131 segment.insert (s);
123132 }
@@ -376,11 +385,12 @@ std::optional<SmallVector<int32_t>> optimalSwizzlingTile(
376385 return vbasis;
377386}
378387
379- LinearLayout
380- optimalSwizzling (const LinearLayout &src, const LinearLayout &dst,
381- int32_t bitwidth, ArrayRef<int32_t > vbasis,
382- ArrayRef<int32_t > tileSrc, ArrayRef<int32_t > tileDst,
383- ArrayRef<std::pair<StringAttr, int32_t >> outDims) {
388+ LinearLayout optimalSwizzling (const LinearLayout &src, const LinearLayout &dst,
389+ int32_t bitwidth, ArrayRef<int32_t > vbasis,
390+ ArrayRef<int32_t > tileSrc,
391+ ArrayRef<int32_t > tileDst,
392+ ArrayRef<std::pair<StringAttr, int32_t >> outDims,
393+ int32_t leaveReps = 0 ) {
384394 // We work on the flattened tensors as the tensor dimensions are not relevant
385395 assert (src.getNumOutDims () == 1 && dst.getNumOutDims () == 1 &&
386396 " src and dst must have a single output dimension" );
@@ -439,7 +449,7 @@ optimalSwizzling(const LinearLayout &src, const LinearLayout &dst,
439449 {bankAttr, unflatten (bbasis)},
440450 {segAttr, unflatten (sbasis)}},
441451 src.getOutDims (), /* requireSurjective=*/ true );
442- basis1D = buildReps (ctx, src, dst, basis1D);
452+ basis1D = buildReps (ctx, src, dst, basis1D, leaveReps );
443453
444454 return basis1D.reshapeOuts (outDims);
445455}
@@ -649,7 +659,7 @@ optimalSwizzling(const LinearLayout &src, const LinearLayout &dst,
649659
650660 // Get the associated src/dst tiles for each instruction if they exist
651661 SmallVector<std::tuple<std::pair<int32_t , int32_t >, SmallVector<int32_t >,
652- SmallVector<int32_t >, SmallVector<int32_t >>>
662+ SmallVector<int32_t >, SmallVector<int32_t >, int32_t >>
653663 tiles;
654664 for (auto [instrs, vbasis] : instr) {
655665 auto maybeTileSrc =
@@ -659,22 +669,31 @@ optimalSwizzling(const LinearLayout &src, const LinearLayout &dst,
659669 if (!maybeTileSrc.has_value () || !maybeTileDst.has_value ()) {
660670 continue ;
661671 }
672+ // Regs bases missing to get full vectorisation
673+ auto regsMissing = [](const LocalMemOpTile &instr) {
674+ return instr.laneContig .size () + instr.laneAddr .size () - 3 ;
675+ };
676+ // We leave 2 reps for combinations of ldmatrix/stmatrix instructions
677+ // to be able to fully vectorise them
678+ int32_t leaveReps = std::min (regsMissing (srcTiles[instrs.first ]),
679+ regsMissing (dstTiles[instrs.second ]));
680+ assert ((leaveReps == 0 || leaveReps == 2 ) && " leaveReps must be 0 or 2" );
662681 tiles.push_back ({instrs, std::move (vbasis), std::move (*maybeTileSrc),
663- std::move (*maybeTileDst)});
682+ std::move (*maybeTileDst), leaveReps });
664683 }
665684
666685 if (tiles.empty ()) {
667686 // We lower to an ld / st, but can't use LDS128/STS128
668687 auto smem = optimalSwizzlingLdSt (src, dst, bitwidth);
669688 return {smem, {0 , 0 }};
670689 } else {
671- // We choose the pair of instructions that minimises the total bank
672- // conflicts
673690 SmallVector<std::tuple<int , LinearLayout, std::pair<int32_t , int32_t >>>
674691 smems;
675- for (auto [instrs, vbasis, tileSrc, tileDst] : tiles) {
692+ // We choose the pair of instructions that minimises the total bank
693+ // conflicts
694+ for (auto [instrs, vbasis, tileSrc, tileDst, leaveReps] : tiles) {
676695 auto smem = optimalSwizzling (srcFlat, dstFlat, bitwidth, vbasis, tileSrc,
677- tileDst, src.getOutDims ());
696+ tileDst, src.getOutDims (), leaveReps );
678697 auto [read, write] = bankConflicts (tileSrc, tileDst, smem);
679698 smems.push_back ({read + write, smem, {instrs.first , instrs.second }});
680699 }
0 commit comments