Skip to content

Commit 991cf5f

Browse files
giuserosliuyunqi20
authored andcommitted
[AMD] Add pass to convert tt.load/tt.store to buffer operations (#4966)
This PR is only introducing a ttgir pass to convert `tt.load`/`tt.store` to `amdgpu.buffer_load`/`amdgpu.buffer_load`, _when this is possible_ : this means we need to check for 3 conditions: 1. The pointer arithmetic has been canonicalized (`scalarPtr->splat->addptr->load/store`) 2. The offsets are 32-bits 3. The offsets are non-negative. We use a mix of analysis and assumptions to verify this condition Right now the functionality is gated behind an `AMDGCN_USE_BUFFER_OPS`, which now also covers the pointer canonicalization pass which is mostly meant to handle this.
1 parent 8f6b089 commit 991cf5f

File tree

13 files changed

+472
-13
lines changed

13 files changed

+472
-13
lines changed

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ namespace mlir::triton {
1313
inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
1414
// clang-format off
1515
"AMDGCN_ENABLE_DUMP",
16+
"AMDGCN_USE_BUFFER_OPS",
1617
"DISABLE_FAST_REDUCTION",
1718
"DISABLE_LLVM_OPT",
1819
"DISABLE_MMA_V3",

python/test/unit/language/test_core.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3978,10 +3978,11 @@ def _kernel(dst, src, CACHE: tl.constexpr):
39783978
amdgcn = pgm.asm['amdgcn']
39793979
cg_cache_modifier_str = 'nt'
39803980
cv_cache_modifier_str = 'sc0 sc1'
3981+
buffer_load_line = [line for line in amdgcn.splitlines() if "buffer_load" in line]
39813982
global_load_line = [line for line in amdgcn.splitlines() if "global_load" in line]
39823983
flat_load_line = [line for line in amdgcn.splitlines() if "flat_load" in line]
39833984
if cache == '' or cache == '.ca':
3984-
assert cg_cache_modifier_str not in global_load_line[0]
3985+
assert cg_cache_modifier_str not in (global_load_line[0] if global_load_line else buffer_load_line[0])
39853986
if cache == '.cg':
39863987
assert cg_cache_modifier_str in global_load_line[0]
39873988
if cache == '.cv':

test/TritonGPU/amd/amd-canonicalize-pointers.mlir

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,46 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war
8989

9090
// -----
9191

92+
#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>
93+
module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
94+
//
95+
// This is the same as conversion3, but now the `arith.extsi` operations
96+
// disappeared and all the offsets are 32 bits.
97+
//
98+
// CHECK-LABEL: tt.func @conversion4
99+
tt.func @conversion4(%arg0: !tt.ptr<f32>{tt.pointer_range = 32 : i32})-> tensor<1024xf32, #blocked>{
100+
%c1024_i32 = arith.constant 1024 : i32
101+
%0 = tt.get_program_id x : i32
102+
%1 = arith.muli %0, %c1024_i32 : i32
103+
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
104+
%3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked>
105+
%4 = arith.addi %3, %2 : tensor<1024xi32, #blocked>
106+
107+
//CHECK: %0 = tt.get_program_id x : i32
108+
//CHECK: %[[pid:.*]] = arith.muli %0, {{.*}} : i32
109+
//CHECK: %[[makerange:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
110+
//CHECK: %[[uniformOffset1:.*]] = arith.addi %[[pid]], {{.*}} : i32
111+
//CHECK: %[[tensorOffset1:.*]] = arith.addi %{{.*}}, %[[makerange]] : tensor<1024xi32, #blocked>
112+
//CHECK: %[[uniformOffset0:.*]] = arith.addi %[[pid:.*]], %{{.*}} : i32
113+
//CHECK: %[[tensorOffset3:.*]] = arith.addi %{{.*}}, %[[makerange]] : tensor<1024xi32, #blocked>
114+
//CHECK: %[[zero:.*]] = tt.splat %{{.*}} : i32 -> tensor<1024xi32, #blocked>
115+
//CHECK: %[[uniformPtr0:.*]] = tt.addptr %arg0, %[[uniformOffset0:.*]] : !tt.ptr<f32>, i32
116+
//CHECK: %[[tensorOffset0:.*]]= arith.addi %[[tensorOffset3]], %[[zero]] : tensor<1024xi32, #blocked>
117+
//CHECK: %[[uniformPtr1:.*]] = tt.addptr %[[uniformPtr0]], %[[uniformOffset1]] : !tt.ptr<f32>, i32
118+
//CHECK: %[[tensorOffset2:.*]] = arith.addi %[[tensorOffset1]], %[[tensorOffset0]]: tensor<1024xi32, #blocked>
119+
//CHECK: %[[scalarPtr:.*]] = tt.splat %[[uniformPtr1]] : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>, #blocked>
120+
//CHECK: %[[newPtr:.*]] = tt.addptr %[[scalarPtr]], %[[tensorOffset2]] : tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xi32, #blocked>
121+
//CHECK: tt.load %[[newPtr]]
122+
%5 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>, #blocked>
123+
%6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xi32, #blocked>
124+
%7 = tt.addptr %6, %4 : tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xi32, #blocked>
125+
%8 = tt.load %7 : tensor<1024x!tt.ptr<f32>, #blocked>
126+
tt.return %8 : tensor<1024xf32, #blocked>
127+
}
128+
}
129+
130+
// -----
131+
92132
#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>
93133
module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
94134
// CHECK-LABEL: tt.func @forOp
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
// RUN: triton-opt %s -split-input-file --tritonamdgpu-convert-buffer-ops | FileCheck %s
2+
3+
#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
4+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
5+
// CHECK-LABEL: simple
6+
tt.func @simple(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 :i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32) {
7+
%c256_i32 = arith.constant 256 : i32
8+
%0 = tt.get_program_id x : i32
9+
%1 = arith.muli %0, %c256_i32 : i32
10+
%2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0>
11+
%3 = tt.splat %1 : i32 -> tensor<256xi32, #blocked0>
12+
// CHECK: %[[offset:.*]] = arith.addi
13+
%4 = arith.addi %3, %2 : tensor<256xi32, #blocked0>
14+
%5 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<256x!tt.ptr<f32>, #blocked0>
15+
%6 = tt.addptr %5, %4 : tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xi32, #blocked0>
16+
%7 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<256x!tt.ptr<f32>, #blocked0>
17+
%8 = tt.addptr %7, %4 : tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xi32, #blocked0>
18+
// CHECK: buffer_load %arg0[%[[offset]]]
19+
%9 = tt.load %6 : tensor<256x!tt.ptr<f32>, #blocked0>
20+
// CHECK: buffer_load %arg1[%[[offset]]]
21+
%10 = tt.load %8 : tensor<256x!tt.ptr<f32>, #blocked0>
22+
// CHECK: %[[data:.*]] = arith.addf
23+
%11 = arith.addf %9, %10 : tensor<256xf32, #blocked0>
24+
%12 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<256x!tt.ptr<f32>, #blocked0>
25+
%13 = tt.addptr %12, %4 : tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xi32, #blocked0>
26+
// CHECK: buffer_store %[[data]], %arg2[%[[offset]]]
27+
tt.store %13, %11 : tensor<256x!tt.ptr<f32>, #blocked0>
28+
tt.return
29+
}
30+
}
31+
32+
// -----
33+
34+
#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
35+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
36+
// CHECK-LABEL: assume_positive_offset
37+
tt.func @assume_positive_offset(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}) -> tensor<1024xf32, #blocked>{
38+
%c1024_i32 = arith.constant 1024 : i32
39+
%c128_i32 = arith.constant 128 : i32
40+
%c0_i32 = arith.constant 0 : i32
41+
%0 = tt.get_program_id x : i32
42+
%1 = arith.muli %0, %c1024_i32 : i32
43+
%sub = arith.subi %1, %c128_i32 : i32
44+
%cmp = arith.cmpi sgt, %sub, %c0_i32 : i32
45+
"llvm.intr.assume"(%cmp) : (i1) -> ()
46+
%2 = tt.splat %sub : i32 -> tensor<1024xi32, #blocked>
47+
%3 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
48+
// CHECK: %[[offset:.*]] = arith.addi
49+
%4 = arith.addi %2, %3 : tensor<1024xi32, #blocked>
50+
// CHECK: %[[scalar_ptr:.*]] = tt.addptr %arg0
51+
%5 = tt.addptr %arg0, %1 : !tt.ptr<f32>, i32
52+
%8 = tt.splat %5 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>, #blocked>
53+
%9 = tt.addptr %8, %4 : tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xi32, #blocked>
54+
// CHECK: buffer_load %[[scalar_ptr]][%[[offset]]]
55+
%10 = tt.load %9 : tensor<1024x!tt.ptr<f32>, #blocked>
56+
tt.return %10 : tensor<1024xf32, #blocked>
57+
}
58+
}
59+
60+
// -----
61+
62+
#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
63+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
64+
// CHECK-LABEL: offset_64_bits
65+
tt.func @offset_64_bits(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}) -> tensor<1024xf32, #blocked> {
66+
%c1024_i32 = arith.constant 1024 : i32
67+
%c128_i32 = arith.constant 128 : i32
68+
%0 = tt.get_program_id x : i32
69+
%1 = arith.muli %0, %c1024_i32 : i32
70+
%sub = arith.subi %1, %c128_i32 : i32
71+
%2 = tt.splat %sub : i32 -> tensor<1024xi32, #blocked>
72+
%3 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
73+
%ext2 = arith.extsi %2 : tensor<1024xi32, #blocked> to tensor<1024xi64, #blocked>
74+
%ext3 = arith.extsi %3 : tensor<1024xi32, #blocked> to tensor<1024xi64, #blocked>
75+
%4 = arith.addi %ext2, %ext3 : tensor<1024xi64, #blocked>
76+
%5 = tt.addptr %arg0, %1 : !tt.ptr<f32>, i32
77+
%8 = tt.splat %5 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>, #blocked>
78+
%9 = tt.addptr %8, %4 : tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xi64, #blocked>
79+
// CHECK: tt.load
80+
%10 = tt.load %9 : tensor<1024x!tt.ptr<f32>, #blocked>
81+
tt.return %10 : tensor<1024xf32, #blocked>
82+
}
83+
}
84+
85+
// -----
86+
87+
#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
88+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
89+
// CHECK-LABEL: offset_64_bits_narrow
90+
tt.func public @offset_64_bits_narrow(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) -> tensor<1024xf32, #blocked> {
91+
%c1024_i32 = arith.constant 1024 : i32
92+
%c128_i32 = arith.constant 128 : i32
93+
%0 = tt.get_program_id x : i32
94+
%1 = arith.muli %0, %c1024_i32 : i32
95+
%2 = tt.splat %1: i32 -> tensor<1024xi32, #blocked>
96+
%3 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
97+
%ext2 = arith.extsi %2 : tensor<1024xi32, #blocked> to tensor<1024xi64, #blocked>
98+
%ext3 = arith.extsi %3 : tensor<1024xi32, #blocked> to tensor<1024xi64, #blocked>
99+
%4 = arith.addi %ext2, %ext3 : tensor<1024xi64, #blocked>
100+
// CHECK: %[[scalar_ptr:.*]] = tt.addptr %arg0
101+
%5 = tt.addptr %arg0, %1 : !tt.ptr<f32>, i32
102+
%8 = tt.splat %5 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>, #blocked>
103+
// CHECK: %[[offset_32_bit:.*]] = arith.trunci
104+
%narrow4 = arith.trunci %4 : tensor<1024xi64, #blocked> to tensor <1024xi32, #blocked>
105+
%9 = tt.addptr %8, %narrow4 : tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xi32, #blocked>
106+
// CHECK: buffer_load %[[scalar_ptr]][%[[offset_32_bit]]]
107+
%10 = tt.load %9 : tensor<1024x!tt.ptr<f32>, #blocked>
108+
tt.return %10 : tensor<1024xf32, #blocked>
109+
}
110+
}
111+
112+
// -----
113+
114+
#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
115+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
116+
// CHECK-LABEL: non_canonical_ptr
117+
tt.func @non_canonical_ptr(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: tensor<1024xi32, #blocked>) -> tensor<1024xf32, #blocked>{
118+
%8 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>, #blocked>
119+
%9 = tt.addptr %8, %arg1: tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xi32, #blocked>
120+
// CHECK: tt.load
121+
%10 = tt.load %9 : tensor<1024x!tt.ptr<f32>, #blocked>
122+
tt.return %10 : tensor<1024xf32, #blocked>
123+
}
124+
}

