Skip to content

Commit d0abc51

Browse files
[AMD] Introduce specialized Allocation pass (#7328)
This PR introduces AMD specific allocation pass and new attribute that defines conversion method: padded or swizzled. For now OptimizeLDSUsage pass sets all convert layout operations in padded mode. --------- Co-authored-by: Alexander Efimov <[email protected]>
1 parent 355dc47 commit d0abc51

File tree

19 files changed

+247
-60
lines changed

19 files changed

+247
-60
lines changed

bin/RegisterTritonDialects.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
6969
mlir::registerLLVMDIScope();
7070

7171
// TritonAMDGPUToLLVM passes
72+
mlir::triton::registerAllocateAMDGPUSharedMemory();
7273
mlir::triton::registerConvertTritonAMDGPUToLLVM();
7374
mlir::triton::registerConvertBuiltinFuncToLLVM();
7475
mlir::triton::registerOptimizeAMDLDSUsage();
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#ifndef TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ALLOCATE_UTILITY_H_
2+
#define TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ALLOCATE_UTILITY_H_
3+
4+
#include "mlir/IR/BuiltinOps.h"
5+
#include "triton/Analysis/Allocation.h"
6+
7+
namespace mlir::triton::gpu {
8+
9+
/// Attach shared memory related attributes to module and operations inside it.
10+
/// This includes total shared memory consumption in module and shared memory
11+
/// offsets of buffers associated with operations.
12+
void attachAllocationSizeAndOffsetAttr(ModuleOp mod,
13+
ModuleAllocation &allocation);
14+
15+
} // namespace mlir::triton::gpu
16+
17+
#endif // TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ALLOCATE_UTILITY_H_
Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "triton/Analysis/Allocation.h"
22
#include "triton/Analysis/Utility.h"
3+
#include "triton/Conversion/TritonGPUToLLVM/AllocateSharedMemoryUtility.h"
34
#include "triton/Conversion/TritonGPUToLLVM/Passes.h"
45
#include "triton/Dialect/Triton/IR/Dialect.h"
56
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
@@ -18,32 +19,9 @@ struct AllocateSharedMemory
1819
AllocateSharedMemory> {
1920
void runOnOperation() override {
2021
ModuleOp mod = getOperation();
21-
MLIRContext *ctx = &getContext();
2222
ModuleAllocation allocation(mod);
2323

24-
mod.walk<mlir::WalkOrder::PreOrder>([&](FunctionOpInterface funcOp) {
25-
auto *funcAllocation = allocation.getFuncData(funcOp);
26-
funcOp.walk([&](Operation *op) {
27-
auto oBufferId = funcAllocation->getBufferId(op);
28-
int offset = -1;
29-
if (oBufferId != Allocation::InvalidBufferId)
30-
offset = funcAllocation->getOffset(oBufferId);
31-
else if (op->getNumResults() == 1) {
32-
Value value = op->getResult(0);
33-
auto vBufferId = funcAllocation->getBufferId(value);
34-
if (vBufferId != Allocation::InvalidBufferId)
35-
offset = funcAllocation->getOffset(vBufferId);
36-
}
37-
if (offset == -1)
38-
return;
39-
op->setAttr("allocation.offset",
40-
IntegerAttr::get(IntegerType::get(ctx, 32), offset));
41-
});
42-
return WalkResult::skip();
43-
});
44-
mod->setAttr("ttg.shared",
45-
mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 32),
46-
allocation.getSharedMemorySize()));
24+
mlir::triton::gpu::attachAllocationSizeAndOffsetAttr(mod, allocation);
4725
}
4826
};
4927
} // namespace
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#include "triton/Conversion/TritonGPUToLLVM/AllocateSharedMemoryUtility.h"
2+
3+
namespace mlir::triton::gpu {
4+
5+
void attachAllocationSizeAndOffsetAttr(ModuleOp mod,
6+
ModuleAllocation &allocation) {
7+
MLIRContext *ctx = mod.getContext();
8+
9+
mod.walk<mlir::WalkOrder::PreOrder>([&](FunctionOpInterface funcOp) {
10+
auto *funcAllocation = allocation.getFuncData(funcOp);
11+
funcOp.walk([&](Operation *op) {
12+
auto oBufferId = funcAllocation->getBufferId(op);
13+
int offset = -1;
14+
if (oBufferId != Allocation::InvalidBufferId)
15+
offset = funcAllocation->getOffset(oBufferId);
16+
else if (op->getNumResults() == 1) {
17+
Value value = op->getResult(0);
18+
auto vBufferId = funcAllocation->getBufferId(value);
19+
if (vBufferId != Allocation::InvalidBufferId)
20+
offset = funcAllocation->getOffset(vBufferId);
21+
}
22+
if (offset == -1)
23+
return;
24+
op->setAttr("allocation.offset",
25+
IntegerAttr::get(IntegerType::get(ctx, 32), offset));
26+
});
27+
return WalkResult::skip();
28+
});
29+
mod->setAttr("ttg.shared",
30+
mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 32),
31+
allocation.getSharedMemorySize()));
32+
}
33+
34+
} // namespace mlir::triton::gpu

