Skip to content

Commit 445e4ec

Browse files
Copilotmawad-amd
andcommitted
Implement Ring Attention (arxiv:2310.01889)
Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com>
1 parent d21bb20 commit 445e4ec

File tree

5 files changed

+800
-0
lines changed

5 files changed

+800
-0
lines changed
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
<!--
2+
SPDX-License-Identifier: MIT
3+
Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.
4+
-->
5+
6+
# Ring Attention
7+
8+
An implementation of **Ring Attention with Blockwise Transformers** for
9+
near-infinite context on AMD GPUs using [Iris](../../README.md).
10+
11+
> Liu, H., Li, M., Hall, A., Dao, T., & Abbeel, P. (2023).
12+
> *Ring Attention with Blockwise Transformers for Near-Infinite Context.*
13+
> arXiv:2310.01889. <https://arxiv.org/pdf/2310.01889>
14+
15+
---
16+
17+
## Algorithm
18+
19+
Standard self-attention requires O(n²) memory in the sequence length n.
20+
Ring Attention enables sequences far longer than what fits on a single device
21+
by distributing them across a *ring* of GPUs:
22+
23+
1. The full sequence is split evenly across **N GPUs** along the sequence
24+
dimension. Each device holds a chunk of Q, K, and V of length
25+
`seq_total / N`.
26+
2. **Q stays local**. K and V rotate around the ring one step at a time.
27+
3. At each of the **N steps**, every device runs a local
28+
[Flash Attention](https://arxiv.org/abs/2205.14135) pass and accumulates
29+
the result using **online softmax**.
30+
4. After all N steps the accumulator is normalised to yield the final output.
31+
32+
For **causal (autoregressive) attention** only the steps where the KV chunk
33+
precedes or coincides with the Q chunk contribute, allowing early termination
34+
for some ranks and reducing total compute.
35+
36+
```
37+
Step 0: rank r processes its own K_r, V_r (causal block diagonal)
38+
Step 1: rank r receives K_{r-1}, V_{r-1} (full attention, past)
39+
...
40+
Step r: rank r receives K_0, V_0 (full attention, past)
41+
Step r+1..N-1: all-future chunks – skipped (causal mode only)
42+
```
43+
44+
---
45+
46+
## Files
47+
48+
| File | Description |
49+
|------|-------------|
50+
| `ring_attention_kernels.py` | Triton flash-attention kernel + Python ring-rotation helper |
51+
| `ring_attention_layer.py` | `RingAttention` – a `torch.nn.Module` wrapper |
52+
| `example_run.py` | End-to-end demo with timing |
53+
54+
---
55+
56+
## Usage
57+
58+
### Quick demo
59+
60+
```bash
61+
# 2 GPUs, causal attention (default)
62+
python examples/32_ring_attention/example_run.py
63+
64+
# 4 GPUs, bidirectional
65+
python examples/32_ring_attention/example_run.py --num_ranks 4 --no_causal
66+
67+
# Custom sizes
68+
python examples/32_ring_attention/example_run.py \
69+
--num_ranks 8 \
70+
--total_seq_len 131072 \
71+
--num_heads 32 \
72+
--head_dim 128
73+
```
74+
75+
### Validation
76+
77+
```bash
78+
python tests/run_tests_distributed.py tests/examples/test_ring_attention.py --num_ranks 2 -v
79+
```
80+
81+
---
82+
83+
## Python API
84+
85+
```python
86+
import iris
87+
from examples.ring_attention.ring_attention_layer import RingAttention
88+
89+
shmem = iris.iris()
90+
91+
# Each rank holds its local chunk
92+
layer = RingAttention(
93+
shmem,
94+
num_heads=16,
95+
head_dim=64,
96+
causal=True, # autoregressive masking
97+
)
98+
99+
# q, k, v: [seq_local, num_heads, head_dim] (float16 or bfloat16)
100+
output = layer(q, k, v) # [seq_local, num_heads, head_dim]
101+
```
102+
103+
---
104+
105+
## Design Notes
106+
107+
* **Communication**: KV rotation uses `torch.distributed.isend` / `irecv`
108+
(point-to-point), launching overlapping sends and receives to maximise
109+
throughput.
110+
* **Online softmax**: The kernel maintains running max (`M`) and sum (`L`)
111+
accumulators in float32 for numerical stability. The final output is
112+
`O / L` after all ring steps.
113+
* **Causal masking**: Handled entirely at the granularity of KV *chunks*
114+
full attention, diagonal block attention, or skip – so the per-element mask
115+
is applied only in the same-block diagonal case.
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
#!/usr/bin/env python3
2+
# SPDX-License-Identifier: MIT
3+
# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.
4+
5+
"""
6+
Minimal example demonstrating ring attention using the RingAttention layer.
7+
8+
The sequence is split evenly across GPUs along the sequence dimension.
9+
Each rank computes its share of the attention output. After the ring passes
10+
Q and V are combined via online-softmax, yielding the same result as a single
11+
device running full attention on the entire sequence.
12+
13+
Usage::
14+
15+
# Run on 2 GPUs (default)
16+
python examples/32_ring_attention/example_run.py
17+
18+
# Run on 4 GPUs
19+
python examples/32_ring_attention/example_run.py --num_ranks 4
20+
21+
# Non-causal (bidirectional) attention
22+
python examples/32_ring_attention/example_run.py --no_causal
23+
"""
24+
25+
import argparse
26+
27+
import torch
28+
import torch.distributed as dist
29+
import torch.multiprocessing as mp
30+
31+
import iris
32+
from ring_attention_layer import RingAttention
33+
34+
35+
def parse_args():
36+
parser = argparse.ArgumentParser(description="Ring Attention example")
37+
parser.add_argument("--total_seq_len", type=int, default=4096, help="Total sequence length (split across GPUs)")
38+
parser.add_argument("--num_heads", type=int, default=16, help="Number of attention heads")
39+
parser.add_argument("--head_dim", type=int, default=64, help="Head dimension")
40+
parser.add_argument("--num_ranks", type=int, default=2, help="Number of GPUs")
41+
parser.add_argument("--no_causal", action="store_true", help="Use bidirectional (non-causal) attention")
42+
parser.add_argument(
43+
"--dtype",
44+
type=str,
45+
default="float16",
46+
choices=["float16", "bfloat16"],
47+
help="Input tensor dtype",
48+
)
49+
return parser.parse_args()
50+
51+
52+
def run(rank: int, world_size: int, init_url: str, args):
53+
backend = "nccl" if torch.cuda.is_available() else "gloo"
54+
dist.init_process_group(
55+
backend=backend,
56+
init_method=init_url,
57+
world_size=world_size,
58+
rank=rank,
59+
device_id=torch.device(f"cuda:{rank}"),
60+
)
61+
62+
shmem = iris.iris()
63+
torch.manual_seed(42)
64+
torch.set_default_device("cuda")
65+
66+
dtype = getattr(torch, args.dtype)
67+
causal = not args.no_causal
68+
69+
seq_local = args.total_seq_len // world_size
70+
num_heads = args.num_heads
71+
head_dim = args.head_dim
72+
73+
if rank == 0:
74+
attn_type = "causal" if causal else "bidirectional"
75+
print(f"--- Ring Attention Example ({attn_type}) ---")
76+
print(f" GPUs : {world_size}")
77+
print(f" Total seq len : {args.total_seq_len}")
78+
print(f" Seq per GPU : {seq_local}")
79+
print(f" Heads × dim : {num_heads} × {head_dim}")
80+
print(f" dtype : {dtype}")
81+
82+
# Each rank creates its local Q, K, V chunk
83+
q = torch.randn(seq_local, num_heads, head_dim, dtype=dtype)
84+
k = torch.randn(seq_local, num_heads, head_dim, dtype=dtype)
85+
v = torch.randn(seq_local, num_heads, head_dim, dtype=dtype)
86+
87+
shmem.barrier()
88+
89+
layer = RingAttention(shmem, num_heads=num_heads, head_dim=head_dim, causal=causal)
90+
91+
# Warm-up pass
92+
_ = layer(q, k, v)
93+
torch.cuda.synchronize()
94+
shmem.barrier()
95+
96+
# Timed pass
97+
start = torch.cuda.Event(enable_timing=True)
98+
end = torch.cuda.Event(enable_timing=True)
99+
100+
start.record()
101+
output = layer(q, k, v)
102+
end.record()
103+
104+
torch.cuda.synchronize()
105+
elapsed_ms = start.elapsed_time(end)
106+
107+
if rank == 0:
108+
print(f"\nOutput shape : {output.shape}")
109+
print(f"Output dtype : {output.dtype}")
110+
print(f"Elapsed time : {elapsed_ms:.2f} ms")
111+
print(f"Output[0, 0, :4] = {output[0, 0, :4].float()}")
112+
113+
shmem.barrier()
114+
dist.destroy_process_group()
115+
116+
117+
def main():
118+
args = parse_args()
119+
init_url = "tcp://127.0.0.1:29500"
120+
mp.spawn(
121+
fn=run,
122+
args=(args.num_ranks, init_url, args),
123+
nprocs=args.num_ranks,
124+
join=True,
125+
)
126+
127+
128+
if __name__ == "__main__":
129+
main()

0 commit comments

Comments
 (0)