third_party/amd/backend/compiler.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,10 @@ def make_ttgir(mod, metadata, options):
229229
passes.ttgpuir.add_reduce_data_duplication(pm)
230230
if amd.has_matrix_core_feature(options.arch):
231231
amd.passes.ttgpuir.add_reorder_instructions(pm)
232-
amd.passes.ttgpuir.add_canonicalize_pointers(pm)
232+
if os.environ.get("AMDGCN_USE_BUFFER_OPS", "0") == "1":
233+
amd.passes.ttgpuir.add_canonicalize_pointers(pm)
234+
passes.common.add_canonicalizer(pm)
235+
amd.passes.ttgpuir.add_convert_to_buffer_ops(pm)
233236
passes.common.add_canonicalizer(pm)
234237
passes.common.add_cse(pm)
235238
passes.common.add_symbol_dce(pm)
@@ -274,11 +277,6 @@ def make_llir(src, metadata, options):
274277
amd.passes.ttgpuir.lower_instruction_sched_hints(pm, options.instruction_sched_variant)
275278
if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0":
276279
passes.llvmir.add_di_scope(pm)
277-
# This pass (`add_builtin_func_to_llvmir`) serves as a temporary workaround to address the issue of excessive basic block
278-
# count caused by predicated loads/stores. In certain kernels, the addition of these blocks can cause the MLIR
279-
# canonicalizer to never finish when attempting to merge blocks. The permanent solution under consideration
280-
# involves using MUBUF instructions that have built-in out-of-bounds checks, which would eliminate the need
281-
# for conditional branching around memory accesses.
282280
amd.passes.ttgpuir.add_builtin_func_to_llvmir(pm, __HIP_FTZ)
283281
pm.run(mod)
284282

