Skip to content

Commit 1bc283c

Browse files
Merge commit '152ef2deb8852d5c84f9ffba217b3a7c8f398c5f'
2 parents e6df65e + 152ef2d commit 1bc283c

File tree

43 files changed

+804
-90
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+804
-90
lines changed

bin/RegisterTritonDialects.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
8787
mlir::registerTritonAMDGPUReorderInstructions();
8888
mlir::registerTritonAMDGPUStreamPipelineV2();
8989
mlir::registerTritonAMDGPUCanonicalizePointers();
90+
mlir::registerTritonAMDGPUConvertToBufferOps();
9091

9192
// TODO: register Triton & TritonGPU passes
9293
registry.insert<mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect,

include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ class TargetInfoBase {
8282

8383
virtual int getSharedAddressSpace() const = 0;
8484

85+
virtual bool supportVectorizedAtomics() const = 0;
86+
8587
virtual ~TargetInfoBase() {}
8688
};
8789
} // namespace mlir::triton

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ namespace mlir::triton {
1313
inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
1414
// clang-format off
1515
"AMDGCN_ENABLE_DUMP",
16+
"AMDGCN_USE_BUFFER_OPS",
1617
"DISABLE_FAST_REDUCTION",
1718
"DISABLE_LLVM_OPT",
1819
"DISABLE_MMA_V3",

lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 44 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -109,27 +109,30 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
109109
: ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) {
110110
}
111111

112+
// FIXME [Dot LL]
113+
// Do for all DotOperandEncodingAttr once we have LLs for all of them
114+
static bool isSupportedDotOpLayout(Attribute layout) {
115+
if (auto dot = dyn_cast<DotOperandEncodingAttr>(layout)) {
116+
if (auto mma = dyn_cast<NvidiaMmaEncodingAttr>(dot.getParent())) {
117+
return mma.isAmpere() && dot.getKWidth() == 8;
118+
}
119+
if (isa<AMDMfmaEncodingAttr>(dot.getParent()))
120+
return true;
121+
}
122+
return false;
123+
};
124+
112125
LogicalResult
113126
matchAndRewrite(LocalLoadOp op, OpAdaptor adaptor,
114127
ConversionPatternRewriter &rewriter) const override {
115128
MemDescType srcTy = op.getSrc().getType();
116129
RankedTensorType dstTy = op.getType();
117130
Attribute srcLayout = srcTy.getEncoding();
118131
Attribute dstLayout = dstTy.getEncoding();
119-
// FIXME [Dot LL]
120-
// Do for all DotOperandEncodingAttr once we have LLs for all of them
121-
auto isAmpereLargeKWidth = [](Attribute layout) {
122-
if (auto dot = dyn_cast<DotOperandEncodingAttr>(layout)) {
123-
if (auto mma = dyn_cast<NvidiaMmaEncodingAttr>(dot.getParent())) {
124-
return mma.isAmpere() && dot.getKWidth() == 8;
125-
}
126-
}
127-
return false;
128-
};
129132
if (isa<SharedEncodingAttr>(srcLayout) &&
130133
(isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr>(
131134
dstLayout) ||
132-
isAmpereLargeKWidth(dstLayout))) {
135+
isSupportedDotOpLayout(dstLayout))) {
133136
return lowerSharedToDistributed(op, adaptor, getTypeConverter(),
134137
rewriter);
135138
}
@@ -167,10 +170,10 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
167170
auto srcTy = op.getSrc().getType();
168171
auto dstTy = op.getResult().getType();
169172
auto dstShape = dstTy.getShape();
170-
assert(dstShape.size() <= 2 &&
171-
"Unexpected rank of ConvertLayout(shared->blocked)");
172173
auto srcSharedLayout = cast<SharedEncodingAttr>(srcTy.getEncoding());
173174
auto dstLayout = dstTy.getEncoding();
175+
assert((dstShape.size() <= 2 || isSupportedDotOpLayout(dstLayout)) &&
176+
"Unexpected rank of ConvertLayout(shared->distributed)");
174177
auto inOrd = getOrder(srcSharedLayout);
175178

