Skip to content

Commit a295e60

Browse files
authored
[AMD] Add amdgpu.async_wait to explicitly represent number of async transactions (#8575)
`ttg.async_wait` counts the number of outstanding `ttg.commit_groups`. However, when lowering to LLVM on AMD we require the number of outstanding async intrinsics/final assembly instructions. The conversion is already done by `UpdateAsyncWaitCnt` which modifies the `num` of `ttg.async_wait` in place. This PR introduces a new op `amdgpu.async_wait` to make the change in semantics explicit in the IR. `UpdateAsyncWaitCount` is moved to `TTGIR->LLVM` primarily to also include in for `Gluon` kernels and we should always call it since it will only have an effect if there are `ttg.async_wait` ops present in the kernel. To avoid membar changes this also adds a `ttgpu.LocalBarrier` after each `amdgpu.async_wait`. Membar will respect the newly added barrier and behave the same as for `ttg.async_wait`.
1 parent 33f077b commit a295e60

File tree

11 files changed

+98
-43
lines changed

11 files changed

+98
-43
lines changed

test/Conversion/amd/async-ops-alias-scopes.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.targ
6565
tt.func public @local_loads_with_token_from_async_wait(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
6666
%arg1: !ttg.memdesc<64x1xf16, #shared, #smem, mutable>,
6767
%arg2: !ttg.memdesc<16x16xf16, #shared, #smem, mutable>) {
68-
%3 = ttg.async_wait {num = 1 : i32}
68+
%3 = amdgpu.async_wait {num_inst = 1 : i32}
6969

7070
// Check alias information is added for different lowering paths
7171

@@ -111,7 +111,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.targ
111111
%0 = ttg.async_copy_global_to_local %ptr, %arg1 : tensor<64x1x!tt.ptr<f32>, #blocked> -> <64x1xf32, #shared, #smem, mutable>
112112
%1 = ttg.async_commit_group tokens %0
113113

114-
%3 = ttg.async_wait %1 {num = 1 : i32}
114+
%3 = amdgpu.async_wait %1 {num_inst = 1 : i32}
115115

116116
// Check alias information is not used at all for different lowering paths
117117
// COMMON-NOT: [[$ASYNC_COPY_SCOPE]]
@@ -146,14 +146,14 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.targ
146146
%c0_i32 = arith.constant 0 : i32
147147
%c1_i32 = arith.constant 1 : i32
148148

149-
%1 = ttg.async_wait {num = 1 : i32}
149+
%1 = amdgpu.async_wait {num_inst = 1 : i32}
150150
// COMMON: llvm.load
151151
%2 = ttg.local_load %arg1 token %1 : !ttg.memdesc<64x1xf16, #shared, #smem, mutable> -> tensor<64x1xf16, #blocked>
152152

153153
%loop_result:2 = scf.for %arg14 = %c0_i32 to %loopIterCount step %c1_i32 iter_args(%arg10 = %1, %arg11 = %2) -> (!ttg.async.token, tensor<64x1xf16, #blocked>) : i32 {
154154
// COMMON: llvm.load {{.*}} {alias_scopes = [[[$LOCAL_LOAD_SCOPE]]], noalias_scopes = [[[$ASYNC_COPY_SCOPE]]]
155155
%3 = ttg.local_load %arg1 token %arg10 : !ttg.memdesc<64x1xf16, #shared, #smem, mutable> -> tensor<64x1xf16, #blocked>
156-
%4 = ttg.async_wait {num = 1 : i32}
156+
%4 = amdgpu.async_wait {num_inst = 1 : i32}
157157
scf.yield %4, %3: !ttg.async.token, tensor<64x1xf16, #blocked>
158158
}
159159

test/Conversion/amd/async_ops_to_llvm.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -106,24 +106,24 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
106106
// CHECK: rocdl.s.waitcnt -49168
107107
// CHECK: rocdl.s.waitcnt -7937
108108
// CHECK: rocdl.s.barrier
109-
ttg.async_wait {num = 0 : i32}
109+
amdgpu.async_wait {num_inst = 0 : i32}
110110
// CHECK: rocdl.s.waitcnt -49167
111111
// CHECK: rocdl.s.waitcnt -7937
112112
// CHECK: rocdl.s.barrier
113-
ttg.async_wait {num = 1 : i32}
113+
amdgpu.async_wait {num_inst = 1 : i32}
114114
// CHECK: rocdl.s.waitcnt -2
115115
// CHECK: rocdl.s.waitcnt -7937
116116
// CHECK: rocdl.s.barrier
117-
ttg.async_wait {num = 62 : i32}
117+
amdgpu.async_wait {num_inst = 62 : i32}
118118
// CHECK: rocdl.s.waitcnt -1
119119
// CHECK: rocdl.s.waitcnt -7937
120120
// CHECK: rocdl.s.barrier
121-
ttg.async_wait {num = 63 : i32}
121+
amdgpu.async_wait {num_inst = 63 : i32}
122122
// Check that we clamp values > 63
123123
// CHECK: rocdl.s.waitcnt -1
124124
// CHECK: rocdl.s.waitcnt -7937
125125
// CHECK: rocdl.s.barrier
126-
ttg.async_wait {num = 64 : i32}
126+
amdgpu.async_wait {num_inst = 64 : i32}
127127
tt.return
128128
}
129129
}

