Skip to content

Commit 62ea289

Browse files
xiaohuguo2023Copilotmawad-amd
authored
Initial development of triton based communication primitives using Iris - try 2 (ROCm#1607)
* add iris as aiter optional dependency * prevents torch.distributed from leaking into aiter namespace * add reduce_scatter triton ops * add all_gater triton ops * add fused rs_rms_quant_ag triton comm op * add unit tests * complete integration * update format * update fused kernel * more format update * suspect this may break vllm tests * add iris.py and remove unecessary iris function api * fix ci import error by enable conditional import iris * remove unecessary tests * Refactored the reduce-scatter and all-gather kernel implementations to eliminate code duplication between standalone operations and fused kernels * use new restructured reduce_scatter and all_gather * add calculate_heap_size to atomatically allocate heap with M, N * make sure calculate_heap_size exported in different level * fix undefined shmem issue * use aiter rmsnorm triton kernel instead * add howto * add test for fused rs_rmsnorm_quant_ag kernel * use multiple processing for this test as well * resolve format issue * fix typo for black * fix another black format issue * ctx is always going to be required * remove unused variables * remove uncessary pass * remove unusded math module * Update aiter/ops/triton/comms/iris.py Co-authored-by: Copilot <[email protected]> * Update op_tests/multigpu_tests/triton/test_reduce_scatter_all_gather.py Co-authored-by: Copilot <[email protected]> * remove dead link * Update aiter/ops/triton/comms/fused/__init__.py Co-authored-by: Muhammad Awad <[email protected]> * Update aiter/ops/triton/comms/fused/reduce_scatter_rmsnorm_quant_all_gather.py Co-authored-by: Muhammad Awad <[email protected]> * Update aiter/ops/triton/comms/fused/reduce_scatter_rmsnorm_quant_all_gather.py Co-authored-by: Muhammad Awad <[email protected]> * Update op_tests/multigpu_tests/triton/test_fused_rs_rmsnorm_quant_ag.py Co-authored-by: Muhammad Awad <[email protected]> * Update aiter/ops/triton/comms/reduce_scatter.py Co-authored-by: Muhammad Awad <[email protected]> * fix iterations for distributed tests to prevent deadlocks * update license * pin iris SHA * fix format * The warnings were at import time (too early), now they're only at usage time (when it matters) * add IRIS_DEP to pin IRIS release to avoid sudden iris api change Co-authored-by: Copilot <[email protected]> * remove commented lines * Fixed fp8 dtype check inconsistency * Add public is_initialized property to fix encapsulation issue * remove unused * Fixed multiprocessing race condition * fix aiter imports and remove unused M_shard * fix inefficient mem allocation * fix format issue * completely silent iris log when not using iris * fix the import crash * move triton_comms README * add triton comms info to README.md --------- Co-authored-by: Copilot <[email protected]> Co-authored-by: Muhammad Awad <[email protected]>
1 parent c7b3bee commit 62ea289

File tree

16 files changed

+2370
-8
lines changed

16 files changed

+2370
-8
lines changed

README.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,24 @@ If you happen to forget the `--recursive` during `clone`, you can use the follow
2525
git submodule sync && git submodule update --init --recursive
2626
```
2727

28+
### Triton-based Communication (Iris)
29+
30+
AITER supports GPU-initiated communication using the [Iris library](https://github.com/ROCm/iris). This enables high-performance Triton-based communication primitives like reduce-scatter and all-gather.
31+
32+
**Installation**
33+
34+
Install with Triton communication support:
35+
36+
```bash
37+
# Option 1: Install via extras
38+
pip install -e ".[triton_comms]"
39+
40+
# Option 2: Install all optional dependencies
41+
pip install -e ".[all]"
42+
```
43+
44+
For more details, see [docs/triton_comms.md](docs/triton_comms.md).
45+
2846
## Run operators supported by aiter
2947

3048
There are number of op test, you can run them with: `python3 op_tests/test_layernorm2d.py`

aiter/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,17 @@ def getLogger():
7676
from .ops.sample import *
7777
from .ops.fused_mrope_rms import *
7878
from . import mla
79+
80+
# Import Triton-based communication primitives from ops.triton.comms (optional, only if Iris is available)
81+
try:
82+
from .ops.triton.comms import (
83+
IrisCommContext,
84+
calculate_heap_size,
85+
reduce_scatter,
86+
all_gather,
87+
reduce_scatter_rmsnorm_quant_all_gather,
88+
IRIS_COMM_AVAILABLE,
89+
)
90+
except ImportError:
91+
# Iris not available, skip import
92+
IRIS_COMM_AVAILABLE = False

aiter/ops/triton/__init__.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,38 @@
11
# SPDX-License-Identifier: MIT
22
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
33

4-
# SPDX-License-Identifier: MIT
5-
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
64
from . import quant
75

8-
# __all__ = [
9-
# "quant",
10-
# ]
6+
# Try to import comms module (requires iris)
7+
try:
8+
from . import comms
9+
10+
# Re-export communication primitives at this level for convenience
11+
from .comms import (
12+
IrisCommContext,
13+
reduce_scatter,
14+
all_gather,
15+
reduce_scatter_rmsnorm_quant_all_gather,
16+
IRIS_COMM_AVAILABLE,
17+
)
18+
19+
_COMMS_AVAILABLE = True
20+
except ImportError:
21+
# Iris not available - comms module won't be available
22+
_COMMS_AVAILABLE = False
23+
IRIS_COMM_AVAILABLE = False
24+
comms = None
25+
26+
__all__ = ["quant"]
27+
28+
if _COMMS_AVAILABLE:
29+
__all__.extend(
30+
[
31+
"comms",
32+
"IrisCommContext",
33+
"reduce_scatter",
34+
"all_gather",
35+
"reduce_scatter_rmsnorm_quant_all_gather",
36+
"IRIS_COMM_AVAILABLE",
37+
]
38+
)

aiter/ops/triton/comms/__init__.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# SPDX-License-Identifier: MIT
2+
# Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
3+
4+
"""
5+
Triton-based communication primitives for AITER.
6+
7+
This submodule contains communication operations implemented using Triton,
8+
including Iris-based GPU-initiated communication.
9+
10+
If Iris is not available, importing this module will raise ImportError.
11+
"""
12+
13+
# Import all Iris-based communication primitives
14+
# If Iris is not installed, this import will fail and the entire
15+
# aiter.ops.triton.comms module will be unavailable
16+
from .iris import IrisCommContext, calculate_heap_size
17+
from .reduce_scatter import reduce_scatter
18+
from .all_gather import all_gather
19+
from .fused import reduce_scatter_rmsnorm_quant_all_gather
20+
21+
__all__ = [
22+
"IrisCommContext",
23+
"calculate_heap_size",
24+
"reduce_scatter",
25+
"all_gather",
26+
"reduce_scatter_rmsnorm_quant_all_gather",
27+
"IRIS_COMM_AVAILABLE",
28+
]
29+
30+
# If we got here, Iris is available
31+
IRIS_COMM_AVAILABLE = True
Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
# SPDX-License-Identifier: MIT
2+
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
3+
4+
"""
5+
All-Gather communication primitive using Iris.
6+
7+
This module provides an all-gather operation along the M dimension using
8+
GPU-initiated communication via the Iris library.
9+
"""
10+
11+
import torch
12+
from torch import Tensor
13+
import triton
14+
import triton.language as tl
15+
import logging
16+
17+
import iris
18+
19+
# If we got here, iris is available
20+
IRIS_AVAILABLE = True
21+
22+
logger = logging.getLogger("aiter")
23+
24+
25+
@triton.jit
26+
def _all_gather_impl(
27+
pid,
28+
shard_ptr,
29+
out_ptr,
30+
M,
31+
M_shard,
32+
N,
33+
stride_sm,
34+
stride_sn,
35+
stride_om,
36+
stride_on,
37+
cur_rank: tl.constexpr,
38+
world_size: tl.constexpr,
39+
heap_bases: tl.tensor,
40+
BLOCK_M: tl.constexpr,
41+
BLOCK_N: tl.constexpr,
42+
GROUP_SIZE_M: tl.constexpr,
43+
NUM_SMS: tl.constexpr,
44+
):
45+
"""
46+
Shared all-gather implementation using push-based approach with iris.put. 1D persistent-style PID mapping
47+
48+
Each rank sends its (M_shard)×N to all other ranks at the appropriate offset.
49+
50+
Args:
51+
pid: Program ID, 1D persistent-style PID mapping
52+
from tl.program_id(0) or passed from parent kernel
53+
"""
54+
num_pid_m = tl.cdiv(M_shard, BLOCK_M)
55+
num_pid_n = tl.cdiv(N, BLOCK_N)
56+
total_tiles = num_pid_m * num_pid_n
57+
58+
# Persistent loop over tiles
59+
for tile_id in range(pid, total_tiles, NUM_SMS):
60+
# Swizzle pattern
61+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
62+
group_id = tile_id // num_pid_in_group
63+
first_pid_m = group_id * GROUP_SIZE_M
64+
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
65+
pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m)
66+
pid_n = (tile_id % num_pid_in_group) // group_size_m
67+
68+
# Local indices
69+
rm_local = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
70+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
71+
rm_local = tl.max_contiguous(tl.multiple_of(rm_local, BLOCK_M), BLOCK_M)
72+
rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_N), BLOCK_N)
73+
mask_m_local = rm_local < M_shard
74+
mask_n = rn < N
75+
76+
# Load local shard
77+
shard_ptrs = shard_ptr + rm_local[:, None] * stride_sm + rn[None, :] * stride_sn
78+
shard_data = tl.load(
79+
shard_ptrs, mask=mask_m_local[:, None] & mask_n[None, :], other=0.0
80+
)
81+
82+
# Send to all ranks at the appropriate M offset
83+
for dst in range(world_size):
84+
# Calculate global M indices
85+
rm_global = cur_rank * M_shard + rm_local
86+
mask_m_global = rm_global < M
87+
final_mask = mask_m_global[:, None] & mask_n[None, :]
88+
89+
out_ptrs = (
90+
out_ptr + rm_global[:, None] * stride_om + rn[None, :] * stride_on
91+
)
92+
93+
if dst == cur_rank:
94+
# Local store
95+
tl.store(out_ptrs, shard_data, mask=final_mask)
96+
else:
97+
# Remote store using iris.put
98+
# from_ptr: local source, to_ptr: remote destination
99+
iris.put(
100+
shard_ptr + rm_local[:, None] * stride_sm + rn[None, :] * stride_sn,
101+
out_ptrs,
102+
cur_rank,
103+
dst,
104+
heap_bases,
105+
mask=final_mask,
106+
)
107+
108+
109+
@triton.jit
110+
def _all_gather_kernel(
111+
shard_ptr, # *[M_shard, N]
112+
out_ptr, # *[M, N]
113+
M,
114+
M_shard,
115+
N,
116+
stride_sm,
117+
stride_sn,
118+
stride_om,
119+
stride_on,
120+
cur_rank: tl.constexpr,
121+
world_size: tl.constexpr,
122+
heap_bases: tl.tensor,
123+
BLOCK_M: tl.constexpr,
124+
BLOCK_N: tl.constexpr,
125+
GROUP_SIZE_M: tl.constexpr,
126+
NUM_SMS: tl.constexpr,
127+
):
128+
"""
129+
All-gather kernel entry point.
130+
131+
This is a wrapper around _all_gather_impl that gets the program ID.
132+
"""
133+
pid = tl.program_id(0)
134+
_all_gather_impl(
135+
pid,
136+
shard_ptr,
137+
out_ptr,
138+
M,
139+
M_shard,
140+
N,
141+
stride_sm,
142+
stride_sn,
143+
stride_om,
144+
stride_on,
145+
cur_rank,
146+
world_size,
147+
heap_bases,
148+
BLOCK_M,
149+
BLOCK_N,
150+
GROUP_SIZE_M,
151+
NUM_SMS,
152+
)
153+
154+
155+
def all_gather(
156+
input_shard: Tensor,
157+
ctx: "IrisCommContext" = None,
158+
block_m: int = 64,
159+
block_n: int = 64,
160+
group_size_m: int = 8,
161+
num_sms: int = 256,
162+
) -> Tensor:
163+
"""
164+
Perform all-gather along the M (row) dimension.
165+
166+
This operation:
167+
1. Each rank has a shard of shape [M_shard, N]
168+
2. All ranks send their shards to all other ranks
169+
3. Each rank receives a full tensor of shape [M, N] where M = M_shard * world_size
170+
171+
Args:
172+
input_shard (Tensor): Input shard of shape [M_shard, N] in Iris shared memory
173+
ctx (IrisCommContext): Iris communication context. Optional if global context exists.
174+
block_m (int): Block size for M dimension. Default: 64
175+
block_n (int): Block size for N dimension. Default: 64
176+
group_size_m (int): Group size for swizzling. Default: 8
177+
num_sms (int): Number of SMs to use (persistent kernel). Default: 256
178+
179+
Returns:
180+
Tensor: Full tensor of shape [M, N] where M = M_shard * world_size
181+
182+
Example:
183+
>>> with IrisCommContext() as ctx:
184+
>>> input_shard = ctx.iris_ctx.zeros((1024, 7168), dtype=torch.float32)
185+
>>> # ... initialize input_shard ...
186+
>>> full_tensor = all_gather(input_shard, ctx)
187+
>>> print(full_tensor.shape) # [8192, 7168] for world_size=8
188+
"""
189+
if not IRIS_AVAILABLE:
190+
raise RuntimeError("Iris library is not available. Cannot perform all-gather.")
191+
192+
if not ctx.is_initialized:
193+
raise RuntimeError(
194+
"Iris context not initialized. Use IrisCommContext as context manager."
195+
)
196+
197+
# Get distributed parameters from context
198+
cur_rank = ctx.cur_rank
199+
world_size = ctx.num_ranks
200+
heap_bases = ctx.get_heap_bases()
201+
iris_ctx = ctx.iris_ctx
202+
203+
# Input shape
204+
M_shard, N = input_shard.shape
205+
M = M_shard * world_size
206+
207+
logger.info(
208+
f"Rank {cur_rank}/{world_size}: All-gather M_shard={M_shard}, N={N} -> M={M}"
209+
)
210+
211+
# Allocate output buffer in IRIS shared memory
212+
full_output = iris_ctx.zeros((M, N), dtype=input_shard.dtype)
213+
214+
# Launch kernel
215+
grid = (num_sms,)
216+
_all_gather_kernel[grid](
217+
input_shard,
218+
full_output,
219+
M,
220+
M_shard,
221+
N,
222+
input_shard.stride(0),
223+
input_shard.stride(1),
224+
full_output.stride(0),
225+
full_output.stride(1),
226+
cur_rank,
227+
world_size,
228+
heap_bases,
229+
BLOCK_M=block_m,
230+
BLOCK_N=block_n,
231+
GROUP_SIZE_M=group_size_m,
232+
NUM_SMS=num_sms,
233+
num_warps=16,
234+
num_stages=4,
235+
waves_per_eu=4,
236+
)
237+
238+
# Synchronize
239+
torch.cuda.synchronize()
240+
iris_ctx.barrier()
241+
242+
logger.info(
243+
f"Rank {cur_rank}: All-gather complete, output shape: {full_output.shape}"
244+
)
245+
246+
return full_output
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# SPDX-License-Identifier: MIT
2+
# Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
3+
4+
"""
5+
Fused Communication + Computation Kernels
6+
7+
This submodule contains Triton kernels that fuse communication operations
8+
with computation operations for improved performance.
9+
10+
Examples:
11+
- reduce_scatter + rmsnorm + quant + all_gather
12+
- all_reduce + rmsnorm + quant
13+
- reduce_scatter + gemm + all_gather
14+
"""
15+
16+
from .reduce_scatter_rmsnorm_quant_all_gather import (
17+
reduce_scatter_rmsnorm_quant_all_gather,
18+
)
19+
20+
__all__ = [
21+
"reduce_scatter_rmsnorm_quant_all_gather",
22+
]

0 commit comments

Comments
 (0)