Skip to content

Commit 88b6833

Browse files
authored
[AMD] Add pattern to enforce coalesced write for async load (#6255)
Adds `TritonAMDGPUCoalesceAsyncCopyPass` to convert the blocked layout of AsyncCopies if they produce non coalesced writes on `GFX9` which is a hardware requirement. The pass ensures `sizePerThread` of the blocked layout is not greater than the contiguity of the source and mask elements and the support load size. Support for swizzled shared encodings will be added as a separate PR so the pass skips those `AsyncCopies` for now. This pass will be required when we add `AsyncCopy` pipelining support in the AMD backend in a later PR.
1 parent dee0846 commit 88b6833

File tree

7 files changed

+367
-0
lines changed

7 files changed

+367
-0
lines changed

bin/RegisterTritonDialects.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
7070
mlir::registerTritonAMDGPUCanonicalizePointers();
7171
mlir::registerTritonAMDGPUConvertToBufferOps();
7272
mlir::registerTritonAMDGPUInThreadTranspose();
73+
mlir::registerTritonAMDGPUCoalesceAsyncCopy();
7374
mlir::triton::registerTritonAMDGPUInsertInstructionSchedHints();
7475
mlir::triton::registerTritonAMDGPULowerInstructionSchedHints();
7576

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
// RUN: triton-opt %s -split-input-file --tritonamdgpu-coalesce-async-copy=arch-generation-name=gfx950 | FileCheck %s
2+
3+
#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
4+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
5+
#smem = #ttg.shared_memory
6+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
7+
// sizePerThread = [1] because we have no information about contiguity of src pointers
8+
// CHECK: #[[NEW_BLOCKED:.*]] = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
9+
tt.func @async_copy_1d(%input: tensor<1024x!tt.ptr<f16>, #blocked>,
10+
%view: !ttg.memdesc<1024xf16, #shared, #smem, mutable>) {
11+
// CHECK: %{{.*}} = ttg.convert_layout %{{.*}} : {{.*}} -> tensor<1024x!tt.ptr<f16>, #[[NEW_BLOCKED]]>
12+
// CHECK: %{{.*}} = ttg.async_copy_global_to_local %{{.*}}: tensor<1024x!tt.ptr<f16>, #[[NEW_BLOCKED]]>
13+
%token = ttg.async_copy_global_to_local %input, %view: tensor<1024x!tt.ptr<f16>, #blocked> -> <1024xf16, #shared, #smem, mutable>
14+
tt.return
15+
}
16+
}
17+
18+
// -----
19+
20+
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
21+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
22+
#smem = #ttg.shared_memory
23+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
24+
// sizePerThread = [1, 1] because we have no information about contiguity of src pointers
25+
// CHECK: #[[NEW_BLOCKED:.*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}>
26+
tt.func @async_copy_2d(%input: tensor<64x64x!tt.ptr<f16>, #blocked>,
27+
%view: !ttg.memdesc<64x64xf16, #shared, #smem, mutable>) {
28+
// CHECK: %{{.*}} = ttg.convert_layout %{{.*}} : {{.*}} -> tensor<64x64x!tt.ptr<f16>, #[[NEW_BLOCKED]]>
29+
// CHECK: %{{.*}} = ttg.async_copy_global_to_local %{{.*}}: tensor<64x64x!tt.ptr<f16>, #[[NEW_BLOCKED]]>
30+
%token = ttg.async_copy_global_to_local %input, %view: tensor<64x64x!tt.ptr<f16>, #blocked> -> <64x64xf16, #shared, #smem, mutable>
31+
tt.return
32+
}
33+
}
34+
35+
// -----
36+
37+
#blocked = #ttg.blocked<{sizePerThread = [8, 1, 1], threadsPerWarp = [32, 1, 1], warpsPerCTA = [1,2,2], order = [0,1,2]}>
38+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0,1,2]}>
39+
#smem = #ttg.shared_memory
40+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
41+
// sizePerThread = [1, 1, 1] because we have no information about contiguity of src pointers
42+
// CHECK: #[[NEW_BLOCKED:.*]] = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [32, 1, 1], warpsPerCTA = [4, 1, 1], order = [0, 1, 2]}>
43+
tt.func @async_copy_3d(%input: tensor<1024x1024x1024x!tt.ptr<f16>, #blocked>,
44+
%view: !ttg.memdesc<1024x1024x1024xf16, #shared, #smem, mutable>) {
45+
// CHECK: %{{.*}} = ttg.convert_layout %{{.*}} : {{.*}} -> tensor<1024x1024x1024x!tt.ptr<f16>, #[[NEW_BLOCKED]]>
46+
// CHECK: %{{.*}} = ttg.async_copy_global_to_local %{{.*}}: tensor<1024x1024x1024x!tt.ptr<f16>, #[[NEW_BLOCKED]]>
47+
%token = ttg.async_copy_global_to_local %input, %view: tensor<1024x1024x1024x!tt.ptr<f16>, #blocked> -> <1024x1024x1024xf16, #shared, #smem, mutable>
48+
tt.return
49+
}
50+
}
51+
52+
// -----
53+
54+
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
55+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
56+
#smem = #ttg.shared_memory
57+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
58+
// CHECK: #[[NEW_BLOCKED:.*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}>
59+
tt.func @async_copy_with_mask_and_other(%input: tensor<64x64x!tt.ptr<f16>, #blocked>,
60+
%view: !ttg.memdesc<64x64xf16, #shared, #smem, mutable>,
61+
%mask: tensor<64x64xi1, #blocked>,
62+
%other: tensor<64x64xf16, #blocked>) {
63+
// CHECK: %{{.*}} = ttg.convert_layout %{{.*}} : {{.*}} -> tensor<64x64x!tt.ptr<f16>, #[[NEW_BLOCKED]]>
64+
// CHECK: %{{.*}} = ttg.convert_layout %{{.*}} : {{.*}} -> tensor<64x64xi1, #[[NEW_BLOCKED]]>
65+
// CHECK: %{{.*}} = ttg.convert_layout %{{.*}} : {{.*}} -> tensor<64x64xf16, #[[NEW_BLOCKED]]>
66+
// CHECK: %{{.*}} = ttg.async_copy_global_to_local %{{.*}}: tensor<64x64x!tt.ptr<f16>, #[[NEW_BLOCKED]]>
67+
%token = ttg.async_copy_global_to_local %input, %view mask %mask other %other: tensor<64x64x!tt.ptr<f16>, #blocked> -> <64x64xf16, #shared, #smem, mutable>
68+
tt.return
69+
}
70+
}
71+
72+
// -----
73+
74+
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
75+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
76+
#smem = #ttg.shared_memory
77+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
78+
// Clip to vector size 2 (32bit) because we do not support 64 bit loads to lds
79+
// CHECK: #[[NEW_BLOCKED:.*]] = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
80+
tt.func public @async_copy_vector_size_2(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
81+
%arg1: i32 {tt.divisibility = 16 : i32},
82+
%arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) {
83+
// We need the index calculation so AxisAnalysis sees that we can vectorize the load
84+
%1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
85+
%2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked>
86+
%3 = tt.broadcast %2 : tensor<1x64xi32, #blocked> -> tensor<32x64xi32, #blocked>
87+
%4 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<32x64x!tt.ptr<f16>, #blocked>
88+
%5 = tt.addptr %4, %3 : tensor<32x64x!tt.ptr<f16>, #blocked>, tensor<32x64xi32, #blocked>
89+
90+
// CHECK: %{{.*}} = ttg.convert_layout %{{.*}} : {{.*}} -> tensor<32x64x!tt.ptr<f16>, #[[NEW_BLOCKED]]>
91+
// CHECK: %{{.*}} = ttg.async_copy_global_to_local %{{.*}}: tensor<32x64x!tt.ptr<f16>, #[[NEW_BLOCKED]]>
92+
%6 = ttg.async_copy_global_to_local %5, %arg2 : tensor<32x64x!tt.ptr<f16>, #blocked> -> <32x64xf16, #shared, #smem, mutable>
93+
tt.return
94+
}
95+
}
96+
97+
// -----
98+
99+
#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [32, 2], warpsPerCTA = [4, 1], order = [1, 0]}>
100+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
101+
#smem = #ttg.shared_memory
102+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
103+
// Clip to vector size 4 (128bit) which is the largest supported load width
104+
// CHECK: #[[NEW_BLOCKED:.*]] = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
105+
tt.func public @async_copy_vector_size_8(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
106+
%arg1: i32 {tt.divisibility = 16 : i32},
107+
%arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) {
108+
// We need the index calculation so AxisAnalysis sees that we can vectorize the load based on the src contiguity
109+
%1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
110+
%2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked>
111+
%3 = tt.broadcast %2 : tensor<1x64xi32, #blocked> -> tensor<32x64xi32, #blocked>
112+
%4 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<32x64x!tt.ptr<f16>, #blocked>
113+
%5 = tt.addptr %4, %3 : tensor<32x64x!tt.ptr<f16>, #blocked>, tensor<32x64xi32, #blocked>
114+
115+
// CHECK: %{{.*}} = ttg.convert_layout %{{.*}} : {{.*}} -> tensor<32x64x!tt.ptr<f16>, #[[NEW_BLOCKED]]>
116+
// CHECK: %{{.*}} = ttg.async_copy_global_to_local %{{.*}}: tensor<32x64x!tt.ptr<f16>, #[[NEW_BLOCKED]]>
117+
%6 = ttg.async_copy_global_to_local %5, %arg2 : tensor<32x64x!tt.ptr<f16>, #blocked> -> <32x64xf16, #shared, #smem, mutable>
118+
tt.return
119+
}
120+
}
121+
122+
// -----
123+
124+
#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [32, 2], warpsPerCTA = [4, 1], order = [1, 0]}>
125+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1]}>
126+
#smem = #ttg.shared_memory
127+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
128+
// The order of #blocked and #shared are different so we need to clip to 1 element
129+
// CHECK: #[[NEW_BLOCKED:.*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [4, 1], order = [1, 0]}>
130+
tt.func public @async_copy_different_order(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
131+
%arg1: i32 {tt.divisibility = 16 : i32},
132+
%arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) {
133+
// We need the index calculation so AxisAnalysis sees that we can vectorize the load based on the src contiguity
134+
%1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
135+
%2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked>
136+
%3 = tt.broadcast %2 : tensor<1x64xi32, #blocked> -> tensor<32x64xi32, #blocked>
137+
%4 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<32x64x!tt.ptr<f16>, #blocked>
138+
%5 = tt.addptr %4, %3 : tensor<32x64x!tt.ptr<f16>, #blocked>, tensor<32x64xi32, #blocked>
139+
140+
// CHECK: %{{.*}} = ttg.convert_layout %{{.*}} : {{.*}} -> tensor<32x64x!tt.ptr<f16>, #[[NEW_BLOCKED]]>
141+
// CHECK: %{{.*}} = ttg.async_copy_global_to_local %{{.*}}: tensor<32x64x!tt.ptr<f16>, #[[NEW_BLOCKED]]>
142+
%6 = ttg.async_copy_global_to_local %5, %arg2 : tensor<32x64x!tt.ptr<f16>, #blocked> -> <32x64xf16, #shared, #smem, mutable>
143+
tt.return
144+
}
145+
}

