Skip to content

Commit 7025305

Browse files
authored
[AMD][GLUON] Wait outstanding async commit groups instead of instructions (#8605)
Currently `async_wait` in Gluon on `CDNA4` requires the kernel writer to pass the number of outstanding hardware instructions/llvm intrinsic to `async_wait`. This count is very difficult to compute as it relies on layouts, sizes, contiguity... This PR changes the semantics of `async_wait` to represent the number of outstanding commit groups. This follows the semantics used for nvidia in Gluon. Therefore, Gluon kernels need to commit outstanding async operations via `commit_group` and then wait on them via `wait_group`. I also adapted the names so existing Gluon kernels using the old semantics error out. `UpdateAsyncWaitCount` is extended to compute the number of outstanding hardware instructions based on the number of oustanding commits groups. Previously, it only worked on `async_waits` carrying tokens of the commit groups which are not available when compiling a Gluon kernel. This is done by walking the IR backwards following *all* possible control flow paths and finding the smallest number of emitted instructions for N outstanding commit groups.
1 parent 4c2175f commit 7025305

File tree

6 files changed

+836
-96
lines changed

6 files changed

+836
-96
lines changed

python/test/gluon/test_core.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -541,8 +541,9 @@ def kernel(a_ptr, b_ptr, use_buffer_load: ttgl.constexpr):
541541
cdna4_async_copy.buffer_load_to_shared(smem, a_ptr, offsets)
542542
else:
543543
cdna4_async_copy.global_load_to_shared(smem, a_ptr + offsets)
544+
cdna4_async_copy.commit_group()
544545

545-
cdna4_async_copy.async_wait(0)
546+
cdna4_async_copy.wait_group(0)
546547
a = cdna4_async_copy.load_shared_relaxed(smem, blocked)
547548

548549
ttgl.store(b_ptr + offsets, a)

python/test/gluon/test_frontend.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1950,17 +1950,36 @@ def test_infer_layout_for_amd_wmma(target):
19501950

19511951

19521952
@gluon.jit
1953-
def amd_async_wait():
1954-
cdna4_async_copy.async_wait(0)
1953+
def amd_commit_group():
1954+
cdna4_async_copy.commit_group()
1955+
1956+
1957+
@pytest.mark.parametrize("target", [HIP_TARGET_CDNA4])
1958+
def test_amd_commit_group(target):
1959+
mod = run_parser(amd_wait_group, target=target)
1960+
expecttest.assert_expected_inline(
1961+
anonymize_ir(mod.str_nodebug()), """\
1962+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 64 : i32} {
1963+
tt.func public @amd_wait_group() attributes {noinline = false} {
1964+
%0 = ttg.async_wait {num = 0 : i32}
1965+
tt.return
1966+
}
1967+
}
1968+
""")
1969+
1970+
1971+
@gluon.jit
1972+
def amd_wait_group():
1973+
cdna4_async_copy.wait_group(0)
19551974

19561975

19571976
@pytest.mark.parametrize("target", [HIP_TARGET_CDNA4])
19581977
def test_amd_async_wait(target):
1959-
mod = run_parser(amd_async_wait, target=target)
1978+
mod = run_parser(amd_wait_group, target=target)
19601979
expecttest.assert_expected_inline(
19611980
anonymize_ir(mod.str_nodebug()), """\
19621981
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 64 : i32} {
1963-
tt.func public @amd_async_wait() attributes {noinline = false} {
1982+
tt.func public @amd_wait_group() attributes {noinline = false} {
19641983
%0 = ttg.async_wait {num = 0 : i32}
19651984
tt.return
19661985
}

python/triton/experimental/gluon/language/amd/cdna4/async_copy.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
__all__ = [
77
"global_load_to_shared",
88
"buffer_load_to_shared",
9-
"async_wait",
9+
"commit_group",
10+
"wait_group",
1011
"load_shared_relaxed",
1112
]
1213

@@ -17,7 +18,10 @@ def global_load_to_shared(dest, ptr, mask=None, other=None, cache_modifier="", _
1718
AMD global load to shared operation. This operation loads data directly
1819
from global memory to shared memory without going through registers. It
1920
happens asynchronously and requires a subsequent `async_wait` to ensure the
20-
data is available in shared memory.
21+
data is available in shared memory. Note that this operation does still
22+
complete in order with ttgl.loads/stores or buffer_loads/stores on CDNA4,
23+
so interleaving with them will hurt performance.
24+
2125
Compared to `buffer_load_to_shared`, it requires a tensor pointer which
2226
supports 64-bit indexing range for each thread in a block, which gives more
2327
flexibility, but at the cost of higher register pressure and no hardware
@@ -72,7 +76,10 @@ def buffer_load_to_shared(dest, ptr, offsets, mask=None, other=None, cache_modif
7276
32-bit offsets instead of a tensor of pointers. This operation loads data
7377
directly from global memory to shared memory without going through
7478
registers. It happens asynchronously and requires a subsequent `async_wait`
75-
to ensure the data is available in shared memory.
79+
to ensure thedata is available in shared memory. Note that this operation
80+
does still complete in order with ttgl.loads/stores or buffer_loads/stores
81+
on CDNA4, so interleaving with them will hurt performance.
82+
7683
Compared to `global_load_to_shared`, it has better performance and also
7784
supports hardware out-of-bound masking. But it strictly requires a
7885
32-bit offset instead of a 64-bit tensor pointer.
@@ -118,16 +125,24 @@ def buffer_load_to_shared(dest, ptr, offsets, mask=None, other=None, cache_modif
118125

119126

120127
@builtin
121-
def async_wait(num_outstanding=0, _semantic=None):
128+
def commit_group(_semantic=None):
129+
"""
130+
Commit oustanding async operations.
131+
132+
This finalizes a set of async copy operations which can be waited upon via `wait_group`.
133+
"""
134+
_semantic.builder.create_async_commit_group()
135+
136+
137+
@builtin
138+
def wait_group(num_outstanding=0, _semantic=None):
122139
"""
123-
Wait for outstanding memory operations, this includes normal load like
124-
`load` and `buffer_load`, as well as direct load to shared memory
125-
like `global_load_to_shared` and `buffer_load_to_shared`.
126-
It will block until the number of outstanding memory operations is less than
127-
or equal to `num_outstanding`.
140+
Wait for outstanding commit groups. It will block until the number of
141+
outstanding commit groups is less than or equal to `num_outstanding`. Note that uncommited
142+
async operations will be waited upon even if `num_outstanding` is 0.
128143
129144
Args:
130-
num_outstanding (int): The number of outstanding operations to wait for. Defaults to 0.
145+
num_outstanding (int): The number of outstanding commit groups to wait for. Defaults to 0.
131146
"""
132147
num_outstanding = _unwrap_if_constexpr(num_outstanding)
133148
_semantic.builder.create_async_wait_group(num_outstanding)

0 commit comments

Comments
 (0)