lib/Conversion/TritonGPUToLLVM/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ add_triton_library(TritonGPUToLLVM
22
DotOpToLLVM/FMA.cpp
33
DotOpToLLVM/FMADotUtility.cpp
44
AllocateSharedMemory.cpp
5+
AllocateSharedMemoryUtility.cpp
56
AllocateWarpGroups.cpp
67
AssertOpToLLVM.cpp
78
ControlFlowOpToLLVM.cpp

lib/Tools/GenericSwizzling.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,4 +382,5 @@ LinearLayout optimalSwizzling(const LinearLayout &src, const LinearLayout &dst,
382382

383383
return basis1D.reshapeOuts(src.getOutDims());
384384
}
385+
385386
} // namespace mlir::triton::gpu
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
// RUN: triton-opt %s -split-input-file --allocate-amdgpu-shared-memory | FileCheck %s
2+
3+
#blocked1 = #ttg.blocked<{sizePerThread = [8, 4], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
4+
#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
5+
6+
// This test checks padding based converter.
7+
//
8+
// Converter allocates temporary buffer, stores and reads parts or tensor in few transactions, which are named repeats.
9+
// Size of temporary buffer is computed using the following algorithm:
10+
// - get CTA tile shape of blocked1 layout: [8*8*4, 4*8*1] = [256, 32]
11+
// - get CTA tile shape of blocked2 layout: [1*8*4, 1*8*1] = [32, 8]
12+
// - compute common tile shape is [max(256, 32), max(32, 8)] = [256, 32].
13+
// - pad fastest dimension(same as output layout, 1 in this case) with size of memory access to reduce bank conflicts. 16 bytes in this case.
14+
//
15+
// Therefore total memory consuption for scratch buffer is 256*(32 * 4(size of one element) + 16(padding)) = 36864 bytes
16+
//
17+
// For implementation see mlir::triton::getNumScratchElemsPaddedCvt function.
18+
19+
// CHECK: ttg.shared = 36864 : i32
20+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
21+
22+
// CHECK-LABEL: @convert_layout_padded
23+
tt.func @convert_layout_padded(%arg0: tensor<256x256xi32, #blocked1>) {
24+
// CHECK-NEXT: allocation.offset = 0 : i32
25+
%0 = ttg.convert_layout %arg0 {amdgpu.use_padded_scratch_shmem} : tensor<256x256xi32, #blocked1> -> tensor<256x256xi32, #blocked2>
26+
tt.return
27+
}
28+
29+
}
30+
31+
// -----
32+
33+
#blocked1 = #ttg.blocked<{sizePerThread = [8, 4], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
34+
#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
35+
36+
// This test checks swizzling based converter.
37+
//
38+
// Swizzling converter tries to find swizzling pattern, which provides widest load and store instructions and avoids as much back conflicts as possible.
39+
// Current converter implementation decides that best swizzling patter requires allocation of tile with shape [256, 128], which takes 256*128*4(size of one element) = 131072 bytes
40+
//
41+
// For implementation see mlir::triton::getNumScratchElemsSwizzledCvt function,
42+
// in particular mlir::triton::gpu::optimalSwizzling to get shape of repeat tile.
43+
44+
// CHECK: ttg.shared = 131072 : i32
45+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
46+
47+
// CHECK-LABEL: @convert_layout_swizzled
48+
tt.func @convert_layout_swizzled(%arg0: tensor<256x256xi32, #blocked1>) {
49+
// CHECK-NEXT: allocation.offset = 0 : i32
50+
%0 = ttg.convert_layout %arg0 : tensor<256x256xi32, #blocked1> -> tensor<256x256xi32, #blocked2>
51+
tt.return
52+
}
53+
54+
}

