@@ -311,14 +311,10 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
311311 // TODO(Keren): implement warp shuffle instead of using the general
312312 // approach that uses shared memory
313313 return transferWithinBlock (op, srcLayout, dstLayout, adaptor, rewriter);
314- } else if (llvm::is_contained (dims, kRegister ) ||
315- dstLayout.getInDimSize (kRegister ) !=
316- srcLayout.getInDimSize (kRegister )) {
314+ } else if (llvm::is_contained (dims, kRegister )) {
317315 // Case 4. Transfer between values in the same thread, in which case we
318316 // simply reorder the elements of adaptor.getSrc().
319- return transferWithinThread (
320- op, dstLayout.getFreeVariableMasks ()[kRegister ],
321- dstLayout.getInDimSize (kRegister ), *conversion, adaptor, rewriter);
317+ return transferWithinThread (op, *conversion, adaptor, rewriter);
322318 } else {
323319 // Cast 5. The two layouts are equivalent. We should probably remove
324320 // these in RemoveLayoutConversion.
@@ -328,8 +324,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
328324 }
329325
330326 LogicalResult
331- transferWithinThread (ConvertLayoutOp op, int32_t regMasks, int32_t numRegs ,
332- const LinearLayout &conversion, OpAdaptor adaptor,
327+ transferWithinThread (ConvertLayoutOp op, const LinearLayout &conversion ,
328+ OpAdaptor adaptor,
333329 ConversionPatternRewriter &rewriter) const {
334330 MLIRContext *ctx = op.getContext ();
335331 auto loc = op.getLoc ();
@@ -339,16 +335,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
339335 auto srcTy = op.getSrc ().getType ();
340336 auto dstTy = op.getType ();
341337 auto inVals = unpackLLElements (loc, adaptor.getSrc (), rewriter);
342- SmallVector<Value> outVals (numRegs);
343- for (int i = 0 ; i < numRegs; i++) {
344- // Remove free masks from the register index
345- // For example, if idx = 0b00111, and masks = 0b00100, then we get
346- // 0b00011. It means that register 7 (0b111) has the same value as
347- // register 3 (0b011).
348- auto idx = i & (~regMasks);
349- auto srcIdx = conversion.hasInDim (kRegister )
350- ? conversion.apply ({{kRegister , idx}}).begin ()->second
351- : idx;
338+ SmallVector<Value> outVals (conversion.getInDimSize (kRegister ));
339+ for (int i = 0 ; i < outVals.size (); i++) {
340+ auto srcIdx = conversion.apply ({{kRegister , i}}).begin ()->second ;
352341 outVals[i] = inVals[srcIdx];
353342 }
354343 Value result = packLLElements (loc, getTypeConverter (), outVals, rewriter,
0 commit comments