Skip to content

Commit e625b79

Browse files
Merge commit 'f7f5b3af26d283348e78f3593b337a39267f7ff9'
2 parents cd4f49b + f7f5b3a commit e625b79

File tree

9 files changed

+66
-3
lines changed

9 files changed

+66
-3
lines changed

include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ class TargetInfoBase {
8989

9090
virtual int getSharedAddressSpace() const = 0;
9191

92+
virtual int getAddressSpace(Attribute addressSpace) const = 0;
93+
9294
virtual bool supportVectorizedAtomics() const = 0;
9395

9496
// Helper used by targets to annotate store operations during lowering to

lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ Type TritonGPUToLLVMTypeConverter::convertMemDescType(
5454
MemDescType type, const TargetInfoBase &targetInfo) {
5555
auto ctx = type.getContext();
5656
// base ptr
57-
auto ptrType =
58-
LLVM::LLVMPointerType::get(ctx, targetInfo.getSharedAddressSpace());
57+
auto ptrType = LLVM::LLVMPointerType::get(
58+
ctx, targetInfo.getAddressSpace(type.getMemorySpace()));
5959

6060
if (isa<triton::nvidia_gpu::TensorMemoryEncodingAttr,
6161
triton::nvidia_gpu::TensorMemoryScalesEncodingAttr>(

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -882,7 +882,7 @@ void MakeTensorPtrOp::build(OpBuilder &builder, OperationState &state,
882882
auto tensorType = RankedTensorType::get(
883883
SmallVector<int64_t>(tensorShape.begin(), tensorShape.end()),
884884
pointerType.getPointeeType());
885-
auto result = PointerType::get(tensorType, 1);
885+
auto result = PointerType::get(tensorType, pointerType.getAddressSpace());
886886

887887
return build(builder, state, result, base, shape, strides, offsets,
888888
builder.getDenseI32ArrayAttr(order));

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,20 @@ SmallVector<unsigned> getShapePerCTATile(Attribute layout) {
153153
}
154154

155155
bool isExpensiveView(Type srcType, Type dstType) {
156+
auto tensorSrcType = cast<RankedTensorType>(srcType);
157+
auto tensorDstType = cast<RankedTensorType>(dstType);
158+
auto llSrc =
159+
toLinearLayout(tensorSrcType.getShape(), tensorSrcType.getEncoding());
160+
auto llDst =
161+
toLinearLayout(tensorDstType.getShape(), tensorDstType.getEncoding());
162+
// In case there are replicated value we need to make sure the new and old
163+
// layout have matching masks.
164+
for (auto [srcMask, dstMask] :
165+
llvm::zip(llSrc.getFreeVariableMasks(), llDst.getFreeVariableMasks())) {
166+
assert(srcMask.first == dstMask.first);
167+
if (srcMask.second != dstMask.second)
168+
return true;
169+
}
156170
return getTotalElemsPerThread(srcType) != getTotalElemsPerThread(dstType);
157171
}
158172

test/TritonGPU/canonicalize.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,25 @@ tt.func @test_canonicalize_convert_expensive_view(%arg0: tensor<256x16xf32, #blo
4040

4141
// -----
4242

43+
// test that the convert doesn't get combined with view if the resulting operations
44+
// is an expensive view which would require moving data across threads.
45+
// CHECK-LABEL: @test_canonicalize_convert_expensive_view
46+
// CHECK-SAME: (%[[ARG:.+]]: tensor<2xf32
47+
// CHECK: %[[C:.+]] = ttg.convert_layout %[[ARG]]
48+
// CHECK: %[[V:.+]] = tt.reshape %[[C]] allow_reorder
49+
// CHECK: tt.return %[[V]]
50+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}>
51+
#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
52+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:80"} {
53+
tt.func @test_canonicalize_convert_expensive_view2(%arg0: tensor<2xf32, #ttg.slice<{dim = 1, parent = #blocked}>>) -> tensor<2xf32, #blocked1> {
54+
%c = ttg.convert_layout %arg0 : tensor<2xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<2xf32, #blocked1>
55+
%r = tt.reshape %c allow_reorder : tensor<2xf32, #blocked1> -> tensor<2xf32, #blocked1>
56+
tt.return %r : tensor<2xf32, #blocked1>
57+
}
58+
}
59+
60+
// -----
61+
4362
// test that the convert does get combined with the view even if the resulting operation
4463
// is an efficient view.
4564
// CHECK-LABEL: @test_canonicalize_convert_view

third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,16 @@ Value TargetInfo::getStackPointer(RewriterBase &rewriter,
460460
return rewriter.create<LLVM::AddressOfOp>(funcOp.getLoc(), globalBase);
461461
}
462462

463+
int TargetInfo::getAddressSpace(Attribute addressSpace) const {
464+
int spaceId = 0;
465+
if (isa<triton::gpu::SharedMemorySpaceAttr>(addressSpace)) {
466+
spaceId = 3;
467+
} else {
468+
llvm::report_fatal_error("Only support SharedMemorySpace for now");
469+
}
470+
return spaceId;
471+
}
472+
463473
bool TargetInfo::supportVectorizedAtomics() const {
464474
// Note: not currently tested or used, but AMD generally supports vectorized
465475
// atomics.

third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,11 @@ class TargetInfo : public mlir::triton::TargetInfoBase {
6565

6666
void assertFail(RewriterBase &rewriter, Location loc, StringRef message,
6767
StringRef file, StringRef func, int line) const override;
68+
6869
int getSharedAddressSpace() const override;
6970

71+
int getAddressSpace(Attribute addressSpace) const override;
72+
7073
bool supportVectorizedAtomics() const override;
7174

7275
void storeOpAnnotation(triton::gpu::LocalStoreOp op, size_t localStoreOpCount,

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -630,6 +630,18 @@ Value TargetInfo::getStackPointer(RewriterBase &rewriter,
630630
return rewriter.create<LLVM::AddressOfOp>(funcOp.getLoc(), globalBase);
631631
}
632632

633+
int TargetInfo::getAddressSpace(Attribute addressSpace) const {
634+
int spaceId = 0;
635+
if (isa<triton::gpu::SharedMemorySpaceAttr,
636+
triton::nvidia_gpu::TensorMemorySpaceAttr>(addressSpace)) {
637+
spaceId = 3;
638+
} else {
639+
llvm::report_fatal_error(
640+
"Only support SharedMemorySpace, TensorMemorySpace for now");
641+
}
642+
return spaceId;
643+
}
644+
633645
bool TargetInfo::supportVectorizedAtomics() const {
634646
return computeCapability >= 90 && ptxVersion >= 81;
635647
}

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,14 @@ class TargetInfo : public mlir::triton::TargetInfoBase {
5858

5959
void assertFail(RewriterBase &rewriter, Location loc, StringRef message,
6060
StringRef file, StringRef func, int line) const override;
61+
6162
int getSharedAddressSpace() const override;
6263

6364
Value getStackPointer(RewriterBase &rewriter,
6465
FunctionOpInterface funcOp) const override;
6566

67+
int getAddressSpace(Attribute addressSpace) const override;
68+
6669
bool supportVectorizedAtomics() const override;
6770

6871
int getPtxVersion() const { return ptxVersion; }

0 commit comments

Comments
 (0)