third_party/amd/include/TritonAMDGPUTransforms/Passes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define TRITON_DIALECT_TRITONAMDGPU_TRANSFORMS_PASSES_H_
33

44
#include "mlir/Pass/Pass.h"
5+
#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h"
56
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
67

78
namespace mlir {
@@ -23,6 +24,8 @@ std::unique_ptr<Pass> createTritonAMDGPUOptimizeEpiloguePass();
2324

2425
std::unique_ptr<Pass> createTritonAMDGPUCanonicalizePointersPass();
2526

27+
std::unique_ptr<Pass> createTritonAMDGPUConvertToBufferOpsPass();
28+
2629
/// Generate the code for registering passes.
2730
#define GEN_PASS_REGISTRATION
2831
#include "TritonAMDGPUTransforms/Passes.h.inc"

third_party/amd/include/TritonAMDGPUTransforms/Passes.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,4 +111,14 @@ def TritonAMDGPUReorderInstructions: Pass<"tritonamdgpu-reorder-instructions", "
111111
let dependentDialects = [];
112112
}
113113

114+
def TritonAMDGPUConvertToBufferOps : Pass<"tritonamdgpu-convert-buffer-ops", "mlir::ModuleOp"> {
115+
let summary = "Convert memory operations to buffer operations";
116+
117+
let description = "This pass converts memory operations (e.g., tt.load/tt.store) to amdgpu buffer operations, if possible";
118+
119+
let constructor = "mlir::createTritonAMDGPUConvertToBufferOpsPass()";
120+
121+
let dependentDialects = ["mlir::triton::amdgpu::TritonAMDGPUDialect"];
122+
}
123+
114124
#endif

