Skip to content

Commit 219433e

Browse files
AlexAUTantiagainst
andauthored
[AMD] Optimize shared address calculation for async load (#7153)
On GFX9 direct-to-lds loads write coalesced to LDS and therefore require the start LDS address as a scalar. This PR refactors the address calculation to uniformly compute the start address instead of per lane addresses. This improves final codegen and reduces register usage. The swizzling computations are now based on the offset instead of the final addresses which further helps codegen. The lowering can produce incorrect loads in some cases if we store into a sub-view which slices along the two minor dimensions, so pipelining is fine. This was already the case before the refactoring and will be converted to an error in a follow up PR. --------- Co-authored-by: Lei Zhang <[email protected]>
1 parent 343bd8e commit 219433e

File tree

3 files changed

+201
-108
lines changed

3 files changed

+201
-108
lines changed

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,13 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
543543
Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
544544
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback);
545545

546+
[[nodiscard]] bool emitTransferBetweenRegistersAndShared(
547+
LinearLayout &regLayout, triton::gpu::MemDescType sharedTy, Type elemLlvmTy,
548+
std::optional<int32_t> maxVecElems, const SharedMemoryObject &smemObj,
549+
Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
550+
Value laneId, Value warpId,
551+
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback);
552+
546553
SmallVector<Value> loadSharedToDistributed(triton::gpu::LocalLoadOp localLoadOp,
547554
Type elemLlvmTy,
548555
const SharedMemoryObject &smemObj,

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,18 @@ bool emitTransferBetweenRegistersAndShared(
417417
std::optional<int32_t> maxVecElems, const SharedMemoryObject &smemObj,
418418
Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
419419
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback) {
420+
auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc);
421+
return emitTransferBetweenRegistersAndShared(
422+
regLayout, sharedTy, elemLlvmTy, maxVecElems, smemObj, loc, rewriter,
423+
target, laneId, warpId, perVectorCallback);
424+
}
425+
426+
bool emitTransferBetweenRegistersAndShared(
427+
LinearLayout &regLayout, triton::gpu::MemDescType sharedTy, Type elemLlvmTy,
428+
std::optional<int32_t> maxVecElems, const SharedMemoryObject &smemObj,
429+
Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
430+
Value laneId, Value warpId,
431+
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback) {
420432
MLIRContext *ctx = rewriter.getContext();
421433
auto b = TritonLLVMOpBuilder(loc, rewriter);
422434

@@ -458,7 +470,6 @@ bool emitTransferBetweenRegistersAndShared(
458470
maxVecElems.value_or(std::numeric_limits<int>::max()));
459471

460472
auto withCTAOffset = triton::gpu::getNumCTAs(sharedTy.getEncoding()) > 1;
461-
auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc);
462473
Value blockId =
463474
withCTAOffset ? target.getClusterCTAId(rewriter, loc) : b.i32_val(0);
464475

0 commit comments

Comments
 (0)