176179
auto smemObj = LLVM::getSharedMemoryObjectFromStruct(
@@ -184,31 +187,36 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
184187
// FIXME [Dot LL]
185188
// Ampere case
186189
// In this case, we need to pack the outputs into i32
187-
if (isa<DotOperandEncodingAttr>(dstTy.getEncoding())) {
188-
if (elemLlvmTy.isInteger(8)) {
189-
auto concat = [&](Value a1, Value a2, Value a3, Value a4) {
190-
return or_(or_(zext(i32_ty, a1), shl(zext(i32_ty, a2), i32_val(8))),
191-
or_(shl(zext(i32_ty, a3), i32_val(16)),
192-
shl(zext(i32_ty, a4), i32_val(24))));
193-
};
194-
SmallVector<Value> outVals32(outVals.size() / 4);
195-
for (int i = 0; i < outVals32.size(); ++i) {
196-
outVals32[i] = concat(outVals[4 * i], outVals[4 * i + 1],
197-
outVals[4 * i + 2], outVals[4 * i + 3]);
198-
}
199-
outVals = outVals32;
200-
} else {
201-
assert(elemLlvmTy.isBF16() && "Unexpected element type");
202-
auto concat = [&](Value a, Value b) {
203-
return or_(zext(i32_ty, bitcast(a, i16_ty)),
204-
shl(zext(i32_ty, bitcast(b, i16_ty)), i32_val(16)));
205-
};
190+
if (auto dotOp = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding())) {
191+
if (auto parent = dyn_cast<NvidiaMmaEncodingAttr>(dotOp.getParent())) {
192+
if (parent.isAmpere()) {
193+
if (elemLlvmTy.isInteger(8)) {
194+
auto concat = [&](Value a1, Value a2, Value a3, Value a4) {
195+
return or_(
196+
or_(zext(i32_ty, a1), shl(zext(i32_ty, a2), i32_val(8))),
197+
or_(shl(zext(i32_ty, a3), i32_val(16)),
198+
shl(zext(i32_ty, a4), i32_val(24))));
199+
};
200+
SmallVector<Value> outVals32(outVals.size() / 4);
201+
for (int i = 0; i < outVals32.size(); ++i) {
202+
outVals32[i] = concat(outVals[4 * i], outVals[4 * i + 1],
203+
outVals[4 * i + 2], outVals[4 * i + 3]);
204+
}
205+
outVals = outVals32;
206+
} else {
207+
assert(elemLlvmTy.isBF16() && "Unexpected element type");
208+
auto concat = [&](Value a, Value b) {
209+
return or_(zext(i32_ty, bitcast(a, i16_ty)),
210+
shl(zext(i32_ty, bitcast(b, i16_ty)), i32_val(16)));
211+
};
206212

207-
SmallVector<Value> outVals32(outVals.size() / 2);
208-
for (int i = 0; i < outVals32.size(); ++i) {
209-
outVals32[i] = concat(outVals[2 * i], outVals[2 * i + 1]);
213+
SmallVector<Value> outVals32(outVals.size() / 2);
214+
for (int i = 0; i < outVals32.size(); ++i) {
215+
outVals32[i] = concat(outVals[2 * i], outVals[2 * i + 1]);
216+
}
217+
outVals = outVals32;
218+
}
210219
}
211-
outVals = outVals32;
212220
}
213221
}
214222

python/test/unit/language/test_core.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4026,10 +4026,11 @@ def _kernel(dst, src, CACHE: tl.constexpr):
40264026
amdgcn = pgm.asm['amdgcn']
40274027
cg_cache_modifier_str = 'nt'
40284028
cv_cache_modifier_str = 'sc0 sc1'
4029+
buffer_load_line = [line for line in amdgcn.splitlines() if "buffer_load" in line]
40294030
global_load_line = [line for line in amdgcn.splitlines() if "global_load" in line]
40304031
flat_load_line = [line for line in amdgcn.splitlines() if "flat_load" in line]
40314032
if cache == '' or cache == '.ca':
4032-
assert cg_cache_modifier_str not in global_load_line[0]
4033+
assert cg_cache_modifier_str not in (global_load_line[0] if global_load_line else buffer_load_line[0])
40334034
if cache == '.cg':
40344035
assert cg_cache_modifier_str in global_load_line[0]
40354036
if cache == '.cv':
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm="arch=gfx942 ftz=True" --convert-builtin-func-to-llvm="ftz=True" | FileCheck %s --check-prefix=LLVM_FTZ
2+
// RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm="arch=gfx942 ftz=False" --convert-builtin-func-to-llvm="ftz=False" | FileCheck %s --check-prefix=LLVM_NO_FTZ
3+
4+
#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}>
5+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} {
6+
tt.func public @test_fast_expf(%arg0: tensor<64xf32, #blocked>) attributes {noinline = false} {
7+
// LLVM_FTZ: llvm.amdgcn.exp2.f32
8+
// LLVM_NO_FTZ: llvm.exp2.f32
9+
%0 = tt.extern_elementwise %arg0 {libname = "libdevice", libpath = "", pure = true, symbol = "__triton_hip_fast_expf"} : (tensor<64xf32, #blocked>) -> tensor<64xf32, #blocked>
10+
tt.return
11+
}
12+
}
Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
1-
// RUN: triton-opt %s --split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx942 | FileCheck %s
1+
// RUN: triton-opt %s --split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx942 --mlir-print-debuginfo --mlir-pretty-debuginfo| FileCheck %s
22

33
#blocked = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
44
#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 16], isTransposed = false}>
55
#shared = #triton_gpu.shared<{vec = 16, perPhase = 4, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}>
66
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.shared = 544 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
77
// CHECK-LABEL: @local_load_offset
88
tt.func @local_load_offset(%arg0: tensor<16x16xf16, #mma>) {
9-
%0 = triton_gpu.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #blocked>
10-
%1 = triton_gpu.local_alloc %0 {allocation.offset = 0 : i32} : (tensor<16x16xf16, #blocked>) -> !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory>
9+
%0 = triton_gpu.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #blocked> loc(#loc1)
10+
%1 = triton_gpu.local_alloc %0 {allocation.offset = 0 : i32} : (tensor<16x16xf16, #blocked>) -> !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> loc(#loc2)
1111
// This catches base ptr calculation in the computeBasePtr, checks if the gep has correct element type.
12-
// CHECK: llvm.sub
13-
// CHECK-NEXT: llvm.getelementptr
14-
// CHECK-SAME: (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16
15-
%2 = triton_gpu.local_load %1 : !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>
12+
// CHECK: llvm.getelementptr {{.*}} (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16 local_load:3:0
13+
%2 = triton_gpu.local_load %1 : !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> loc(#loc3)
1614
tt.return
1715
}
1816
}
17+
#loc1 = loc("conert_layout":1:0)
18+
#loc2 = loc("local_alloc":2:0)
19+
#loc3 = loc("local_load":3:0)

test/Conversion/amd/tritongpu_to_llvm.mlir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,31 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
3434
tt.return
3535
}
3636
}
37+
38+
// -----
39+
40+
// Smoke test to check that mfma 32 and dot operand layouts can work with small tensors, for example with shape 16x16
41+
#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [32, 32], isTransposed = true}>
42+
#dotop0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=4}>
43+
#dotop1 = #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth=4}>
44+
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}>
45+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
46+
// CHECK-LABEL: small_mfma_tensor_conversions
47+
tt.func public @small_mfma_tensor_conversions(%arg0: tensor<16x16xf16, #mfma>, %arg1: tensor<16x16x!tt.ptr<f32>, #mfma>) {
48+
// CHECK-NOT: triton_gpu.convert_layout
49+
%0 = triton_gpu.local_alloc %arg0 : (tensor<16x16xf16, #mfma>) -> !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory>
50+
// CHECK-4: store {{.*}} vector<4xf16>
51+
%1 = triton_gpu.local_load %0 : !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> -> tensor<16x16xf16, #dotop0>
52+
// CHECK-2: load {{.*}} vector<4xf16>
53+
%2 = triton_gpu.local_load %0 : !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> -> tensor<16x16xf16, #dotop1>
54+
// CHECK-8: load {{.*}} vector<1xf16>
55+
%3 = triton_gpu.local_load %0 : !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> -> tensor<16x16xf16, #mfma>
56+
// CHECK-4: load {{.*}} vector<4xf16>
57+
%4 = tt.fp_to_fp %3 : tensor<16x16xf16, #mfma> -> tensor<16x16xf32, #mfma>
58+
59+
%5 = tt.dot %1, %2, %4 : tensor<16x16xf16, #dotop0> * tensor<16x16xf16, #dotop1> -> tensor<16x16xf32, #mfma>
60+
// Store result to prevent DCE from removing all conversion related code
61+
%6 = triton_gpu.local_alloc %5 : (tensor<16x16xf32, #mfma>) -> !tt.memdesc<16x16xf32, #shared, #triton_gpu.shared_memory>
62+
tt.return
63+
}
64+
}

test/Conversion/tritongpu_to_llvm_hopper.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-gpu-to-llvm=compute-capability=90 2>&1 | FileCheck %s
1+
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-gpu-to-llvm='compute-capability=90 ptx-version=81' 2>&1 | FileCheck %s
22

33
#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}>
44
#shared = #triton_gpu.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}>
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-gpu-to-llvm='compute-capability=90 ptx-version=80' 2>&1 | FileCheck %s
2+
3+
#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
4+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} {
5+
tt.func public @atomic_add_f32_nomask(%dest_ptrs: tensor<256x!tt.ptr<f32>, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf32, #blocked>) attributes {noinline = false} {
6+
// CHECK-LABEL: atomic_add_f32_nomask
7+
// CHECK: atom.global.gpu.acq_rel.add.f32
8+
// CHECK: atom.global.gpu.acq_rel.add.f32
9+
// CHECK: atom.global.gpu.acq_rel.add.f32
10+
// CHECK: atom.global.gpu.acq_rel.add.f32
11+
%0 = tt.atomic_rmw fadd, acq_rel, gpu, %dest_ptrs, %data : (tensor<256x!tt.ptr<f32>, #blocked>, tensor<256xf32, #blocked>) -> tensor<256xf32, #blocked>
12+
tt.return
13+
}
14+
}
15+
16+
// -----
17+
18+
#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
19+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} {
20+
tt.func public @atomic_add_f32_withmask(%dest_ptrs: tensor<256x!tt.ptr<f32>, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf32, #blocked>, %mask: tensor<256xi1, #blocked> {tt.constancy = 2 : i32}) attributes {noinline = false} {
21+
// CHECK-LABEL: atomic_add_f32_withmask
22+
// CHECK: atom.global.gpu.acq_rel.add.f32
23+
// CHECK: atom.global.gpu.acq_rel.add.f32
24+
// CHECK: atom.global.gpu.acq_rel.add.f32
25+
// CHECK: atom.global.gpu.acq_rel.add.f32
26+
%0 = tt.atomic_rmw fadd, acq_rel, gpu, %dest_ptrs, %data, %mask : (tensor<256x!tt.ptr<f32>, #blocked>, tensor<256xf32, #blocked>, tensor<256xi1, #blocked>) -> tensor<256xf32, #blocked>
27+
tt.return
28+
}
29+
}
30+
31+
// -----
32+
33+
#blocked = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
34+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} {
35+
tt.func public @atomic_add_f16_withmask(%dest_ptrs: tensor<256x!tt.ptr<f16>, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf16, #blocked>, %mask: tensor<256xi1, #blocked> {tt.constancy = 4 : i32}) attributes {noinline = false} {
36+
// CHECK-LABEL: atomic_add_f16_withmask
37+
// CHECK: atom.global.gpu.acq_rel.add.noftz.f16x2
38+
// CHECK: atom.global.gpu.acq_rel.add.noftz.f16x2
39+
// CHECK: atom.global.gpu.acq_rel.add.noftz.f16x2
40+
// CHECK: atom.global.gpu.acq_rel.add.noftz.f16x2
41+
%0 = tt.atomic_rmw fadd, acq_rel, gpu, %dest_ptrs, %data, %mask : (tensor<256x!tt.ptr<f16>, #blocked>, tensor<256xf16, #blocked>, tensor<256xi1, #blocked>) -> tensor<256xf16, #blocked>
42+
tt.return
43+
}
44+
}

0 commit comments

Comments
 (0)