Skip to content

Commit 72b2d9b

Browse files
authored
[AMD] Optimize to bypass ds_bpermute for direct-to-lds loads (#7064)
If the fastest dim we load elements from hbm is contiguous we can apply the laneOffset to the pointers/buffer offsets to get the swizzled addresses. This only works because we are swapping the elements between lanes and we only swizzle in the fastest dim. In general this performs better than using ds_bpermute.
1 parent 65a416a commit 72b2d9b

File tree

2 files changed

+59
-17
lines changed

2 files changed

+59
-17
lines changed

test/Conversion/amd/buffer_load_to_local_to_llvm.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.sha
271271

272272
// Each thread needs to load 8 elements and we load 8 (sizePerThread) per buffer load instruction
273273
// GFX950: rocdl.make.buffer.rsrc
274-
// GFX950: rocdl.ds_bpermute
274+
// Src ptrs are contiguous so we do expect to bypass the ds_bpermute (see lowering to LLVM)
275+
// GFX950-NOT: rocdl.ds_bpermute
275276
// GFX950: rocdl.raw.ptr.buffer.load.lds
276277
// GFX950-NOT: rocdl.raw.ptr.buffer.load.lds
277278

third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 57 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -239,21 +239,19 @@ struct DirectToLdsLoadConversionBase : public LoadStoreConversionBase {
239239
}
240240
}
241241

242-
// Emits the computation to get the lane index which holds the source
242+
// Emits the computation to get the lane id offset which holds the source
243243
// pointers/offsets we need to store to shared memory
244-
Value emitSwizzledLaneIndex(RewriterBase &rewriter, TritonLLVMOpBuilder &b,
245-
Location loc, Value coalescedShmem,
246-
Value swizzledShmem, Value vecBytes) const {
244+
Value emitSwizzledLaneOffset(RewriterBase &rewriter, TritonLLVMOpBuilder &b,
245+
Location loc, Value coalescedShmem,
246+
Value swizzledShmem, Value vecBytes) const {
247247
// Compute the laneOffset based on the difference in elements between
248248
// the two shmem addresses. laneOffset will be negative for half the
249249
// lanes because a smaller laneId might hold our global_ptr.
250250
auto coalescedAddr = b.ptrtoint(i64_ty, coalescedShmem);
251251
auto swizzledAddr = b.ptrtoint(i64_ty, swizzledShmem);
252252
auto diff = b.trunc(i32_ty, b.sub(swizzledAddr, coalescedAddr));
253253
Value laneOffset = b.sdiv(diff, vecBytes);
254-
// laneId + laneOffset will always stay inside the warp [0,
255-
// threadsPerWarp) because we only swizzle inside a warp
256-
return b.add(getLaneId(rewriter, loc), laneOffset);
254+
return laneOffset;
257255
}
258256

259257
// Swizzle the mask (1bit) based on selectLane via ballot
@@ -266,6 +264,21 @@ struct DirectToLdsLoadConversionBase : public LoadStoreConversionBase {
266264
auto bitMask = b.lshr(warpMask, b.zext(rewriter.getI64Type(), selectLane));
267265
return b.trunc(i1_ty, bitMask);
268266
}
267+
268+
// For direct-to-lds the order of the shared encoding decides the order we
269+
// load elements from global memory. This function returns true if the fastest
270+
// dim for the sharedEnc is contiguous for the global ptrs/offsets
271+
bool isFastedLoadDimContiguous(Value srcPtrOrOffset,
272+
MemDescType sharedTy) const {
273+
auto fastestDim = triton::gpu::getOrder(sharedTy)[0];
274+
AxisInfo *axisInfo = axisAnalysisPass.getAxisInfo(srcPtrOrOffset);
275+
276+
// This can happen if axis analysis fails (e.g. lit tests).
277+
if (axisInfo->getRank() <= fastestDim)
278+
return false;
279+
280+
return axisInfo->getContiguity(fastestDim) > 1;
281+
}
269282
};
270283

271284
struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
@@ -542,11 +555,26 @@ struct BufferLoadToLocalOpConversion
542555

543556
if (hasSwizzling) {
544557
// Apply swizzling to the src offsets
545-
Value swizzledLaneId =
546-
emitSwizzledLaneIndex(rewriter, b, loc, coalescedShmemAddr[i],
547-
swizzledShmemAddr[i], vecBytesVal);
548-
offsetIn =
549-
targetInfo.shuffleIdx(rewriter, loc, offsetIn, swizzledLaneId);
558+
Value laneOffset =
559+
emitSwizzledLaneOffset(rewriter, b, loc, coalescedShmemAddr[i],
560+
swizzledShmemAddr[i], vecBytesVal);
561+
// laneId + laneOffset will always stay inside the warp [0,
562+
// threadsPerWarp) because we only swizzle inside a warp
563+
Value swizzledLaneId = b.add(getLaneId(rewriter, loc), laneOffset);
564+
565+
if (isFastedLoadDimContiguous(offset, cast<MemDescType>(dstTy))) {
566+
// Because rows are contiguous and we only swizzle inside rows by
567+
// swapping elements between lanes we can add laneOffset * vecSize to
568+
// the offset to apply the swizzling
569+
offsetIn = b.add(
570+
offsetIn, b.mul(laneOffset, b.i32_val(vecTy.getNumElements())));
571+
} else {
572+
// If rows are not contiguous in memory we need to shuffle the
573+
// pointers to apply the swizzling to the src pointers
574+
offsetIn =
575+
targetInfo.shuffleIdx(rewriter, loc, offsetIn, swizzledLaneId);
576+
}
577+
550578
if (mask) {
551579
pred =
552580
shuffleMask(rewriter, b, loc, targetInfo, swizzledLaneId, pred);
@@ -666,10 +694,23 @@ struct AsyncCopyGlobalToLocalOpConversion
666694

667695
if (hasSwizzling) {
668696
// Apply swizzling to the src pointers
669-
Value swizzledLaneId =
670-
emitSwizzledLaneIndex(rewriter, b, loc, coalescedShmemAddr[i],
671-
swizzledShmemAddr[i], vecBytesVal);
672-
srcPtr = targetInfo.shuffleIdx(rewriter, loc, srcPtr, swizzledLaneId);
697+
Value laneOffset =
698+
emitSwizzledLaneOffset(rewriter, b, loc, coalescedShmemAddr[i],
699+
swizzledShmemAddr[i], vecBytesVal);
700+
// laneId + laneOffset will always stay inside the warp [0,
701+
// threadsPerWarp) because we only swizzle inside a warp
702+
Value swizzledLaneId = b.add(getLaneId(rewriter, loc), laneOffset);
703+
704+
if (isFastedLoadDimContiguous(op.getSrc(), cast<MemDescType>(dstTy))) {
705+
// Because rows are contiguous and we only swizzle inside rows by
706+
// swapping elements between lanes we can move the vecTy typed src
707+
// pointer by laneOffset elements to apply the swizzling.
708+
srcPtr = b.gep(srcPtr.getType(), vecTy, srcPtr, laneOffset);
709+
} else {
710+
// If rows are not contiguous in memory we need to shuffle the
711+
// pointers to apply the swizzling to the src pointers
712+
srcPtr = targetInfo.shuffleIdx(rewriter, loc, srcPtr, swizzledLaneId);
713+
}
673714
if (!maskElements.empty()) {
674715
pred =
675716
shuffleMask(rewriter, b, loc, targetInfo, swizzledLaneId, pred);

0 commit comments

Comments
 (0)