@@ -288,60 +288,71 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
288288 return rewriter.notifyMatchFailure (
289289 op, " NYI. srcTy and/or dstTy don't implement LLs yet" );
290290 }
291+ LinearLayout srcLayout =
292+ *toLinearLayout (srcTy.getShape (), srcTy.getEncoding ());
293+ LinearLayout dstLayout =
294+ *toLinearLayout (dstTy.getShape (), dstTy.getEncoding ());
295+
296+ StringAttr kBlock = str_attr (" block" );
297+ StringAttr kWarp = str_attr (" warp" );
298+ StringAttr kLane = str_attr (" lane" );
299+ StringAttr kRegister = str_attr (" register" );
291300
292301 assert (to_vector (conversion->getInDimNames ()) ==
293302 to_vector (conversion->getOutDimNames ()));
294303 auto dims = conversion->getInDimNames ();
295- if (llvm::is_contained (dims, str_attr ( " block " ) )) {
304+ if (llvm::is_contained (dims, kBlock )) {
296305 // Case 1: Transfer between values in different CTAs.
297306 // This requires moving values through distributed shared memory.
298307 return rewriter.notifyMatchFailure (
299308 op, " NYI: Transfer between different CTAs" );
300- } else if (llvm::is_contained (dims, str_attr ( " warp " ) )) {
309+ } else if (llvm::is_contained (dims, kWarp )) {
301310 // Case 2: Transfer between values in the same CTA, in which case we move
302311 // values through shared memory.
303- LinearLayout srcLayout =
304- *toLinearLayout (srcTy.getShape (), srcTy.getEncoding ());
305- LinearLayout dstLayout =
306- *toLinearLayout (dstTy.getShape (), dstTy.getEncoding ());
307312 return transferWithinBlock (op, srcLayout, dstLayout, adaptor, rewriter);
308- } else if (llvm::is_contained (dims, str_attr ( " lane " ) )) {
313+ } else if (llvm::is_contained (dims, kLane )) {
309314 // Case 3. Transfer between values in the same warp, in which case we try
310315 // to move values using warp shuffles, though if the pattern is
311316 // complicated enough we may fall back to using shared memory
312317 // TODO(Keren): implement warp shuffle instead of using the general
313318 // approach that uses shared memory
314- LinearLayout srcLayout =
315- *toLinearLayout (srcTy.getShape (), srcTy.getEncoding ());
316- LinearLayout dstLayout =
317- *toLinearLayout (dstTy.getShape (), dstTy.getEncoding ());
318319 return transferWithinBlock (op, srcLayout, dstLayout, adaptor, rewriter);
319- } else if (llvm::is_contained (dims, str_attr (" register" ))) {
320+ } else if (llvm::is_contained (dims, kRegister ) ||
321+ dstLayout.getInDimSize (kRegister ) !=
322+ srcLayout.getInDimSize (kRegister )) {
320323 // Case 4. Transfer between values in the same thread, in which case we
321324 // simply reorder the elements of adaptor.getSrc().
322- return transferWithinThread (op, *conversion, adaptor, rewriter);
325+ return transferWithinThread (
326+ op, dstLayout.getFreeVariableMasks ()[kRegister ],
327+ dstLayout.getInDimSize (kRegister ), *conversion, adaptor, rewriter);
323328 } else {
324- // The two layouts are equivalent. We should probably remove these in
325- // RemoveLayoutConversion.
329+ // Cast 5. The two layouts are equivalent. We should probably remove
330+ // these in RemoveLayoutConversion.
326331 rewriter.replaceOp (op, adaptor.getSrc ());
327332 return success ();
328333 }
329334 }
330335
331336 LogicalResult
332- transferWithinThread (ConvertLayoutOp op, const LinearLayout &conversion ,
333- OpAdaptor adaptor,
337+ transferWithinThread (ConvertLayoutOp op, int32_t regMasks, int32_t numRegs ,
338+ const LinearLayout &conversion, OpAdaptor adaptor,
334339 ConversionPatternRewriter &rewriter) const {
335340 MLIRContext *ctx = op.getContext ();
336341 auto loc = op.getLoc ();
337342 StringAttr kRegister = str_attr (" register" );
338343 assert (!cvtNeedsSharedMemory (op.getSrc ().getType (), op.getType ()));
339344
340345 auto inVals = unpackLLElements (loc, adaptor.getSrc (), rewriter);
341- SmallVector<Value> outVals;
342- outVals.resize (conversion.getInDimSize (kRegister ));
343- for (int i = 0 ; i < conversion.getInDimSize (kRegister ); i++) {
344- auto srcIdx = conversion.apply ({{kRegister , i}}).begin ()->second ;
346+ SmallVector<Value> outVals (numRegs);
347+ for (int i = 0 ; i < outVals.size (); i++) {
348+ // Remove free masks from the register index
349+ // For example, if idx = 0b00111, and masks = 0b00100, then we get
350+ // 0b00011. It means that register 7 (0b111) has the same value as
351+ // register 3 (0b011).
352+ auto idx = i & (~regMasks);
353+ auto srcIdx = conversion.hasInDim (kRegister )
354+ ? conversion.apply ({{kRegister , idx}}).begin ()->second
355+ : idx;
345356 outVals[i] = inVals[srcIdx];
346357 }
347358 Value result = packLLElements (loc, getTypeConverter (), outVals, rewriter,
0 commit comments