test/TritonGPU/amd/amd-update-async-wait-count.mlir

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
1818
%3 = ttg.async_commit_group tokens %2
1919

2020
// Do not wait on the second async_copy => waitcnt 2
21-
// CHECK: ttg.async_wait {{.*}} {num = 2
21+
// CHECK: amdgpu.async_wait {{.*}} {num_inst = 2
2222
%9 = ttg.async_wait %1 {num = 0 : i32}
2323
// No async_copies in between => waitcnt 0
24-
// CHECK: ttg.async_wait {{.*}} {num = 0
24+
// CHECK: amdgpu.async_wait {{.*}} {num_inst = 0
2525
%10 = ttg.async_wait %3 {num = 0 : i32}
2626
tt.return
2727
}
@@ -47,10 +47,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
4747
%3 = ttg.async_commit_group tokens %2
4848

4949
// Do not wait on the second async_copy => waitcnt 2
50-
// CHECK: ttg.async_wait {{.*}} {num = 0
50+
// CHECK: amdgpu.async_wait {{.*}} {num_inst = 0
5151
%9 = ttg.async_wait %3 {num = 0 : i32}
5252
// No async_copies in between => waitcnt 0
53-
// CHECK: ttg.async_wait {{.*}} {num = 2
53+
// CHECK: amdgpu.async_wait {{.*}} {num_inst = 2
5454
%10 = ttg.async_wait %1 {num = 0 : i32}
5555
tt.return
5656
}
@@ -77,9 +77,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
7777

7878
%4 = tt.load %arg3 : tensor<128x16x!tt.ptr<f16>, #blocked>
7979

