@@ -288,71 +288,60 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
288288 return rewriter.notifyMatchFailure (
289289 op, " NYI. srcTy and/or dstTy don't implement LLs yet" );
290290 }
291- LinearLayout srcLayout =
292- *toLinearLayout (srcTy.getShape (), srcTy.getEncoding ());
293- LinearLayout dstLayout =
294- *toLinearLayout (dstTy.getShape (), dstTy.getEncoding ());
295-
296- StringAttr kBlock = str_attr (" block" );
297- StringAttr kWarp = str_attr (" warp" );
298- StringAttr kLane = str_attr (" lane" );
299- StringAttr kRegister = str_attr (" register" );
300291
301292 assert (to_vector (conversion->getInDimNames ()) ==
302293 to_vector (conversion->getOutDimNames ()));
303294 auto dims = conversion->getInDimNames ();
304- if (llvm::is_contained (dims, kBlock )) {
295+ if (llvm::is_contained (dims, str_attr ( " block " ) )) {
305296 // Case 1: Transfer between values in different CTAs.
306297 // This requires moving values through distributed shared memory.
307298 return rewriter.notifyMatchFailure (
308299 op, " NYI: Transfer between different CTAs" );
309- } else if (llvm::is_contained (dims, kWarp )) {
300+ } else if (llvm::is_contained (dims, str_attr ( " warp " ) )) {
310301 // Case 2: Transfer between values in the same CTA, in which case we move
311302 // values through shared memory.
303+ LinearLayout srcLayout =
304+ *toLinearLayout (srcTy.getShape (), srcTy.getEncoding ());
305+ LinearLayout dstLayout =
306+ *toLinearLayout (dstTy.getShape (), dstTy.getEncoding ());
312307 return transferWithinBlock (op, srcLayout, dstLayout, adaptor, rewriter);
313- } else if (llvm::is_contained (dims, kLane )) {
308+ } else if (llvm::is_contained (dims, str_attr ( " lane " ) )) {
314309 // Case 3. Transfer between values in the same warp, in which case we try
315310 // to move values using warp shuffles, though if the pattern is
316311 // complicated enough we may fall back to using shared memory
317312 // TODO(Keren): implement warp shuffle instead of using the general
318313 // approach that uses shared memory
314+ LinearLayout srcLayout =
315+ *toLinearLayout (srcTy.getShape (), srcTy.getEncoding ());
316+ LinearLayout dstLayout =
317+ *toLinearLayout (dstTy.getShape (), dstTy.getEncoding ());
319318 return transferWithinBlock (op, srcLayout, dstLayout, adaptor, rewriter);
320- } else if (llvm::is_contained (dims, kRegister ) ||
321- dstLayout.getInDimSize (kRegister ) !=
322- srcLayout.getInDimSize (kRegister )) {
319+ } else if (llvm::is_contained (dims, str_attr (" register" ))) {
323320 // Case 4. Transfer between values in the same thread, in which case we
324321 // simply reorder the elements of adaptor.getSrc().
325- return transferWithinThread (
326- op, dstLayout.getFreeVariableMasks ()[kRegister ],
327- dstLayout.getInDimSize (kRegister ), *conversion, adaptor, rewriter);
322+ return transferWithinThread (op, *conversion, adaptor, rewriter);
328323 } else {
329- // Cast 5. The two layouts are equivalent. We should probably remove
330- // these in RemoveLayoutConversion.
324+ // The two layouts are equivalent. We should probably remove these in
325+ // RemoveLayoutConversion.
331326 rewriter.replaceOp (op, adaptor.getSrc ());
332327 return success ();
333328 }
334329 }
335330
336331 LogicalResult
337- transferWithinThread (ConvertLayoutOp op, int32_t regMasks, int32_t numRegs ,
338- const LinearLayout &conversion, OpAdaptor adaptor,
332+ transferWithinThread (ConvertLayoutOp op, const LinearLayout &conversion ,
333+ OpAdaptor adaptor,
339334 ConversionPatternRewriter &rewriter) const {
340335 MLIRContext *ctx = op.getContext ();
341336 auto loc = op.getLoc ();
342337 StringAttr kRegister = str_attr (" register" );
343338 assert (!cvtNeedsSharedMemory (op.getSrc ().getType (), op.getType ()));
344339
345340 auto inVals = unpackLLElements (loc, adaptor.getSrc (), rewriter);
346- SmallVector<Value> outVals (numRegs);
347- for (int i = 0 ; i < outVals.size (); i++) {
348- // Remove free masks from the register index
349- // For example, if idx = 0b00111, and masks = 0b00100, then we get
350- // 0b00011. It means that register 7 (0b111) has the same value as
351- // register 3 (0b011).
352- auto idx = i & (~regMasks);
353- auto srcIdx = conversion.hasInDim (kRegister )
354- ? conversion.apply ({{kRegister , idx}}).begin ()->second
355- : idx;
341+ SmallVector<Value> outVals;
342+ outVals.resize (conversion.getInDimSize (kRegister ));
343+ for (int i = 0 ; i < conversion.getInDimSize (kRegister ); i++) {
344+ auto srcIdx = conversion.apply ({{kRegister , i}}).begin ()->second ;
356345 outVals[i] = inVals[srcIdx];
357346 }
358347 Value result = packLLElements (loc, getTypeConverter (), outVals, rewriter,
0 commit comments