third_party/amd/include/TritonAMDGPUTransforms/Passes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ createTritonAMDGPUBlockPingpongPass(int32_t numStages = 2);
3737

3838
std::unique_ptr<Pass> createTritonAMDGPUInThreadTransposePass();
3939

40+
std::unique_ptr<Pass>
41+
createTritonAMDGPUCoalesceAsyncCopyPass(std::string archGenName = {});
42+
4043
/// Generate the code for registering passes.
4144
#define GEN_PASS_REGISTRATION
4245
#include "TritonAMDGPUTransforms/Passes.h.inc"

third_party/amd/include/TritonAMDGPUTransforms/Passes.td

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,26 @@ def TritonAMDGPUInThreadTranspose: Pass<"tritonamdgpu-in-thread-transpose", "mli
224224
let dependentDialects = ["mlir::triton::amdgpu::TritonAMDGPUDialect", "mlir::triton::gpu::TritonGPUDialect"];
225225
}
226226

227+
def TritonAMDGPUCoalesceAsyncCopy: Pass<"tritonamdgpu-coalesce-async-copy", "mlir::ModuleOp"> {
228+
let summary = "Improve coalescing for async global to local copies";
229+
230+
let description = [{
231+
GFX9:
232+
For AsyncCopyGlobalToLocal ops where the blocked encoding's sizePerThread is larger than the contiguity of the
233+
source or the supported load vector size we clip it to the largest supported size. This ensures we get coalesced writes to
234+
shared memory as required by the hardware. Does only work for non swizzled shared memory layouts
235+
}];
236+
237+
let constructor = "mlir::createTritonAMDGPUCoalesceAsyncCopyPass()";
238+
239+
let dependentDialects = [];
240+
241+
let options = [
242+
Option<"archGenerationName", "arch-generation-name",
243+
"std::string", /*default=*/"std::string{}",
244+
"GFX generation name of target device.">,
245+
];
246+
}
227247