80-
// CHECK: ttg.async_wait {{.*}} {num = 2
80+
// CHECK: amdgpu.async_wait {{.*}} {num_inst = 2
8181
%9 = ttg.async_wait %1 {num = 0 : i32}
82-
// CHECK: ttg.async_wait {{.*}} {num = 0
82+
// CHECK: amdgpu.async_wait {{.*}} {num_inst = 0
8383
%10 = ttg.async_wait %3 {num = 0 : i32}
8484
tt.return
8585
}
@@ -106,15 +106,15 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
106106
%2 = ttg.async_copy_global_to_local %arg4, %arg2 : tensor<16x256x!tt.ptr<f16>, #blocked1> -> <16x256xf16, #shared1, #smem, mutable>
107107
%3 = ttg.async_commit_group tokens %2
108108
%8:2 = scf.for %arg14 = %c0_i32 to %arg0 step %c1_i32 iter_args(%arg15 = %1, %arg16 = %3) -> (!ttg.async.token, !ttg.async.token) : i32 {
109-
// CHECK: ttg.async_wait {{.*}}, {{.*}} {num = 0
109+
// CHECK: amdgpu.async_wait {{.*}}, {{.*}} {num_inst = 0
110110
%10 = ttg.async_wait %arg15, %arg16 {num = 2 : i32}
111111
%11 = ttg.async_copy_global_to_local %arg3, %arg1 : tensor<128x16x!tt.ptr<f16>, #blocked> -> <128x16xf16, #shared, #smem, mutable>
112112
%12 = ttg.async_commit_group tokens %11
113113
%13 = ttg.async_copy_global_to_local %arg4, %arg2 : tensor<16x256x!tt.ptr<f16>, #blocked1> -> <16x256xf16, #shared1, #smem, mutable>
114114
%14 = ttg.async_commit_group tokens %13
115115
scf.yield %12, %14: !ttg.async.token, !ttg.async.token
116116
}
117-
// CHECK: ttg.async_wait {{.*}}, {{.*}} {num = 0
117+
// CHECK: amdgpu.async_wait {{.*}}, {{.*}} {num_inst = 0
118118
%9 = ttg.async_wait %8#0, %8#1 {num = 0 : i32}
119119
tt.return
120120
}
@@ -145,15 +145,15 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
145145
%6 = ttg.async_copy_global_to_local %arg4, %arg2 : tensor<16x256x!tt.ptr<f16>, #blocked1> -> <16x256xf16, #shared1, #smem, mutable>
146146
%7 = ttg.async_commit_group tokens %6
147147
%8:4 = scf.for %arg14 = %c0_i32 to %arg0 step %c1_i32 iter_args(%arg15 = %1, %arg16 = %5, %arg17 = %3, %arg18 = %7) -> (!ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token) : i32 {
148-
// CHECK: ttg.async_wait {{.*}}, {{.*}} {num = 3
148+
// CHECK: amdgpu.async_wait {{.*}}, {{.*}} {num_inst = 3
149149
%10 = ttg.async_wait %arg15, %arg17 {num = 2 : i32}
150150
%11 = ttg.async_copy_global_to_local %arg3, %arg1 : tensor<128x16x!tt.ptr<f16>, #blocked> -> <128x16xf16, #shared, #smem, mutable>
151151
%12 = ttg.async_commit_group tokens %11
152152
%13 = ttg.async_copy_global_to_local %arg4, %arg2 : tensor<16x256x!tt.ptr<f16>, #blocked1> -> <16x256xf16, #shared1, #smem, mutable>
153153
%14 = ttg.async_commit_group tokens %13
154154
scf.yield %arg16, %12, %arg18, %14 : !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token
155155
}
156-
// CHECK: ttg.async_wait {{.*}}, {{.*}} {num = 0
156+
// CHECK: amdgpu.async_wait {{.*}}, {{.*}} {num_inst = 0
157157
%9 = ttg.async_wait %8#0, %8#1, %8#2, %8#3 {num = 0 : i32}
158158
tt.return
159159
}
@@ -185,12 +185,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
185185
%8:4 = scf.for %arg14 = %c0_i32 to %arg0 step %c1_i32 iter_args(%arg15 = %1, %arg16 = %5, %arg17 = %3, %arg18 = %7) -> (!ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token) : i32 {
186186
%103 = scf.if %cond -> (!ttg.async.token) {
187187
// We wait on both tokens so we interleave with one iteration => 3
188-
// CHECK: ttg.async_wait {{.*}}, {{.*}} {num = 3
188+
// CHECK: amdgpu.async_wait {{.*}}, {{.*}} {num_inst = 3
189189
%token1 = ttg.async_wait %arg15, %arg17 {num = 2 : i32}
190190
scf.yield %token1 : !ttg.async.token
191191
} else {
192192
// We only wait on the token of the first load so we can interleave one more load => 3 + 2
193-
// CHECK: ttg.async_wait {{.*}} {num = 5
193+
// CHECK: amdgpu.async_wait {{.*}} {num_inst = 5
194194
%token2 = ttg.async_wait %arg15 {num = 1 : i32}
195195
scf.yield %token2 : !ttg.async.token
196196
}
@@ -200,7 +200,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
200200
%14 = ttg.async_commit_group tokens %13
201201
scf.yield %arg16, %12, %arg18, %14 : !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token
202202
}
203-
// CHECK: ttg.async_wait {{.*}}, {{.*}} {num = 0
203+
// CHECK: amdgpu.async_wait {{.*}}, {{.*}} {num_inst = 0
204204
%9 = ttg.async_wait %8#0, %8#1, %8#2, %8#3 {num = 0 : i32}
205205
tt.return
206206
}
@@ -235,7 +235,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
235235
%cond_load = ttg.async_copy_global_to_local %arg4, %arg2 : tensor<16x256x!tt.ptr<f16>, #blocked1> -> <16x256xf16, #shared1, #smem, mutable>
236236
%cond_load_commit = ttg.async_commit_group tokens %cond_load
237237
// We wait on both tokens (3) and additionally we should count the load inside our block (+2) => 5
238-
// CHECK: ttg.async_wait {{.*}}, {{.*}} {num = 5
238+
// CHECK: amdgpu.async_wait {{.*}}, {{.*}} {num_inst = 5
239239
%token1 = ttg.async_wait %arg15, %arg17 {num = 2 : i32}
240240
scf.yield %token1 : !ttg.async.token
241241
} else {
@@ -247,7 +247,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
247247
%14 = ttg.async_commit_group tokens %13
248248
scf.yield %arg16, %12, %arg18, %14 : !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token
249249
}
250-
// CHECK: ttg.async_wait {{.*}}, {{.*}} {num = 0
250+
// CHECK: amdgpu.async_wait {{.*}}, {{.*}} {num_inst = 0
251251
%9 = ttg.async_wait %8#0, %8#1, %8#2, %8#3 {num = 0 : i32}
252252
tt.return
253253
}
@@ -279,7 +279,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
279279
%7 = ttg.async_commit_group tokens %6
280280
%8:4 = scf.for %arg14 = %c0_i32 to %arg0 step %c1_i32 iter_args(%arg15 = %1, %arg16 = %5, %arg17 = %3, %arg18 = %7) -> (!ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token) : i32 {
281281
// The then block contains 3 instructions and the else 1 so we expect the count to be 3 (1 + 2) because there are also 2 instructions outside the scf.if in the loop body
282-
// CHECK: ttg.async_wait {{.*}}, {{.*}} {num = 3
282+
// CHECK: amdgpu.async_wait {{.*}}, {{.*}} {num_inst = 3
283283
%token1 = ttg.async_wait %arg15, %arg17 {num = 2 : i32}
284284

285285
%103 = scf.if %cond -> (!ttg.async.token) {
@@ -296,7 +296,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
296296
%14 = ttg.async_commit_group tokens %13
297297
scf.yield %arg16, %103, %arg18, %14 : !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token
298298
}
299-
// CHECK: ttg.async_wait {{.*}}, {{.*}} {num = 0
299+
// CHECK: amdgpu.async_wait {{.*}}, {{.*}} {num_inst = 0
300300
%9 = ttg.async_wait %8#0, %8#1, %8#2, %8#3 {num = 0 : i32}
301301
tt.return
302302
}
@@ -323,14 +323,14 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
323323
%7 = ttg.async_commit_group tokens %6
324324
// Dynamic iteration count so we should not count its body
325325
%30 = scf.for %arg21 = %c0_i32 to %arg0 step %c1_i32 iter_args(%arg30 = %6) -> (!ttg.async.token) : i32 {
326-
// CHECK: ttg.async_wait {{.*}} {num = 0
326+
// CHECK: amdgpu.async_wait {{.*}} {num_inst = 0
327327
%31 = ttg.async_wait %arg30 {num = 1 : i32}
328328
// Emits 1 direct to lds instruction
329329
%32 = ttg.async_copy_global_to_local %arg3, %arg1 : tensor<128x16x!tt.ptr<f16>, #blocked> -> <128x16xf16, #shared, #smem, mutable>
330330
%33 = ttg.async_commit_group tokens %32
331331
scf.yield %33 : !ttg.async.token
332332
}
333-
// CHECK: ttg.async_wait {{.*}} {num = 1
333+
// CHECK: amdgpu.async_wait {{.*}} {num_inst = 1
334334
%10 = ttg.async_wait %1 {num = 1 : i32}
335335
tt.return
336336
}
@@ -357,14 +357,14 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
357357
%7 = ttg.async_commit_group tokens %6
358358
// Loop with 4 iterations => 4 instructions
359359
%30 = scf.for %arg21 = %c0_i32 to %c4_i32 step %c1_i32 iter_args(%arg30 = %6) -> (!ttg.async.token) : i32 {
360-
// CHECK: ttg.async_wait {{.*}} {num = 0
360+
// CHECK: amdgpu.async_wait {{.*}} {num_inst = 0
361361
%31 = ttg.async_wait %arg30 {num = 1 : i32}
362362
// Emits 1 direct to lds instruction
363363
%32 = ttg.async_copy_global_to_local %arg3, %arg1 : tensor<128x16x!tt.ptr<f16>, #blocked> -> <128x16xf16, #shared, #smem, mutable>
364364
%33 = ttg.async_commit_group tokens %32
365365
scf.yield %33 : !ttg.async.token
366366
}
367-
// CHECK: ttg.async_wait {{.*}} {num = 5
367+
// CHECK: amdgpu.async_wait {{.*}} {num_inst = 5
368368
%10 = ttg.async_wait %1 {num = 1 : i32}
369369
tt.return
370370
}
@@ -397,10 +397,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
397397

