@@ -63,8 +63,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
6363 } else if (llvm::is_contained (dims, kWarp )) {
6464 // Case 2: Transfer between values in the same CTA, in which case we move
6565 // values through shared memory.
66- transferWithinBlockSwizzling (op, adaptor.getSrc (), rewriter);
67- return success ();
66+ return transferWithinBlock (op, srcLayout, dstLayout, adaptor, rewriter);
6867 } else if (llvm::is_contained (dims, kLane )) {
6968 // Case 3. Transfer between values in the same warp, in which case we try
7069 // to move values using warp shuffles, though if the pattern is
@@ -75,8 +74,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
7574 // TODO: Since data is only transferred within a warp over shared memory,
7675 // we should use `bar.warp.sync` instead of `barrier`, which will improve
7776 // latency when warps issue barriers on different cycles.
78- transferWithinBlockSwizzling (op, adaptor.getSrc (), rewriter);
79- return success ();
77+ return transferWithinBlock (op, srcLayout, dstLayout, adaptor, rewriter);
8078 } else if (llvm::is_contained (dims, kRegister )) {
8179 // Case 4. Transfer between values in the same thread, in which case we
8280 // simply reorder the elements of adaptor.getSrc().
@@ -171,7 +169,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
171169 // At this point we have a type that's at least 8-bit
172170 // and we don't have broadcasting in the registers
173171 auto bitwidth = llvmElemTy.getIntOrFloatBitWidth ();
174- auto smem = optimalSwizzlingLdSt (srcLayout, dstLayout, bitwidth);
172+ auto smem = optimalSwizzling (srcLayout, dstLayout, bitwidth);
175173
176174 // Extract reps from smem
177175 auto kReg = str_attr (" register" );
@@ -203,9 +201,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
203201
204202 assert (permutedInVals.size () == tileSize * nReps);
205203 SmallVector<Value> outVals;
204+ auto noPaddingOffset = [](Value v) { return v; };
206205 auto affineOffset = b.i32_val (0 );
207206 auto maskSpanAffineOffset = 0 ;
208- auto noPaddingOffset = [](Value v) { return v; };
209207 for (int i = 0 ; i < nReps; ++i) {
210208 if (i > 0 )
211209 b.barrier ();
@@ -229,8 +227,20 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
229227 return outVals;
230228 }
231229
232- void transferWithinBlockSwizzling (ConvertLayoutOp op, Value src,
233- ConversionPatternRewriter &rewriter) const {
230+ LogicalResult
231+ transferWithinBlockSwizzling (ConvertLayoutOp op, Value src,
232+ ConversionPatternRewriter &rewriter) const {
233+ // Fallback for now to standard lowering if it can use stmatrix
234+ auto scratchConfig =
235+ getScratchConfigForCvt (op.getSrc ().getType (), op.getType ());
236+ bool isStMatrix = targetInfo.canUseStMatrix (
237+ op.getSrc ().getType (), scratchConfig.repShape ,
238+ scratchConfig.paddedRepShape , scratchConfig.order ,
239+ /* swizzleByteSize=*/ 0 );
240+ if (isStMatrix) {
241+ return failure ();
242+ }
243+
234244 auto loc = op.getLoc ();
235245 auto *ctx = op.getContext ();
236246 auto srcTy = op.getSrc ().getType ();
@@ -258,6 +268,28 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
258268 Value result =
259269 packLLElements (loc, getTypeConverter (), outVals, rewriter, dstTy);
260270 rewriter.replaceOp (op, result);
271+ return success ();
272+ }
273+
274+ LogicalResult transferWithinBlock (ConvertLayoutOp op,
275+ const LinearLayout &srcLayout,
276+ const LinearLayout &dstLayout,
277+ OpAdaptor adaptor,
278+ ConversionPatternRewriter &rewriter) const {
279+ assert (cvtNeedsSharedMemory (op.getSrc ().getType (), op.getType ()));
280+
281+ // Try to use swizzling to implement the conversion
282+ // HACK Remove once AMD tests pass for the swizzling path
283+ if (targetInfo.isCuda () && succeeded (transferWithinBlockSwizzling (
284+ op, adaptor.getSrc (), rewriter))) {
285+ return success ();
286+ }
287+
288+ Value result = transferWithinBlockPadding (op, adaptor.getSrc (), targetInfo,
289+ getTypeConverter (), rewriter);
290+
291+ rewriter.replaceOp (op, result);
292+ return success ();
261293 }
262294
263295 // Use warp shuffles to implement a layout conversion where data only needs to
0 commit comments