test/TritonGPU/amd/optimize-lds-usage.mlir

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
// CHECK-LABEL: alloc_convert_load
66
// CHECK-32KLIMIT-LABEL: alloc_convert_load
77
// CHECK: %0 = ttg.local_alloc %arg0 : {{.*}}#blocked{{.*}}#shared
8-
// CHECK: %1 = ttg.convert_layout %arg1 : {{.*}}#blocked{{.*}}#blocked1
9-
// CHECK: %2 = ttg.convert_layout %1 : {{.*}}#blocked1{{.*}}#mma
8+
// CHECK: %1 = ttg.convert_layout %arg1 {{.*}}: {{.*}}#blocked{{.*}}#blocked1
9+
// CHECK: %2 = ttg.convert_layout %1 {{.*}}: {{.*}}#blocked1{{.*}}#mma
1010
// CHECK: %3 = ttg.local_load %0 : {{.*}}#shared{{.*}}#ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
1111
#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
1212
#mma = #ttg.amd_mfma<{version = 2, warpsPerCTA = [1, 8], instrShape = [32, 32], isTransposed = false}>
@@ -28,8 +28,8 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32}
2828
// CHECK-LABEL: alloc_convert_small_load
2929
// CHECK-32KLIMIT-LABEL: alloc_convert_small_load
3030
// CHECK: %0 = ttg.local_alloc %arg0 : {{.*}}#blocked{{.*}}#shared
31-
// CHECK: %1 = ttg.convert_layout %arg1 : {{.*}}#blocked{{.*}}#blocked1
32-
// CHECK: %2 = ttg.convert_layout %1 : {{.*}}#blocked1{{.*}}#mma
31+
// CHECK: %1 = ttg.convert_layout %arg1 {{.*}}: {{.*}}#blocked{{.*}}#blocked1
32+
// CHECK: %2 = ttg.convert_layout %1 {{.*}}: {{.*}}#blocked1{{.*}}#mma
3333
// CHECK: %3 = ttg.local_load %0 : {{.*}}#shared{{.*}}#ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
3434
#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
3535
#mma = #ttg.amd_mfma<{version = 2, warpsPerCTA = [1, 8], instrShape = [32, 32], isTransposed = false}>
@@ -55,7 +55,7 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32}
5555
// CHECK-32KLIMIT-LABEL: alloc_convert_3d_load
5656
// CHECK: [[V0:%.*]] = ttg.local_alloc {{.*}}[[$BLOCKED1]]{{.*}}
5757
// CHECK: [[V1:%.*]] = ttg.convert_layout {{.*}}[[$BLOCKED1]]{{.*}}[[$BLOCKED2]]
58-
// CHECK: [[V2:%.*]] = ttg.convert_layout [[V1]] : {{.*}}[[$BLOCKED2]]{{.*}}[[$MMA]]
58+
// CHECK: [[V2:%.*]] = ttg.convert_layout [[V1]] {{.*}}: {{.*}}[[$BLOCKED2]]{{.*}}[[$MMA]]
5959
// CHECK: [[V3:%.*]] = ttg.local_load [[V0]] : {{.*}}#ttg.dot_op<{opIdx = 0, parent = [[$MMA]], kWidth = 4}>>
6060
#blocked = #ttg.blocked<{sizePerThread = [1, 8, 1], threadsPerWarp = [1, 16, 4], warpsPerCTA = [1, 1, 8], order = [0, 1, 2]}>
6161
#mma = #ttg.amd_mfma<{version = 2, warpsPerCTA = [1, 1, 8], instrShape = [32, 32], isTransposed = false}>
@@ -75,12 +75,12 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32}
7575
// Check that optimization triggers with custom LDS limit and do not triggers with default one
7676
// CHECK-LABEL: alloc_convert_32k_limit
7777
// CHECK: %0 = ttg.local_alloc %arg0 : {{.*}}#blocked{{.*}}#shared
78-
// CHECK: %1 = ttg.convert_layout %arg1 : {{.*}}#blocked{{.*}}#mma
78+
// CHECK: %1 = ttg.convert_layout %arg1 {{.*}}: {{.*}}#blocked{{.*}}#mma
7979
// CHECK: %2 = ttg.local_load %0 : {{.*}}#shared{{.*}}#ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
8080
// CHECK-32KLIMIT-LABEL: alloc_convert_32k_limit
8181
// CHECK-32KLIMIT: %0 = ttg.local_alloc %arg0 : {{.*}}#blocked{{.*}}#shared
82-
// CHECK-32KLIMIT: %1 = ttg.convert_layout %arg1 : {{.*}}#blocked{{.*}}#blocked1
83-
// CHECK-32KLIMIT: %2 = ttg.convert_layout %1 : {{.*}}#blocked1{{.*}}#mma
82+
// CHECK-32KLIMIT: %1 = ttg.convert_layout %arg1 {{.*}}: {{.*}}#blocked{{.*}}#blocked1
83+
// CHECK-32KLIMIT: %2 = ttg.convert_layout %1 {{.*}}: {{.*}}#blocked1{{.*}}#mma
8484
// CHECK-32KLIMIT: %3 = ttg.local_load %0 : {{.*}}#shared{{.*}}#ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
8585
#blocked = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
8686
#mma = #ttg.amd_mfma<{version = 2, warpsPerCTA = [1, 8], instrShape = [32, 32], isTransposed = false}>
@@ -106,9 +106,9 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32}
106106