398398
// Check that we do not take other TDM loads into account (they use a different HW counter)
399399

400-
// CHECK: ttg.async_wait {{.*}} {num = 2
400+
// CHECK: amdgpu.async_wait {{.*}} {num_inst = 2
401401
%cw1 = ttg.async_wait %21 {num = 0 : i32}
402402

403-
// CHECK: ttg.async_wait {{.*}} {num = 0
403+
// CHECK: amdgpu.async_wait {{.*}} {num_inst = 0
404404
%cw2 = ttg.async_wait %51 {num = 0 : i32}
405405
tt.return
406406
}

third_party/amd/backend/compiler.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -256,8 +256,6 @@ def make_ttgir(mod, metadata, options):
256256
passes.common.add_canonicalizer(pm)
257257
passes.common.add_cse(pm)
258258
passes.common.add_symbol_dce(pm)
259-
if use_async_copy:
260-
amd.passes.ttgpuir.add_update_async_wait_count(pm, options.arch)
261259
pm.run(mod, 'make_ttgir')
262260
return mod
263261

@@ -283,6 +281,7 @@ def make_llir(src, metadata, options):
283281
# TritonGPU -> LLVM-IR (MLIR)
284282
pm = ir.pass_manager(mod.context)
285283
pm.enable_debug()
284+
amd.passes.ttgpuir.add_update_async_wait_count(pm, options.arch)
286285
# custom_lds_size is an experimental parameter that defines amount of LDS available
287286
# for one thread block. Measured in bytes.
288287
#

