Skip to content

Commit 65a80e4

Browse files
AlexAUTantiagainst
andauthored
[AMD] Fix vmcnt(0) for LocalLoads with loop-carried AsyncToken (#7052)
Moves all async related llvm workaround function to a separate utility file. Reuses `LocalLoad` annotations introduced by triton-lang/triton#7047 to handle loop-carried tokens in alias computations. The only functional change is handling vmcnt(0) case better. Before this PR we get a vmcnt(0) before the ds_read for such cases. --------- Co-authored-by: Lei Zhang <[email protected]>
1 parent 24863d6 commit 65a80e4

File tree

15 files changed

+247
-181
lines changed

15 files changed

+247
-181
lines changed

test/Conversion/amd/async-ops-alias-scopes.mlir

Lines changed: 56 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx950 --convert-builtin-func-to-llvm | FileCheck %s --check-prefixes=COMMON,GFX950
22
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx942 --convert-builtin-func-to-llvm | FileCheck %s --check-prefixes=COMMON,GFX942
33

4-
// COMMON: [[ASYNC_COPY_SCOPE:#.*]] = #llvm.alias_scope<id = "amdgpu.AsyncCopies"
5-
// COMMON: [[LOCAL_LOAD_SCOPE:#.*]] = #llvm.alias_scope<id = "amdgpu.LocalLoads"
4+
// COMMON: [[$ASYNC_COPY_SCOPE:#.*]] = #llvm.alias_scope<id = "amdgpu.AsyncCopies"
5+
// COMMON: [[$LOCAL_LOAD_SCOPE:#.*]] = #llvm.alias_scope<id = "amdgpu.LocalLoads"
66
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [64, 1], warpsPerCTA = [1, 1], order = [0, 1]}>
77
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 8, maxPhase = 2, order = [0, 1]}>
88
#smem = #ttg.shared_memory
99
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
10+
// COMMON-LABEL: @async_copy_alias
1011
tt.func public @async_copy_alias(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
1112
%arg1: !ttg.memdesc<64x1xf32, #shared, #smem, mutable>,
1213
%maskVal: i1) {
@@ -15,9 +16,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.targ
1516
%ptr = tt.splat %arg0 : !tt.ptr<f32> -> tensor<64x1x!tt.ptr<f32>, #blocked>
1617
%mask = tt.splat %maskVal : i1 -> tensor<64x1xi1, #blocked>
1718

18-
// COMMON: rocdl.global.load.lds {{.*}} {alias_scopes = [[[ASYNC_COPY_SCOPE]]]
19+
// COMMON: rocdl.global.load.lds {{.*}} {alias_scopes = [[[$ASYNC_COPY_SCOPE]]]
1920
// Check that store for 'other' has alias information set
20-
// COMMON: llvm.store {{.*}} {alias_scopes = [[[LOCAL_LOAD_SCOPE]]], noalias_scopes = [[[ASYNC_COPY_SCOPE]]]
21+
// COMMON: llvm.store {{.*}} {alias_scopes = [[[$LOCAL_LOAD_SCOPE]]], noalias_scopes = [[[$ASYNC_COPY_SCOPE]]]
2122
%0 = ttg.async_copy_global_to_local %ptr, %arg1 mask %mask other %other : tensor<64x1x!tt.ptr<f32>, #blocked> -> <64x1xf32, #shared, #smem, mutable>
2223

2324
// COMMON: llvm.return
@@ -27,21 +28,22 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.targ
2728

2829
// -----
2930

30-
// COMMON: [[ASYNC_COPY_SCOPE:#.*]] = #llvm.alias_scope<id = "amdgpu.AsyncCopies"
31+
// COMMON: [[$ASYNC_COPY_SCOPE:#.*]] = #llvm.alias_scope<id = "amdgpu.AsyncCopies"
3132
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [4, 1], order = [1, 0]}>
3233
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
3334
#smem = #ttg.shared_memory
3435
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
36+
// COMMON-LABEL: @buffer_load_to_local_alias
3537
tt.func public @buffer_load_to_local_alias(%maskVal: i1,
3638
%arg1: !tt.ptr<f32>,
3739
%arg2: tensor<8x64xi32, #blocked>,
3840
%arg3: !ttg.memdesc<8x64xf32, #shared, #smem, mutable>) {
3941
%mask = tt.splat %maskVal : i1 -> tensor<8x64xi1, #blocked>
4042
%other = arith.constant dense<1.000000e+00> : tensor<8x64xf32, #blocked>
4143

42-
// COMMON: rocdl.raw.ptr.buffer.load.lds {{.*}} {alias_scopes = [[[ASYNC_COPY_SCOPE]]]
44+
// COMMON: rocdl.raw.ptr.buffer.load.lds {{.*}} {alias_scopes = [[[$ASYNC_COPY_SCOPE]]]
4345
// Check that store for 'other' has alias information set
44-
// COMMON: llvm.store {{.*}} {alias_scopes = [[[LOCAL_LOAD_SCOPE]]], noalias_scopes = [[[ASYNC_COPY_SCOPE]]]
46+
// COMMON: llvm.store {{.*}} {alias_scopes = [[[$LOCAL_LOAD_SCOPE]]], noalias_scopes = [[[$ASYNC_COPY_SCOPE]]]
4547
%65 = amdgpu.buffer_load_to_local %arg1[%arg2] mask=%mask other=%other into %arg3 {OpIdx = #amdgpu.OpIdx<1>} : <f32>[tensor<8x64xi32, #blocked>] tensor<8x64xf32, #blocked> -> <8x64xf32, #shared, #smem, mutable>
4648

4749
// COMMON: llvm.return
@@ -51,14 +53,15 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
5153

5254
// -----
5355

54-
// COMMON: [[LOCAL_LOAD_SCOPE:#.*]] = #llvm.alias_scope<id = "amdgpu.LocalLoads"
55-
// COMMON: [[ASYNC_COPY_SCOPE:#.*]] = #llvm.alias_scope<id = "amdgpu.AsyncCopies"
56+
// COMMON: [[$LOCAL_LOAD_SCOPE:#.*]] = #llvm.alias_scope<id = "amdgpu.LocalLoads"
57+
// COMMON: [[$ASYNC_COPY_SCOPE:#.*]] = #llvm.alias_scope<id = "amdgpu.AsyncCopies"
5658
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [64, 1], warpsPerCTA = [1, 1], order = [0, 1]}>
5759
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 8, maxPhase = 2, order = [0, 1]}>
5860
#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 16, order = [1, 0]}>
5961
#smem = #ttg.shared_memory
6062
#mma = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [32, 32], isTransposed = true}>
6163
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
64+
// COMMON-LABEL: @local_loads_with_token_from_async_wait
6265
tt.func public @local_loads_with_token_from_async_wait(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
6366
%arg1: !ttg.memdesc<64x1xf16, #shared, #smem, mutable>,
6467
%arg2: !ttg.memdesc<16x16xf16, #shared, #smem, mutable>) {
@@ -67,12 +70,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.targ
6770
// Check alias information is added for different lowering paths
6871

6972
// Test lowering path in common MemoryOpToLLVM pattern
70-
// COMMON: llvm.load {{.*}} {alias_scopes = [[[LOCAL_LOAD_SCOPE]]], {{.*}}, noalias_scopes = [[[ASYNC_COPY_SCOPE]]]
73+
// COMMON: llvm.load {{.*}} {alias_scopes = [[[$LOCAL_LOAD_SCOPE]]], {{.*}}, noalias_scopes = [[[$ASYNC_COPY_SCOPE]]]
7174
%4 = ttg.local_load %arg1 token %3 : !ttg.memdesc<64x1xf16, #shared, #smem, mutable> -> tensor<64x1xf16, #blocked>
7275

7376
// Test lowering path in AMD's MemoryOpToLLVM pattern
74-
// GFX942: llvm.load {{.*}} {alias_scopes = [[[LOCAL_LOAD_SCOPE]]], {{.*}}, noalias_scopes = [[[ASYNC_COPY_SCOPE]]]
75-
// GFX950: rocdl.ds.read.tr16.b64 {{.*}} {alias_scopes = [[[LOCAL_LOAD_SCOPE]]], noalias_scopes = [[[ASYNC_COPY_SCOPE]]]
77+
// GFX942: llvm.load {{.*}} {alias_scopes = [[[$LOCAL_LOAD_SCOPE]]], {{.*}}, noalias_scopes = [[[$ASYNC_COPY_SCOPE]]]
78+
// GFX950: rocdl.ds.read.tr16.b64 {{.*}} {alias_scopes = [[[$LOCAL_LOAD_SCOPE]]], noalias_scopes = [[[$ASYNC_COPY_SCOPE]]]
7679
%5 = ttg.local_load %arg2 token %3 : !ttg.memdesc<16x16xf16, #shared, #smem, mutable> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>
7780

7881
// Stores to keep the local_loads
@@ -90,27 +93,28 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.targ
9093

9194
// Same as above but LocalLoad does not use the token from AsyncWait
9295

93-
// COMMON: [[ASYNC_COPY_SCOPE:#.*]] = #llvm.alias_scope<id = "amdgpu.AsyncCopies"
96+
// COMMON: [[$ASYNC_COPY_SCOPE:#.*]] = #llvm.alias_scope<id = "amdgpu.AsyncCopies"
9497
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [64, 1], warpsPerCTA = [1, 1], order = [0, 1]}>
9598
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 8, maxPhase = 2, order = [0, 1]}>
9699
#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 16, order = [1, 0]}>
97100
#smem = #ttg.shared_memory
98101
#mma = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [32, 32], isTransposed = true}>
99102
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
103+
// COMMON-LABEL: @local_loads_without_token_from_async_wait
100104
tt.func public @local_loads_without_token_from_async_wait(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
101105
%arg1: !ttg.memdesc<64x1xf32, #shared, #smem, mutable>,
102106
%arg4: !ttg.memdesc<16x16xf32, #shared, #smem, mutable>) {
103107
// We need the splat to allow the AxisAnalysis to work during lowering
104108
%ptr = tt.splat %arg0 : !tt.ptr<f32> -> tensor<64x1x!tt.ptr<f32>, #blocked>
105109

106-
// COMMON: rocdl.global.load.lds {{.*}} {alias_scopes = [[[ASYNC_COPY_SCOPE]]]
110+
// COMMON: rocdl.global.load.lds {{.*}} {alias_scopes = [[[$ASYNC_COPY_SCOPE]]]
107111
%0 = ttg.async_copy_global_to_local %ptr, %arg1 : tensor<64x1x!tt.ptr<f32>, #blocked> -> <64x1xf32, #shared, #smem, mutable>
108112
%1 = ttg.async_commit_group %0
109113

110114
%3 = ttg.async_wait %1 {num = 1 : i32}
111115

112116
// Check alias information is not used at all for different lowering paths
113-
// COMMON-NOT: [[ASYNC_COPY_SCOPE]]
117+
// COMMON-NOT: [[$ASYNC_COPY_SCOPE]]
114118

115119
// Test lowering path in common MemoryOpToLLVM pattern
116120
%4 = ttg.local_load %arg1 token %0 : !ttg.memdesc<64x1xf32, #shared, #smem, mutable> -> tensor<64x1xf32, #blocked>
@@ -124,3 +128,40 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.targ
124128
tt.return
125129
}
126130
}
131+
132+
// -----
133+
134+
// COMMON: [[$LOCAL_LOAD_SCOPE:#.*]] = #llvm.alias_scope<id = "amdgpu.LocalLoads"
135+
// COMMON: [[$ASYNC_COPY_SCOPE:#.*]] = #llvm.alias_scope<id = "amdgpu.AsyncCopies"
136+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [64, 1], warpsPerCTA = [1, 1], order = [0, 1]}>
137+
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 8, maxPhase = 2, order = [0, 1]}>
138+
#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 16, order = [1, 0]}>
139+
#smem = #ttg.shared_memory
140+
#mma = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [32, 32], isTransposed = true}>
141+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
142+
// COMMON-LABEL: @local_loads_with_loop_carried_token
143+
tt.func public @local_loads_with_loop_carried_token(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
144+
%arg1: !ttg.memdesc<64x1xf16, #shared, #smem, mutable>,
145+
%loopIterCount: i32) {
146+
%c0_i32 = arith.constant 0 : i32
147+
%c1_i32 = arith.constant 1 : i32
148+
149+
%1 = ttg.async_wait {num = 1 : i32}
150+
// COMMON: llvm.load
151+
%2 = ttg.local_load %arg1 token %1 : !ttg.memdesc<64x1xf16, #shared, #smem, mutable> -> tensor<64x1xf16, #blocked>
152+
153+
%loop_result:2 = scf.for %arg14 = %c0_i32 to %loopIterCount step %c1_i32 iter_args(%arg10 = %1, %arg11 = %2) -> (!ttg.async.token, tensor<64x1xf16, #blocked>) : i32 {
154+
// COMMON: llvm.load {{.*}} {alias_scopes = [[[$LOCAL_LOAD_SCOPE]]], {{.*}}, noalias_scopes = [[[$ASYNC_COPY_SCOPE]]]
155+
%3 = ttg.local_load %arg1 token %arg10 : !ttg.memdesc<64x1xf16, #shared, #smem, mutable> -> tensor<64x1xf16, #blocked>
156+
%4 = ttg.async_wait {num = 1 : i32}
157+
scf.yield %4, %3: !ttg.async.token, tensor<64x1xf16, #blocked>
158+
}
159+
160+
// Stores to keep the local_loads
161+
%ptr = tt.splat %arg0 : !tt.ptr<f16> -> tensor<64x1x!tt.ptr<f16>, #blocked>
162+
tt.store %ptr, %loop_result#1 : tensor<64x1x!tt.ptr<f16>, #blocked>
163+
164+
// COMMON: llvm.return
165+
tt.return
166+
}
167+
}