107107
// CHECK: tt.func public @mfma_dot_shortcut([[ARG_0:%[a-z0-9]*]]: {{.*}}, [[ARG_1:%[a-z0-9]*]]: {{.*}}, [[ARG_2:%[a-z0-9]*]]: {{.*}})
108108
// CHECK: [[ALLOC:%[0-9]+]] = ttg.local_alloc [[ARG_0]] : (tensor<128x128xf16, [[BLOCKED_1]]>) -> !ttg.memdesc<128x128xf16, [[SHARED]], #smem>
109-
// CHECK: [[INTERMEDIATE_CONV:%[0-9]+]] = ttg.convert_layout [[ARG_1]] : tensor<128x128xf32, [[BLOCKED_1]]> -> tensor<128x128xf32, [[BLOCKED_2]]>
110-
// CHECK: [[CONVERT_1:%[0-9]+]] = ttg.convert_layout [[INTERMEDIATE_CONV]] : tensor<128x128xf32, [[BLOCKED_2]]> -> tensor<128x128xf32, [[MMA_2]]>
111-
// CHECK: [[CONVERT_2:%[0-9]+]] = ttg.convert_layout [[ARG_2]] : tensor<256x128xf16, [[MMA_1]]> -> tensor<256x128xf16, #ttg.dot_op<{opIdx = 0, parent = [[MMA_1]], kWidth = 4}>>
109+
// CHECK: [[INTERMEDIATE_CONV:%[0-9]+]] = ttg.convert_layout [[ARG_1]] {{.*}}: tensor<128x128xf32, [[BLOCKED_1]]> -> tensor<128x128xf32, [[BLOCKED_2]]>
110+
// CHECK: [[CONVERT_1:%[0-9]+]] = ttg.convert_layout [[INTERMEDIATE_CONV]] {{.*}}: tensor<128x128xf32, [[BLOCKED_2]]> -> tensor<128x128xf32, [[MMA_2]]>
111+
// CHECK: [[CONVERT_2:%[0-9]+]] = ttg.convert_layout [[ARG_2]] {{.*}}: tensor<256x128xf16, [[MMA_1]]> -> tensor<256x128xf16, #ttg.dot_op<{opIdx = 0, parent = [[MMA_1]], kWidth = 4}>>
112112
// CHECK: [[LOAD:%[0-9]+]] = ttg.local_load [[ALLOC]] : !ttg.memdesc<128x128xf16, [[SHARED]], #smem> -> tensor<128x128xf16, #ttg.dot_op<{opIdx = 0, parent = [[MMA_2]], kWidth = 4}>>
113113
#blocked = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
114114
#mma1 = #ttg.amd_mfma<{version = 2, warpsPerCTA = [1, 8], instrShape = [32, 32], isTransposed = false}>

third_party/amd/backend/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ def make_llir(src, metadata, options):
303303
passes.convert.add_scf_to_cf(pm)
304304
passes.convert.add_index_to_llvmir(pm)
305305

306-
passes.ttgpuir.add_allocate_shared_memory(pm)
306+
amd.passes.ttgpuir.add_allocate_shared_memory(pm)
307307
## __HIP_FTZ is used to control the denorm flushing behavior of exp2 op as follows:
308308
## 1. If __HIP_FTZ = 1, exp2 flushes denorms in input and output regardless
309309
## of the value of kernel arg `allow_flush_denorm`.
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#ifndef TRITONAMD_ANALYSIS_AMDGPU_ALLOCATION_H
2+
#define TRITONAMD_ANALYSIS_AMDGPU_ALLOCATION_H
3+
4+
#include "mlir/IR/BuiltinTypes.h"
5+
#include "mlir/IR/Operation.h"
6+
7+
namespace mlir::triton::AMD {
8+
9+
constexpr char AttrSharedMemPadded[] = "amdgpu.use_padded_scratch_shmem";
10+
11+
unsigned getConvertLayoutScratchInBytes(RankedTensorType srcTy,
12+
RankedTensorType dstTy,
13+
bool usePadding);
14+
15+
unsigned AMDAllocationAnalysisScratchSizeFn(Operation *op);
16+
17+
} // namespace mlir::triton::AMD
18+
19+
#endif // TRITONAMD_ANALYSIS_AMDGPU_ALLOCATION_H

0 commit comments

Comments
 (0)