Skip to content

Commit 48a673d

Browse files
authored
[Ring Attention] Add more detailed references (#6294)
* fix * fix
1 parent 4ac2227 commit 48a673d

File tree

1 file changed

+10
-5
lines changed
  • colossalai/shardformer/layer

1 file changed

+10
-5
lines changed

colossalai/shardformer/layer/attn.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -406,13 +406,18 @@ def _rescale_out_lse(out, block_out, lse, block_lse):
406406
class RingAttention(torch.autograd.Function):
407407
"""Implements the Ring Attention from `Ring Attention with Blockwise Transformers for Near-Infinite Context`
408408
(https://arxiv.org/abs/2310.01889).
409-
For load-balancing we adopted the "zigzag" attention scheme from https://github.com/zhuzilin/ring-flash-attention/tree/main
410-
For portable integration with more models, we don't follow the spirit of "block-wise FNN" in the original paper,
411-
which requires fusing FFN with the Flash Attention kernel/function (see https://arxiv.org/pdf/2305.19370;
412-
implemented in Jax and not optimized).
413-
We adopt the double ring topology from LoongTrain (https://arxiv.org/pdf/2406.18485) to fully utilize available
409+
For load-balancing, we adopted the "zigzag" dataloading scheme from ring-flash-attention.
410+
We also adopt the double ring topology from LoongTrain to fully utilize available
414411
NICs on each node, by computing attention within a inner ring first and then sending all KVs to the next
415412
ring at once.
413+
Our implementation references code from
414+
- ring-flash-attention: https://github.com/zhuzilin/ring-flash-attention/tree/main
415+
- Megatron Context Parallel: https://github.com/NVIDIA/TransformerEngine/pull/726
416+
References:
417+
- Ring Attention with Blockwise Transformers for Near-Infinite Context
418+
https://arxiv.org/abs/2310.01889
419+
- LoongTrain: Efficient Training of Long-Sequence LLMs with Head-Context Parallelism
420+
https://arxiv.org/abs/2406.18485
416421
"""
417422

418423
# Globle cache to avoid recomputation for same-lengthed sequences

0 commit comments

Comments
 (0)