third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ Type BufferEmitter::getBufferOpType(Type type) {
133133
// will be bitcast-able to the original type. So if the types
134134
// ended up different, we simply have to emit a `bitcastOp` to convert
135135
Type bufferType = type;
136-
if (bufferVecSize != vecSize)
136+
if (bufferVecSize != vecSize || bufferElementType != elementType)
137137
bufferType = VectorType::get(bufferVecSize, bufferElementType);
138138
if (bufferVecSize == 1)
139139
bufferType = getElementTypeOrSelf(bufferType);

third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ struct LoadStoreConversionBase {
165165
// Get alignment from the pointer. Since this is a scalar pointer
166166
// we should not take the pointer contiguity to consider alignment
167167
auto *axisInfo = axisAnalysisPass.getAxisInfo(ptr);
168-
auto maxMultipleBytes = axisInfo->getDivisibility(order[0]);
168+
auto maxMultipleBytes = axisInfo->getDivisibility(0);
169169
auto elemNumBits = triton::getPointeeBitWidth(tensorTy);
170170
auto elemNumBytes = std::max<unsigned>(elemNumBits / 8, 1);
171171
auto align = std::max<int64_t>(maxMultipleBytes / elemNumBytes, 1);

third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
add_triton_library(TritonAMDGPUTransforms
22
AccelerateAMDMatmul.cpp
33
CanonicalizePointers.cpp
4+
ConvertToBufferOps.cpp
45
OptimizeEpilogue.cpp
56
ReorderInstructions.cpp
67
StreamPipelineV2.cpp

0 commit comments

Comments
 (0)