third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -775,4 +775,21 @@ def AsyncTDMWait : TT_AMDGPU_Op<"async_tdm_wait"> {
775775
let assemblyFormat = "$asyncToken attr-dict";
776776
}
777777

778+
//===----------------------------------------------------------------------===//
779+
// AsyncWait
780+
//===----------------------------------------------------------------------===//
781+
782+
def AsyncWaitOp : TT_AMDGPU_Op<"async_wait"> {
783+
let summary = "Wait until there are less than or equal to the given number of outstanding async intrinsics";
784+
let description = [{
785+
Similar to ttg.async_wait but instead of waiting on oustanding ttg.async_commit_groups
786+
this op waits on the number of outstanding async instructions/intrinsics as required for the
787+
lowering to LLVM on the AMD backend.
788+
}];
789+
790+
let arguments = (ins Variadic<TTG_AsyncToken>:$asyncToken, I32Attr:$num_inst);
791+
let results = (outs TTG_AsyncToken:$retToken);
792+
let assemblyFormat = "($asyncToken^)? attr-dict";
793+
}
794+
778795
#endif

third_party/amd/include/TritonAMDGPUTransforms/Passes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ def TritonAMDGPUUpdateAsyncWaitCount: Pass<"tritonamdgpu-update-async-wait-count
257257
compute the number of interleaving global memory instructions to emit the correct waitcnt during lowering.
258258
}];
259259

260-
let dependentDialects = [];
260+
let dependentDialects = ["mlir::triton::amdgpu::TritonAMDGPUDialect"];
261261