third_party/amd/include/TritonAMDGPUToLLVM/MembarUtility.h

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,10 @@
11
#ifndef TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTOLLVM_MEMBARUTILITY_H_
22
#define TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTOLLVM_MEMBARUTILITY_H_
33

4-
#include "mlir/IR/BuiltinOps.h"
54
#include "mlir/IR/Operation.h"
6-
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
75

86
namespace mlir::triton::AMD {
97

10-
// Annotates LocalLoadOps with ttg.amdgpu.syncedByAsyncWait=true if they are
11-
// synced by an AsyncWait.
12-
void annotateLocalLoadsSyncedViaAsyncWait(ModuleOp mod);
13-
14-
// Getter for the annotation applied by annotateLocalLoadsSyncedViaAsyncWait
15-
bool isSyncedViaAsyncWait(triton::gpu::LocalLoadOp localLoadOp);
16-
178
// Filter function used in the AMDGPU backend to filter unnecessary barriers
189
// during Membar Analysis. Filters applied by this function:
1910
// 1) Do not create barriers between AsyncCopyGlobalToLocal and LocalLoad if the
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
#include "AsyncUtility.h"
2+
3+
#include "Dialect/TritonAMDGPU/IR/Dialect.h"
4+
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
5+
6+
namespace mlir::triton::AMD {
7+
namespace {
8+
constexpr const char *syncedViaAsyncWaitAttrName =
9+
"ttg.amdgpu.syncedViaAsyncWait";
10+
// Traverses the def-chain including control flow of the token and returns true
11+
// if all defining operations are an AsyncWait
12+
bool comesFromAsyncWait(Value token) {
13+
if (auto defOp = token.getDefiningOp()) {
14+
return isa<triton::gpu::AsyncWaitOp>(defOp);
15+
}
16+
17+
auto blockArg = dyn_cast<BlockArgument>(token);
18+
// If the token has no defining op and is not an BlockArgument bail out
19+
if (!blockArg) {
20+
return false;
21+
}
22+
23+
auto block = blockArg.getOwner();
24+
auto argId = blockArg.getArgNumber();
25+
26+
auto destOperandFromAsyncWait = [argId](auto &&operands) {
27+
assert(argId < operands.size());
28+
return comesFromAsyncWait(operands[argId]);
29+
};
30+
31+
// Check all predecessor block's terminator and follow the passed value at
32+
// argId to see if they are immediately an AsyncWait.
33+
for (auto *pred : block->getPredecessors()) {
34+
auto terminator = pred->getTerminator();
35+
if (auto br = dyn_cast<BranchOpInterface>(terminator)) {
36+
for (auto successor : llvm::enumerate(br->getSuccessors())) {
37+
if (block != successor.value())
38+
continue;
39+
auto operands = br.getSuccessorOperands(successor.index());
40+
if (!destOperandFromAsyncWait(operands))
41+
return false;
42+
}
43+
} else {
44+
return false;
45+
}
46+
}
47+
return true;
48+
}
49+
} // namespace
50+
51+
void annotateLocalLoadsSyncedViaAsyncWait(ModuleOp mod) {
52+
SmallVector<triton::gpu::LocalLoadOp> localLoads;
53+
mod->walk([&](triton::gpu::LocalLoadOp localLoadOp) {
54+
localLoads.emplace_back(localLoadOp);
55+
});
56+
57+
auto *ctx = mod->getContext();
58+
for (auto &loadOp : localLoads) {
59+
auto token = loadOp.getToken();
60+
bool isSyncedViaAsyncWait = token && comesFromAsyncWait(token);
61+
loadOp->setAttr(syncedViaAsyncWaitAttrName,
62+
BoolAttr::get(ctx, isSyncedViaAsyncWait));
63+
}
64+
}
65+
66+
bool isSyncedViaAsyncWait(triton::gpu::LocalLoadOp localLoadOp) {
67+
auto attr = localLoadOp->getAttr(syncedViaAsyncWaitAttrName);
68+
if (!attr) {
69+
localLoadOp.emitRemark("has no async sync information attached to it which "
70+
"might negatively affect performance. Run "
71+
"annotateLocalLoadSyncedViaAsyncWait first");
72+
return false;
73+
}
74+
return cast<BoolAttr>(attr).getValue();
75+
}
76+
77+
namespace {
78+
LLVM::AliasScopeDomainAttr getLoadScopeDomain(MLIRContext *ctx) {
79+
Builder b(ctx);
80+
return b.getAttr<LLVM::AliasScopeDomainAttr>(
81+
b.getStringAttr("amdgpu.AsyncOps"),
82+
b.getStringAttr(
83+
"Domain to hold alias scopes to specify aliasing information between "
84+
"AsyncCopyGlobalToLocal, BufferLoadToLocal and LocalLoad ops"));
85+
}
86+
87+
LLVM::AliasScopeAttr getAsyncCopyScope(MLIRContext *ctx) {
88+
Builder b(ctx);
89+
auto name = b.getStringAttr("amdgpu.AsyncCopies");
90+
auto desc = b.getStringAttr(
91+
"Scope containing all AsyncCopyGlobalToLocal and BufferLoadToLocal ops");
92+
return b.getAttr<LLVM::AliasScopeAttr>(name, getLoadScopeDomain(ctx), desc);
93+
}
94+
95+
LLVM::AliasScopeAttr getLoadCopyScope(MLIRContext *ctx) {
96+
Builder b(ctx);
97+
auto name = b.getStringAttr("amdgpu.LocalLoads");
98+
auto desc = b.getStringAttr("Scope containing all LocalLoad ops");
99+
return b.getAttr<LLVM::AliasScopeAttr>(name, getLoadScopeDomain(ctx), desc);
100+
}
101+
} // namespace
102+
103+
void addAsyncCopyAliasScope(LLVM::AliasAnalysisOpInterface directToLdsOp) {
104+
auto ctx = directToLdsOp->getContext();
105+
Builder b(ctx);
106+
directToLdsOp.setAliasScopes(b.getArrayAttr(getAsyncCopyScope(ctx)));
107+
}
108+
109+
void addLocalLoadNoAliasScope(triton::gpu::LocalLoadOp localLoadOp,
110+
LLVM::AliasAnalysisOpInterface llLoadOp) {
111+
if (!isSyncedViaAsyncWait(localLoadOp))
112+
return;
113+
114+
return addLocalLoadNoAliasScope(llLoadOp);
115+
}
116+
117+
void addLocalLoadNoAliasScope(LLVM::AliasAnalysisOpInterface llLoadOp) {
118+
auto ctx = llLoadOp->getContext();
119+
120+
// Do not alias with AsyncCopies
121+
auto noAliasScopes = ArrayAttr::get(ctx, getAsyncCopyScope(ctx));
122+
llLoadOp.setNoAliasScopes(noAliasScopes);
123+
124+
// Add to different scope as ops without any scope alias with everything
125+
auto aliasScopes = ArrayAttr::get(ctx, getLoadCopyScope(ctx));
126+
llLoadOp.setAliasScopes(aliasScopes);
127+
}
128+
129+
} // namespace mlir::triton::AMD
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#ifndef TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_ASYNCUTILITY_H_
2+
#define TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_ASYNCUTILITY_H_
3+
4+
#include "mlir/Dialect/LLVMIR/LLVMInterfaces.h"
5+
#include "mlir/IR/BuiltinOps.h"
6+
#include "triton/Conversion/MLIRTypes.h"
7+
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
8+
9+
namespace mlir::triton::AMD {
10+
// Annotates LocalLoadOps with ttg.amdgpu.syncedByAsyncWait=true if they are
11+
// synced by an AsyncWait.
12+
void annotateLocalLoadsSyncedViaAsyncWait(ModuleOp mod);
13+
14+
// Getter for the annotation applied by annotateLocalLoadsSyncedViaAsyncWait
15+
bool isSyncedViaAsyncWait(triton::gpu::LocalLoadOp localLoadOp);
16+
17+
// LLVM is unable to deduce dependencies across warps and loop iterations for
18+
// AsyncCopy and LocalLoad and will emit conservative wait counts. In triton the
19+
// dependency is models via AsyncWait, e.g.
20+
// %token1 = ttg.async_copy_global_to_local/amdgpu.buffer_load_to_local
21+
// %token2 = ttg.async_wait %token1
22+
// %1 = ttg.local_load .. token %token2
23+
// For such cases AsyncWait will emit the correct wait and the conservative
24+
// waits are redundant and hindering performance/interleaving.
25+
// To disable the conservative waits two alias scopes are created:
26+
// 1) "amdgpu.AsyncCopies" will contain all AsyncCopy ops
27+
// 2) "amdgpu.LocalLoad" will contain all LocalLoads manually synchronized via
28+
// AsyncWait
29+
// ALl manually synchronized LocalLoads will additionally have "AsyncCopies" as
30+
// a non alias scope to disable the implicit waits from the LLVM backend
31+
32+
// If localLoadOp has a token from an AsyncWait:
33+
// - Attaches "amdgpu.LocalLoad" alias scope to llLoadOp
34+
// - Attaches "amdgpu.AsyncCopies" as *non* alias scope to llLoadOp
35+
void addLocalLoadNoAliasScope(triton::gpu::LocalLoadOp localLoadOp,
36+
LLVM::AliasAnalysisOpInterface llLoadOp);
37+
// Overload from above without checking the AsyncToken
38+
void addLocalLoadNoAliasScope(LLVM::AliasAnalysisOpInterface llLoadOp);
39+
// Attaches the "AsyncCopies" alias scope to llLoadDirectToLdsOp
40+
void addAsyncCopyAliasScope(LLVM::AliasAnalysisOpInterface llLoadDirectToLdsOp);
41+
42+
} // namespace mlir::triton::AMD
43+
44+
#endif

third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "TritonAMDGPUToLLVM/Passes.h"
22

3+
#include "AsyncUtility.h"
34
#include "Utility.h"
45
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
56
#include "mlir/Pass/Pass.h"
@@ -80,7 +81,7 @@ class CallOpConversion : public OpRewritePattern<LLVM::CallOp> {
8081
bool addAsyncAliasScopes =
8182
callOp.getCallee().value().contains(mlir::LLVM::AMD::noAliasAsyncLoads);
8283
if (addAsyncAliasScopes) {
83-
LLVM::AMD::addLocalLoadNoAliasScope(storeOp);
84+
AMD::addLocalLoadNoAliasScope(storeOp);
8485
}
8586
rewriter.create<LLVM::BrOp>(loc, afterStore);
8687
rewriter.setInsertionPointToStart(afterStore);
@@ -120,7 +121,7 @@ class CallOpConversion : public OpRewritePattern<LLVM::CallOp> {
120121
bool addAsyncNoAliasInfo =
121122
callOp.getCallee().value().contains(mlir::LLVM::AMD::noAliasAsyncLoads);
122123
if (addAsyncNoAliasInfo) {
123-
LLVM::AMD::addLocalLoadNoAliasScope(loadOp);
124+
AMD::addLocalLoadNoAliasScope(loadOp);
124125
}
125126
rewriter.create<LLVM::BrOp>(loc, loadOp->getResult(0), afterLoad);
126127
rewriter.setInsertionPointToStart(falseBlock);

0 commit comments

Comments
 (0)