Skip to content

Commit 5201154

Browse files
authored
[AMD] Use lowerLdSt for local_load to ds_read_tr path (#8344)
Transition the last use of emitTransferBetweenRegistersAndShared to the new lowering path. Some general cleanup to lowerInst(), including aliasing information for packed loads. This also removes emitTransferBetweenRegistersAndShared as it's now unused.
1 parent 4c388af commit 5201154

File tree

5 files changed

+92
-241
lines changed

5 files changed

+92
-241
lines changed

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -528,32 +528,6 @@ Value emitPadding(Location loc, RewriterBase &rewriter,
528528
triton::gpu::PaddedSharedEncodingAttr layout,
529529
unsigned bitwidth, Value smemOffset, bool offsetInBytes);
530530

531-
// Emits IR to load data from shared memory into registers, or to store data
532-
// from registers into shared memory.
533-
//
534-
// You supply perVectorCallback, which is called once per group of register
535-
// elements to transfer. You can use this callback to emit IR to load or store
536-
// data from or to shared memory.
537-
//
538-
// elemLlvmTy should be dstTy's element type converted to an LLVM-dialect type.
539-
//
540-
// If maxVecElems is provided, we won't vectorize more than this many elements.
541-
//
542-
// Returns true on success.
543-
[[nodiscard]] bool emitTransferBetweenRegistersAndShared(
544-
RankedTensorType registerTy, triton::gpu::MemDescType sharedTy,
545-
Type elemLlvmTy, std::optional<int32_t> maxVecElems,
546-
const SharedMemoryObject &smemObj, Location loc, RewriterBase &rewriter,
547-
const TargetInfoBase &target,
548-
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback);
549-
550-
[[nodiscard]] bool emitTransferBetweenRegistersAndShared(
551-
LinearLayout &regLayout, triton::gpu::MemDescType sharedTy, Type elemLlvmTy,
552-
std::optional<int32_t> maxVecElems, const SharedMemoryObject &smemObj,
553-
Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
554-
Value laneId, Value warpId,
555-
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback);
556-
557531
// Close cousin of lowerLdStMatrix in MemoryOpToLLVM.cpp
558532
// We might want to merge them at some point, but having to support
559533
// ldmatrix.trans makes the code in lowerLdStMatrix a bit specific

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 0 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -706,110 +706,6 @@ lowerLocalLdSt(Location loc, MLIRContext *ctx,
706706
maybeMaxVecElems, localLoadOp);
707707
}
708708