262262
let options = [
263263
Option<"archGenerationName", "arch-generation-name",

third_party/amd/lib/TritonAMDGPUToLLVM/AsyncUtility.cpp

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include "Dialect/TritonAMDGPU/IR/Dialect.h"
44
#include "TargetInfo.h"
5+
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
56
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
67
#include "llvm/ADT/TypeSwitch.h"
78

@@ -13,7 +14,7 @@ constexpr const char *syncedViaAsyncWaitAttrName =
1314
// if all defining operations are an AsyncWait
1415
bool comesFromAsyncWait(Value token) {
1516
if (auto defOp = token.getDefiningOp()) {
16-
return isa<triton::gpu::AsyncWaitOp>(defOp);
17+
return isa<triton::gpu::AsyncWaitOp, amdgpu::AsyncWaitOp>(defOp);
1718
}
1819

1920
auto blockArg = dyn_cast<BlockArgument>(token);
@@ -50,6 +51,22 @@ bool comesFromAsyncWait(Value token) {
5051
}
5152
} // namespace
5253

54+
void addLocalBarrierAfterAmdGpuAsyncWait(ModuleOp mod) {
55+
auto *ctx = mod->getContext();
56+
57+
SmallVector<amdgpu::AsyncWaitOp> waits;
58+
mod->walk([&waits](amdgpu::AsyncWaitOp waitOp) { waits.push_back(waitOp); });
59+
60+
IRRewriter builder(mod.getContext());
61+
for (auto waitOp : waits) {
62+
if (isa<mlir::gpu::BarrierOp, gpu::LocalBarrierOp>(waitOp->getNextNode()))
63+
continue;
64+
65+
builder.setInsertionPointAfter(waitOp);
66+
builder.create<triton::gpu::LocalBarrierOp>(waitOp->getLoc());
67+
}
68+
}
69+
5370
void annotateLocalLoadsSyncedViaAsyncWait(ModuleOp mod) {
5471
auto *ctx = mod->getContext();
5572

third_party/amd/lib/TritonAMDGPUToLLVM/AsyncUtility.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@
99
namespace mlir::triton::AMD {
1010
class TargetInfo;
1111

12+
// Walks the module and adds a LocalBarrier after any amdgpu.async_wait if there
13+
// is not already a barrier following it. This mimicks what Member does for
14+
// common async wait operations and avoids AMD specific modifications to Membar.
15+
// This yields to the same behaviour compared to when membar adds the barrier.
16+
void addLocalBarrierAfterAmdGpuAsyncWait(ModuleOp mod);
17+
1218
// Annotates LocalLoadOps with ttg.amdgpu.syncedByAsyncWait=true if they are
1319
// synced by an AsyncWait.
1420
void annotateLocalLoadsSyncedViaAsyncWait(ModuleOp mod);

third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1887,14 +1887,15 @@ struct AtomicRMWOpConversion
18871887
}
18881888
};
18891889

1890-
struct AsyncWaitOpConversion : public ConvertOpToLLVMPattern<AsyncWaitOp> {
1890+
struct AsyncWaitOpConversion
1891+
: public ConvertOpToLLVMPattern<amdgpu::AsyncWaitOp> {
18911892
AsyncWaitOpConversion(LLVMTypeConverter &converter,
18921893
const AMD::TargetInfo &targetInfo,
18931894
PatternBenefit benefit)
18941895
: ConvertOpToLLVMPattern(converter, benefit), targetInfo(targetInfo) {}
18951896

18961897
LogicalResult
1897-
matchAndRewrite(AsyncWaitOp op, OpAdaptor adaptor,
1898+
matchAndRewrite(amdgpu::AsyncWaitOp op, OpAdaptor adaptor,
18981899
ConversionPatternRewriter &rewriter) const override {
18991900
auto loc = op->getLoc();
19001901
auto b = TritonLLVMOpBuilder(loc, rewriter);
@@ -1912,7 +1913,7 @@ struct AsyncWaitOpConversion : public ConvertOpToLLVMPattern<AsyncWaitOp> {
19121913
// interested in those.
19131914

19141915
// Clamp vmcnt to 6bits; a lower vmcnt will produce a conservative wait
1915-
unsigned vmCnt = std::min(63u, op.getNum());
1916+
unsigned vmCnt = std::min(63u, op.getNumInst());
19161917

19171918
// Extract low and high bits and combine while setting all other bits to 1
19181919
unsigned lowBits = vmCnt & 0xF;
@@ -1925,7 +1926,7 @@ struct AsyncWaitOpConversion : public ConvertOpToLLVMPattern<AsyncWaitOp> {
19251926
}
19261927
case ISAFamily::GFX1250: {
19271928
// Clamp asyncCnt to 6bits(hw imit); lower means conservative
1928-
unsigned asyncCnt = std::min(63u, op.getNum());
1929+
unsigned asyncCnt = std::min(63u, op.getNumInst());
19291930
LLVM::createLLVMIntrinsicCallOp(rewriter, loc,
19301931
"llvm.amdgcn.s.wait.asynccnt", {},
19311932
{b.i16_val(asyncCnt)});

third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,8 @@ struct ConvertTritonAMDGPUToLLVM
121121

122122
if (targetInfo.requiresAliasInfoForAsyncOps())
123123
AMD::annotateLocalLoadsSyncedViaAsyncWait(mod);
124+
125+
AMD::addLocalBarrierAfterAmdGpuAsyncWait(mod);
124126
ModuleMembarAnalysis membarPass(&allocation,
125127
mlir::triton::AMD::membarFilter);
126128
membarPass.run();

0 commit comments

Comments
 (0)