Skip to content

Commit f7f5b3a

Browse files
zkxvbkaixuan.zhang
andauthored
typeConverter to llvm support addressSpace attribute (#5951)
MemDescType support attribute memorySpace, but TypeConverter to LLVM not used, this pull request try to fix it. Co-authored-by: kaixuan.zhang <[email protected]>
1 parent 013453f commit f7f5b3a

File tree

7 files changed

+33
-3
lines changed

7 files changed

+33
-3
lines changed

include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ class TargetInfoBase {
8989

9090
virtual int getSharedAddressSpace() const = 0;
9191

92+
virtual int getAddressSpace(Attribute addressSpace) const = 0;
93+
9294
virtual bool supportVectorizedAtomics() const = 0;
9395

9496
// Helper used by targets to annotate store operations during lowering to

lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ Type TritonGPUToLLVMTypeConverter::convertMemDescType(
5454
MemDescType type, const TargetInfoBase &targetInfo) {
5555
auto ctx = type.getContext();
5656
// base ptr
57-
auto ptrType =
58-
LLVM::LLVMPointerType::get(ctx, targetInfo.getSharedAddressSpace());
57+
auto ptrType = LLVM::LLVMPointerType::get(
58+
ctx, targetInfo.getAddressSpace(type.getMemorySpace()));
5959

6060
if (isa<triton::nvidia_gpu::TensorMemoryEncodingAttr,
6161
triton::nvidia_gpu::TensorMemoryScalesEncodingAttr>(

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -882,7 +882,7 @@ void MakeTensorPtrOp::build(OpBuilder &builder, OperationState &state,
882882
auto tensorType = RankedTensorType::get(
883883
SmallVector<int64_t>(tensorShape.begin(), tensorShape.end()),
884884
pointerType.getPointeeType());
885-
auto result = PointerType::get(tensorType, 1);
885+
auto result = PointerType::get(tensorType, pointerType.getAddressSpace());
886886

887887
return build(builder, state, result, base, shape, strides, offsets,
888888
builder.getDenseI32ArrayAttr(order));

third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,16 @@ void TargetInfo::assertFail(RewriterBase &rewriter, Location loc,
447447

448448
int TargetInfo::getSharedAddressSpace() const { return 3; }
449449

450+
int TargetInfo::getAddressSpace(Attribute addressSpace) const {
451+
int spaceId = 0;
452+
if (isa<triton::gpu::SharedMemorySpaceAttr>(addressSpace)) {
453+
spaceId = 3;
454+
} else {
455+
llvm::report_fatal_error("Only support SharedMemorySpace for now");
456+
}
457+
return spaceId;
458+
}
459+
450460
bool TargetInfo::supportVectorizedAtomics() const {
451461
// Note: not currently tested or used, but AMD generally supports vectorized
452462
// atomics.

third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,11 @@ class TargetInfo : public mlir::triton::TargetInfoBase {
6565

6666
void assertFail(RewriterBase &rewriter, Location loc, StringRef message,
6767
StringRef file, StringRef func, int line) const override;
68+
6869
int getSharedAddressSpace() const override;
6970

71+
int getAddressSpace(Attribute addressSpace) const override;
72+
7073
bool supportVectorizedAtomics() const override;
7174

7275
void storeOpAnnotation(triton::gpu::LocalStoreOp op, size_t localStoreOpCount,

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -617,6 +617,18 @@ void TargetInfo::assertFail(RewriterBase &rewriter, Location loc,
617617

618618
int TargetInfo::getSharedAddressSpace() const { return 3; }
619619

620+
int TargetInfo::getAddressSpace(Attribute addressSpace) const {
621+
int spaceId = 0;
622+
if (isa<triton::gpu::SharedMemorySpaceAttr,
623+
triton::nvidia_gpu::TensorMemorySpaceAttr>(addressSpace)) {
624+
spaceId = 3;
625+
} else {
626+
llvm::report_fatal_error(
627+
"Only support SharedMemorySpace, TensorMemorySpace for now");
628+
}
629+
return spaceId;
630+
}
631+
620632
bool TargetInfo::supportVectorizedAtomics() const {
621633
return computeCapability >= 90 && ptxVersion >= 81;
622634
}

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,11 @@ class TargetInfo : public mlir::triton::TargetInfoBase {
5858

5959
void assertFail(RewriterBase &rewriter, Location loc, StringRef message,
6060
StringRef file, StringRef func, int line) const override;
61+
6162
int getSharedAddressSpace() const override;
6263

64+
int getAddressSpace(Attribute addressSpace) const override;
65+
6366
bool supportVectorizedAtomics() const override;
6467

6568
int getPtxVersion() const { return ptxVersion; }

0 commit comments

Comments
 (0)