228248

229249
#endif

third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ add_triton_library(TritonAMDGPUTransforms
22
AccelerateAMDMatmul.cpp
33
BlockPingpong.cpp
44
CanonicalizePointers.cpp
5+
CoalesceAsyncCopy.cpp
56
ConvertToBufferOps.cpp
67
OptimizeEpilogue.cpp
78
HoistLayoutConversions.cpp
Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
#include "TritonAMDGPUToLLVM/TargetUtils.h"
2+
#include "amd/lib/TritonAMDGPUToLLVM/Utility.h"
3+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
4+
#include "triton/Analysis/AxisInfo.h"
5+
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
6+
7+
#define GEN_PASS_CLASSES
8+
#include "TritonAMDGPUTransforms/Passes.h"
9+
10+
#undef DEBUG_TYPE
11+
#define DEBUG_TYPE "tritonamdgpu-coalesce-async-copy"
12+
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
13+
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
14+
15+
using namespace mlir;
16+
namespace ttg = triton::gpu;
17+
18+
// On gfx9 global and buffer loads directly to shared memory need to write
19+
// coalesced. This pattern converts the layout of the src, mask and other to
20+
// ensure the owned data per thread is contigious and does no exceed the
21+
// supported load vector size to ensure coalesed writes
22+
struct CoalesceAsyncCopyWrites
23+
: public OpRewritePattern<ttg::AsyncCopyGlobalToLocalOp> {
24+
CoalesceAsyncCopyWrites(const triton::AMD::TargetInfo &targetInfo,
25+
const DenseMap<ttg::AsyncCopyGlobalToLocalOp,
26+
unsigned> &asyncCopyContiguity,
27+
MLIRContext *ctx)
28+
: OpRewritePattern(ctx), targetInfo{targetInfo},
29+
asyncCopyContiguity{std::move(asyncCopyContiguity)} {}
30+
31+
LogicalResult matchAndRewrite(ttg::AsyncCopyGlobalToLocalOp copyOp,
32+
PatternRewriter &rewriter) const override {
33+
auto src = copyOp.getSrc();
34+
auto dst = copyOp.getResult();
35+
Value mask = copyOp.getMask();
36+
Value other = copyOp.getOther();
37+
38+
auto srcTy = cast<RankedTensorType>(src.getType());
39+
auto dstTy = cast<ttg::MemDescType>(dst.getType());
40+
41+
auto blockedEnc = dyn_cast<ttg::BlockedEncodingAttr>(srcTy.getEncoding());
42+
if (!blockedEnc)
43+
return rewriter.notifyMatchFailure(copyOp,
44+
"src encoding must be #blocked");
45+
46+
auto sharedEnc =
47+
dyn_cast<ttg::SwizzledSharedEncodingAttr>(dstTy.getEncoding());
48+
if (!sharedEnc)
49+
return rewriter.notifyMatchFailure(
50+
copyOp, "destination encoding must be #SwizzledShared");
51+
if (sharedEnc.getMaxPhase() > 1)
52+
return rewriter.notifyMatchFailure(
53+
copyOp, "swizzled shared encoding not supported");
54+
55+
// We start from the precomputed contiguity we got from AxisAnalysis.
56+
unsigned loadContig = 0;
57+
if (auto it = asyncCopyContiguity.find(copyOp);
58+
it != asyncCopyContiguity.end())
59+
loadContig = it->second;
60+
else
61+
return copyOp->emitError()
62+
<< "No contiguity information about the copy op";
63+
assert(loadContig > 0);
64+
65+
// Further restrict the contiguity based on the contiguity of the src to dst
66+
// layout e.g. if the order of the blocked and shared encoding is different
67+
// we can only load one element at a time or if the shared encoding is
68+
// swizzled we cannot exceed the vector size of the swizzling pattern
69+
LinearLayout regLayout =
70+
triton::gpu::toLinearLayout(srcTy.getShape(), blockedEnc);
71+
LinearLayout sharedLayout =
72+
triton::gpu::toLinearLayout(srcTy.getShape(), sharedEnc);
73+
auto regToSharedLayout = regLayout.invertAndCompose(sharedLayout);
74+
loadContig = std::min<unsigned>(loadContig,
75+
regToSharedLayout.getNumConsecutiveInOut());
76+
77+
// Select the largest supported load width equal or smaller than loadContig
78+
auto elemBitWidth = dstTy.getElementTypeBitWidth();
79+
while (loadContig > 0 && !targetInfo.supportsDirectToLdsLoadBitWidth(
80+
loadContig * elemBitWidth)) {
81+
loadContig /= 2;
82+
}
83+
84+
if (loadContig == 0) {
85+
return rewriter.notifyMatchFailure(
86+
copyOp, "could not find layout config to create coalesced writes");
87+
}
88+
89+
// Do not rewrite if we already use the correct contiguity (could be from a
90+
// previous rewrite)
91+
auto contigPerThread = ttg::getContigPerThread(srcTy);
92+
auto blockedContig = contigPerThread[blockedEnc.getOrder()[0]];
93+
if (blockedContig == loadContig) {
94+
return rewriter.notifyMatchFailure(copyOp,
95+
"already using the correct layout");
96+
}
97+
98+
// Get new blocked encoding with loadContig as sizePerThread in the fastest
99+
// dim
100+
assert(blockedContig >= loadContig);
101+
contigPerThread[blockedEnc.getOrder()[0]] = loadContig;
102+
int numWarps = triton::gpu::lookupNumWarps(copyOp);
103+
auto mod = copyOp->getParentOfType<ModuleOp>();
104+
int threadsPerWarp = ttg::TritonGPUDialect::getThreadsPerWarp(mod);
105+
auto newBlockEnc = BlockedEncodingAttr::get(
106+
copyOp.getContext(), srcTy.getShape(), contigPerThread,
107+
blockedEnc.getOrder(), numWarps, threadsPerWarp,
108+
blockedEnc.getCTALayout());
109+
110+
// Convert layout of src, mask and other to new encoding
111+
auto convertLayout = [&rewriter](auto loc, Value old, auto newEnc) {
112+
auto oldTy = cast<RankedTensorType>(old.getType());
113+
RankedTensorType newSrcTy = RankedTensorType::get(
114+
oldTy.getShape(), oldTy.getElementType(), newEnc);
115+
return rewriter.create<ttg::ConvertLayoutOp>(loc, newSrcTy, old);
116+
};
117+
118+
auto loc = copyOp->getLoc();
119+
Value cvtSrc = convertLayout(loc, src, newBlockEnc);
120+
121+
if (mask)
122+
mask = convertLayout(loc, mask, newBlockEnc);
123+
if (other)
124+
other = convertLayout(loc, other, newBlockEnc);
125+
126+
rewriter.modifyOpInPlace(copyOp, [&]() {
127+
copyOp.getSrcMutable().assign(cvtSrc);
128+
if (mask)
129+
copyOp.getMaskMutable().assign(mask);
130+
if (other)
131+
copyOp.getOtherMutable().assign(other);
132+
});
133+
return success();
134+
}
135+
136+
private:
137+
const triton::AMD::TargetInfo &targetInfo;
138+
const DenseMap<ttg::AsyncCopyGlobalToLocalOp, unsigned> &asyncCopyContiguity;
139+
};
140+
141+
class TritonAMDGPUCoalesceAsyncCopyPass
142+
: public TritonAMDGPUCoalesceAsyncCopyBase<
143+
TritonAMDGPUCoalesceAsyncCopyPass> {
144+
public:
145+
TritonAMDGPUCoalesceAsyncCopyPass(StringRef archGenName) {
146+
this->archGenerationName = archGenName.str();
147+
}
148+
149+
void runOnOperation() override {
150+
ModuleOp m = getOperation();
151+
MLIRContext *context = &getContext();
152+
153+
triton::AMD::TargetInfo targetInfo(archGenerationName);
154+
155+
mlir::RewritePatternSet patterns(context);
156+
157+
switch (targetInfo.getISAFamily()) {
158+
case triton::AMD::ISAFamily::CDNA1:
159+
case triton::AMD::ISAFamily::CDNA2:
160+
case triton::AMD::ISAFamily::CDNA3:
161+
case triton::AMD::ISAFamily::CDNA4: {
162+
break;
163+
}
164+
default:
165+
return;
166+
}
167+
168+
// Precompute the contiguity of all AsyncCopy ops based on the src and
169+
// mask contiguity/alignment to avoid rebuilding ModuleAxisInfoAnalysis
170+
// after every IR change.
171+
triton::ModuleAxisInfoAnalysis axisAnalysis(m);
172+
DenseMap<ttg::AsyncCopyGlobalToLocalOp, unsigned> asyncCopyContiguity;
173+
m->walk([&](ttg::AsyncCopyGlobalToLocalOp copyOp) {
174+
unsigned contiguity =
175+
mlir::LLVM::AMD::getContiguity(copyOp.getSrc(), axisAnalysis);
176+
if (auto mask = copyOp.getMask()) {
177+
contiguity =
178+
std::min<unsigned>(contiguity, axisAnalysis.getMaskAlignment(mask));
179+
}
180+
asyncCopyContiguity.insert({copyOp, contiguity});
181+
});
182+
patterns.add<CoalesceAsyncCopyWrites>(targetInfo, asyncCopyContiguity,
183+
context);
184+
185+
if (applyPatternsGreedily(m, std::move(patterns)).failed())
186+
signalPassFailure();
187+
}
188+
};
189+
190+
std::unique_ptr<Pass>
191+
mlir::createTritonAMDGPUCoalesceAsyncCopyPass(std::string archGenName) {
192+
return std::make_unique<TritonAMDGPUCoalesceAsyncCopyPass>(
193+
std::move(archGenName));
194+
}

0 commit comments

Comments
 (0)