709-
bool emitTransferBetweenRegistersAndShared(
710-
LinearLayout &regLayout, triton::gpu::MemDescType sharedTy, Type elemLlvmTy,
711-
std::optional<int32_t> maxVecElems, const SharedMemoryObject &smemObj,
712-
Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
713-
Value laneId, Value warpId,
714-
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback) {
715-
MLIRContext *ctx = rewriter.getContext();
716-
auto b = TritonLLVMOpBuilder(loc, rewriter);
717-
718-
StringAttr kBlock = str_attr("block");
719-
StringAttr kRegister = str_attr("register");
720-
StringAttr kLane = str_attr("lane");
721-
StringAttr kWarp = str_attr("warp");
722-
StringAttr kOffset = str_attr("offset");
723-
724-
auto shape = sharedTy.getShape();
725-
auto paddedEnc =
726-
dyn_cast<triton::gpu::PaddedSharedEncodingAttr>(sharedTy.getEncoding());
727-
LinearLayout regToSharedLayout = LinearLayout::empty();
728-
if (paddedEnc) {
729-
const auto &sharedLL = paddedEnc.getLinearComponent();
730-
regToSharedLayout = regLayout.invertAndCompose(sharedLL);
731-
} else {
732-
auto sharedLL = triton::gpu::toLinearLayout(sharedTy);
733-
regToSharedLayout = regLayout.invertAndCompose(sharedLL);
734-
}
735-
736-
// TODO(jlebar): We don't currently support loading from shared memory in a
737-
// different CTA. We'd need to emit `mapa.shared::cluster` instructions.
738-
if (regToSharedLayout.hasInDim(kBlock) &&
739-
regToSharedLayout.hasOutDim(kBlock) &&
740-
!regToSharedLayout.isTrivialOver({kBlock})) {
741-
return false;
742-
}
743-
744-
// Determine how many consecutive registers map to consecutive shmem elements
745-
// in out-dimension offsetN. This is our load instruction's vector width.
746-
//
747-
// It's OK if the vector width we choose here is wider than the hardware
748-
// supports; LLVM will legalize it.
749-
int vecElems =
750-
std::min({regToSharedLayout.getNumConsecutiveInOut(),
751-
maxVecElems.value_or(std::numeric_limits<int>::max())});
752-
if (paddedEnc) {
753-
vecElems = std::min(vecElems, int(paddedEnc.getMinInterval()));
754-
}
755-
756-
auto withCTAOffset = triton::gpu::getNumCTAs(sharedTy.getEncoding()) > 1;
757-
Value blockId =
758-
withCTAOffset ? target.getClusterCTAId(rewriter, loc) : b.i32_val(0);
759-
760-
int numElems = regToSharedLayout.getInDimSize(kRegister);
761-
auto vecTy = vec_ty(elemLlvmTy, vecElems);
762-
SmallVector<uint32_t> regIds;
763-
for (int i = 0; i < numElems / vecElems; i++) {
764-
regIds.push_back(i * vecElems);
765-
}
766-
767-
auto smemBase = smemObj.getBase();
768-
769-
auto indicesVec = applyLinearLayoutVec(loc, rewriter, regToSharedLayout,
770-
{{kRegister, b.i32_val(0)},
771-
{kLane, laneId},
772-
{kWarp, warpId},
773-
{kBlock, blockId}},
774-
regIds);
775-
776-
// Compute affine offset given by memdesc_subslice
777-
auto offset = smemObj.getShmemOffset(loc, rewriter, sharedTy);
778-
SmallVector<Value> vecAddrVec;
779-
for (auto &indices : indicesVec) {
780-
Value smemOffset = indices[0].second;
781-
smemOffset = b.xor_(smemOffset, offset);
782-
if (paddedEnc) {
783-
// Apply the offset needed for padding.
784-
auto bitwidth = elemLlvmTy.getIntOrFloatBitWidth();
785-
Value padOffset = emitPadding(loc, rewriter, paddedEnc, bitwidth,
786-
smemOffset, /*offsetInBytes=*/false);
787-
smemOffset = b.add(smemOffset, padOffset);
788-
}
789-
auto vecAddr = b.gep(smemBase.getType(), elemLlvmTy, smemBase, smemOffset,
790-
LLVM::GEPNoWrapFlags::inbounds);
791-
vecAddrVec.push_back(vecAddr);
792-
}
793-
794-
for (Value &vecAddr : vecAddrVec) {
795-
perVectorCallback(vecTy, vecAddr);
796-
}
797-
return true;
798-
}
799-
800-
bool emitTransferBetweenRegistersAndShared(
801-
RankedTensorType registerTy, triton::gpu::MemDescType sharedTy,
802-
Type elemLlvmTy, std::optional<int32_t> maxVecElems,
803-
const SharedMemoryObject &smemObj, Location loc, RewriterBase &rewriter,
804-
const TargetInfoBase &target,
805-
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback) {
806-
auto regLayout = triton::gpu::toLinearLayout(registerTy);
807-
auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc);
808-
return emitTransferBetweenRegistersAndShared(
809-
regLayout, sharedTy, elemLlvmTy, maxVecElems, smemObj, loc, rewriter,
810-
target, laneId, warpId, perVectorCallback);
811-
}
812-
813709
SmallVector<Value> unpackLLElements(Location loc, Value llvmStruct,
814710
RewriterBase &rewriter) {
815711
assert(bool(llvmStruct) && "can not unpack null values");

third_party/amd/lib/TritonAMDGPUToLLVM/AsyncUtility.cpp

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "Dialect/TritonAMDGPU/IR/Dialect.h"
44
#include "TargetInfo.h"
55
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
6+
#include "llvm/ADT/TypeSwitch.h"
67

78
namespace mlir::triton::AMD {
89
namespace {
@@ -50,21 +51,20 @@ bool comesFromAsyncWait(Value token) {
5051
} // namespace
5152

5253
void annotateLocalLoadsSyncedViaAsyncWait(ModuleOp mod) {
53-
SmallVector<triton::gpu::LocalLoadOp> localLoads;
54-
mod->walk([&](triton::gpu::LocalLoadOp localLoadOp) {
55-
localLoads.emplace_back(localLoadOp);
56-
});
57-
5854
auto *ctx = mod->getContext();
59-
for (auto &loadOp : localLoads) {
60-
auto token = loadOp.getToken();
61-
if (loadOp->hasAttr(syncedViaAsyncWaitAttrName))
62-
continue;
63-
64-
bool isSyncedViaAsyncWait = token && comesFromAsyncWait(token);
65-
loadOp->setAttr(syncedViaAsyncWaitAttrName,
66-
BoolAttr::get(ctx, isSyncedViaAsyncWait));
67-
}
55+
56+
mod->walk([&](Operation *op) {
57+
TypeSwitch<Operation *, void>(op)
58+
.Case<triton::gpu::LocalLoadOp,
59+
triton::amdgpu::LocalLoadPackedTransposedOp>([&](auto loadOp) {
60+
if (loadOp->hasAttr(syncedViaAsyncWaitAttrName))
61+
return;
62+
Value token = loadOp.getToken();
63+
bool isSyncedViaAsyncWait = token && comesFromAsyncWait(token);
64+
loadOp->setAttr(syncedViaAsyncWaitAttrName,
65+
BoolAttr::get(ctx, isSyncedViaAsyncWait));
66+
});
67+
});
6868
}
6969

7070
bool isSyncedViaAsyncWait(Operation *op) {
@@ -112,8 +112,10 @@ void addAsyncCopyAliasScope(LLVM::AliasAnalysisOpInterface directToLdsOp) {
112112
directToLdsOp.setAliasScopes(b.getArrayAttr(getAsyncCopyScope(ctx)));
113113
}
114114

115-
void addLocalLoadNoAliasScope(triton::gpu::LocalLoadOp localLoadOp,
115+
void addLocalLoadNoAliasScope(Operation *localLoadOp,
116116
LLVM::AliasAnalysisOpInterface llLoadOp) {
117+
if (!localLoadOp->hasTrait<OpTrait::LocalLoadTrait>())
118+
return;
117119
if (!isSyncedViaAsyncWait(localLoadOp))
118120
return;
119121

third_party/amd/lib/TritonAMDGPUToLLVM/AsyncUtility.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ bool isSyncedViaAsyncWait(Operation *localLoadOp);
3434
// If localLoadOp has a token from an AsyncWait:
3535
// - Attaches "amdgpu.LocalLoad" alias scope to llLoadOp
3636
// - Attaches "amdgpu.AsyncCopies" as *non* alias scope to llLoadOp
37-
void addLocalLoadNoAliasScope(triton::gpu::LocalLoadOp localLoadOp,
37+
void addLocalLoadNoAliasScope(Operation *localLoadOp,
3838
LLVM::AliasAnalysisOpInterface llLoadOp);
3939
// Overload from above without checking the AsyncToken
4040
void addLocalLoadNoAliasScope(LLVM::AliasAnalysisOpInterface llLoadOp);

0 commit comments

Comments
 (0)