Skip to content

Commit 27fae2b

Browse files
committed
[AMD] Remove bypass permute optimization for AsyncCopy (triton-lang#7704)
We can only bypass ds_bpermute to apply the swizzling if lanes loading the same row read a contiguous chunk of memory from HBM, which we cannot infer when lowering to LLVM. The current selection does only check if the elements for each lane are contiguous which is not strict enough.
1 parent b9f0fdf commit 27fae2b

File tree

2 files changed

+6
-40
lines changed

2 files changed

+6
-40
lines changed

test/Conversion/amd/buffer_load_to_local_to_llvm.mlir

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,8 +271,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.sha
271271

272272
// Each thread needs to load 8 elements and we load 8 (sizePerThread) per buffer load instruction
273273
// GFX950: rocdl.make.buffer.rsrc
274-
// Src ptrs are contiguous so we do expect to bypass the ds_bpermute (see lowering to LLVM)
275-
// GFX950-NOT: rocdl.ds_bpermute
276274
// GFX950: rocdl.raw.ptr.buffer.load.lds
277275
// GFX950-NOT: rocdl.raw.ptr.buffer.load.lds
278276

third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 6 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -326,21 +326,6 @@ struct DirectToLdsLoadConversionBase : public LoadStoreConversionBase {
326326
auto bitMask = b.lshr(warpMask, b.zext(rewriter.getI64Type(), selectLane));
327327
return b.trunc(i1_ty, bitMask);
328328
}
329-
330-
// For direct-to-lds the order of the shared encoding decides the order we
331-
// load elements from global memory. This function returns true if the fastest
332-
// dim for the sharedEnc is contiguous for the global ptrs/offsets
333-
bool isFastedLoadDimContiguous(Value srcPtrOrOffset,
334-
MemDescType sharedTy) const {
335-
auto fastestDim = triton::gpu::getOrder(sharedTy)[0];
336-
AxisInfo *axisInfo = axisAnalysisPass.getAxisInfo(srcPtrOrOffset);
337-
338-
// This can happen if axis analysis fails (e.g. lit tests).
339-
if (axisInfo->getRank() <= fastestDim)
340-
return false;
341-
342-
return axisInfo->getContiguity(fastestDim) > 1;
343-
}
344329
};
345330

346331
struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
@@ -618,18 +603,9 @@ struct BufferLoadToLocalOpConversion
618603
// laneId + swizzleOffset will always stay inside the warp [0,
619604
// threadsPerWarp) because we only swizzle inside a warp
620605
Value swizzledLaneId = b.add(laneId, swizzleLaneOffset);
621-
if (isFastedLoadDimContiguous(offset, cast<MemDescType>(dstTy))) {
622-
// Because rows are contiguous and we only swizzle inside rows by
623-
// swapping elements between lanes we can add laneOffset * vecSize to
624-
// the offset to apply the swizzling
625-
offsetIn = b.add(offsetIn, b.mul(swizzleLaneOffset,
626-
b.i32_val(vecTy.getNumElements())));
627-
} else {
628-
// If rows are not contiguous in memory we need to shuffle the
629-
// pointers to apply the swizzling to the src pointers
630-
offsetIn =
631-
targetInfo.shuffleIdx(rewriter, loc, offsetIn, swizzledLaneId);
632-
}
606+
// Shuffle based on swizzleLaneId to apply the swizzling
607+
offsetIn =
608+
targetInfo.shuffleIdx(rewriter, loc, offsetIn, swizzledLaneId);
633609

634610
if (mask) {
635611
pred =
@@ -763,17 +739,9 @@ struct AsyncCopyGlobalToLocalOpConversion
763739
Value swizzledLaneId =
764740
b.add(getLaneId(rewriter, loc), swizzledLaneOffsets[i]);
765741

766-
if (isFastedLoadDimContiguous(op.getSrc(), cast<MemDescType>(dstTy))) {
767-
// Because rows are contiguous and we only swizzle inside rows by
768-
// swapping elements between lanes we can move the vecTy typed src
769-
// pointer by laneOffset elements to apply the swizzling.
770-
srcPtr =
771-
b.gep(srcPtr.getType(), vecTy, srcPtr, swizzledLaneOffsets[i]);
772-
} else {
773-
// If rows are not contiguous in memory we need to shuffle the
774-
// pointers to apply the swizzling to the src pointers
775-
srcPtr = targetInfo.shuffleIdx(rewriter, loc, srcPtr, swizzledLaneId);
776-
}
742+
// Shuffle based on swizzleLaneId to apply the swizzling
743+
srcPtr = targetInfo.shuffleIdx(rewriter, loc, srcPtr, swizzledLaneId);
744+
777745
if (!maskElements.empty()) {
778746
pred =
779747
shuffleMask(rewriter, b, loc, targetInfo, swizzledLaneId, pred);

0 commit comments

Comments
 (0)