Skip to content

Commit 1523c6c

Browse files
authored
[AMD] Remove bypass permute optimization for AsyncCopy (#7704)
We can only bypass ds_bpermute to apply the swizzling if lanes loading the same row read a contiguous chunk of memory from HBM, which we cannot infer when lowering to LLVM. The current selection does only check if the elements for each lane are contiguous which is not strict enough.
1 parent 250b6eb commit 1523c6c

File tree

2 files changed

+6
-40
lines changed

2 files changed

+6
-40
lines changed

test/Conversion/amd/buffer_load_to_local_to_llvm.mlir

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,8 +272,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.sha
272272

273273
// Each thread needs to load 8 elements and we load 8 (sizePerThread) per buffer load instruction
274274
// GFX950: rocdl.make.buffer.rsrc
275-
// Src ptrs are contiguous so we do expect to bypass the ds_bpermute (see lowering to LLVM)
276-
// GFX950-NOT: rocdl.ds_bpermute
277275
// GFX950: rocdl.raw.ptr.buffer.load.lds
278276
// GFX950-NOT: rocdl.raw.ptr.buffer.load.lds
279277

third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 6 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -446,21 +446,6 @@ struct DirectToLdsLoadConversionBase : public LoadStoreConversionBase {
446446
auto bitMask = b.lshr(warpMask, b.zext(rewriter.getI64Type(), selectLane));
447447
return b.trunc(i1_ty, bitMask);
448448
}
449-
450-
// For direct-to-lds the order of the shared encoding decides the order we
451-
// load elements from global memory. This function returns true if the fastest
452-
// dim for the sharedEnc is contiguous for the global ptrs/offsets
453-
bool isFastedLoadDimContiguous(Value srcPtrOrOffset,
454-
MemDescType sharedTy) const {
455-
auto fastestDim = triton::gpu::getOrder(sharedTy)[0];
456-
AxisInfo *axisInfo = axisAnalysisPass.getAxisInfo(srcPtrOrOffset);
457-
458-
// This can happen if axis analysis fails (e.g. lit tests).
459-
if (axisInfo->getRank() <= fastestDim)
460-
return false;
461-
462-
return axisInfo->getContiguity(fastestDim) > 1;
463-
}
464449
};
465450

466451
struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
@@ -732,18 +717,9 @@ struct BufferLoadToLocalOpConversion
732717
// laneId + swizzleOffset will always stay inside the warp [0,
733718
// threadsPerWarp) because we only swizzle inside a warp
734719
Value swizzledLaneId = b.add(laneId, swizzleLaneOffset);
735-
if (isFastedLoadDimContiguous(offset, cast<MemDescType>(dstTy))) {
736-
// Because rows are contiguous and we only swizzle inside rows by
737-
// swapping elements between lanes we can add laneOffset * vecSize to
738-
// the offset to apply the swizzling
739-
offsetIn = b.add(offsetIn, b.mul(swizzleLaneOffset,
740-
b.i32_val(vecTy.getNumElements())));
741-
} else {
742-
// If rows are not contiguous in memory we need to shuffle the
743-
// pointers to apply the swizzling to the src pointers
744-
offsetIn =
745-
targetInfo.shuffleIdx(rewriter, loc, offsetIn, swizzledLaneId);
746-
}
720+
// Shuffle based on swizzleLaneId to apply the swizzling
721+
offsetIn =
722+
targetInfo.shuffleIdx(rewriter, loc, offsetIn, swizzledLaneId);
747723

748724
if (mask) {
749725
pred =
@@ -877,17 +853,9 @@ struct AsyncCopyGlobalToLocalOpConversion
877853
Value swizzledLaneId =
878854
b.add(getLaneId(rewriter, loc), swizzledLaneOffsets[i]);
879855

880-
if (isFastedLoadDimContiguous(op.getSrc(), cast<MemDescType>(dstTy))) {
881-
// Because rows are contiguous and we only swizzle inside rows by
882-
// swapping elements between lanes we can move the vecTy typed src
883-
// pointer by laneOffset elements to apply the swizzling.
884-
srcPtr =
885-
b.gep(srcPtr.getType(), vecTy, srcPtr, swizzledLaneOffsets[i]);
886-
} else {
887-
// If rows are not contiguous in memory we need to shuffle the
888-
// pointers to apply the swizzling to the src pointers
889-
srcPtr = targetInfo.shuffleIdx(rewriter, loc, srcPtr, swizzledLaneId);
890-
}
856+
// Shuffle based on swizzleLaneId to apply the swizzling
857+
srcPtr = targetInfo.shuffleIdx(rewriter, loc, srcPtr, swizzledLaneId);
858+
891859
if (!maskElements.empty()) {
892860
pred =
893861
shuffleMask(rewriter, b, loc, targetInfo, swizzledLaneId, pred);

0 commit comments

Comments
 (0)