Skip to content

Commit a77026d

Browse files
authored
feature: kernel barrier and fast compile for grpcoll (#223)
* add kernel barrier for groupcoll * split compile * refactor grpcoll for fast compile * code polish * lint and update contributing.md * resolve issues * fix ffa build * fix ffa build * fix ffa build * fix test
1 parent 9233922 commit a77026d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+3815
-2391
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# magi_attention
22
magi_attention/_version.py
33
magi_attention/flex_flash_attn*
4+
magi_attention/csrc/comm/grpcoll/instantiations/
45
*.nsys-rep
56
*.ncu-rep
67

CONTRIBUTING.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,18 @@ pre-commit run -a
6565

6666
> [!NOTE]
6767
> Code format checking will be automatically executed when you commit your changes.
68+
69+
70+
### Type Stubs (C++ Extension)
71+
72+
If you modify the C++ extension (`magi_attn_ext`), please remember to regenerate the Python type stubs (`.pyi` files). This ensures that static type checkers (like MyPy) and IDEs can correctly recognize the updated C++ signatures.
73+
74+
Ensure the extension is installed in your environment, then run:
75+
76+
```bash
77+
pybind11-stubgen magi_attention.magi_attn_ext -o .
78+
```
79+
80+
> [!IMPORTANT]
81+
> Failure to update stubs after modifying C++ code may cause type checking errors during CI.
82+
```

exps/dist_attn/run_benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,7 @@ def run_magi_attn(
508508
num_sms = int(getattr(ATTN_CONFIG, "num_sms", 24))
509509
nvl_chunk_size = int(getattr(ATTN_CONFIG, "nvl_chunk_size", 8))
510510
nvl_buffer_size = int(getattr(ATTN_CONFIG, "nvl_buffer_size", 256))
511-
rdma_chunk_size = int(getattr(ATTN_CONFIG, "rdma_chunk_size", 4))
511+
rdma_chunk_size = int(getattr(ATTN_CONFIG, "rdma_chunk_size", 16))
512512
rdma_buffer_size = int(getattr(ATTN_CONFIG, "rdma_buffer_size", 128))
513513
num_nvl_bytes = int(getattr(ATTN_CONFIG, "num_nvl_bytes", int(3e9))) # ~3GB
514514
# only valid for internode

exps/grpcoll/test_internode_grpcoll.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,15 @@
4747
from magi_attention.comm.primitive.grpcoll._buffer import GrpCollBuffer
4848
from magi_attention.comm.primitive.grpcoll._config import GrpCollConfig
4949
from magi_attention.comm.primitive.grpcoll._handle import GrpCollInterHandle
50-
from magi_attention.comm.primitive.grpcoll._mgr import grpcoll_mgr
50+
from magi_attention.comm.primitive.grpcoll._mgr import grpcoll_buffer_mgr
5151
from magi_attention.comm.primitive.grpcoll.utils import (
5252
get_a2av_perm_idxs_from_group_cast_meta,
5353
get_native_group_cast_meta,
5454
get_num_rdma_recv_tokens,
5555
transfer_splits_and_dst_idxs_to_t2r_idx,
5656
unpermute_output,
5757
)
58-
from magi_attention.common.enum import GroupReduceOp
58+
from magi_attention.common.enum import GroupReduceOp, GrpCollBufferName
5959
from magi_attention.testing.precision import assert_close
6060
from magi_attention.utils import pad_and_pack_tensors, setup_dist_env
6161

@@ -1547,7 +1547,7 @@ def test_loop(args: argparse.Namespace):
15471547
assert num_local_ranks == 8 and num_ranks > 8
15481548

15491549
# set grpcoll config
1550-
use_grpcoll_mgr = True
1550+
use_grpcoll_buffer_mgr = True
15511551
if args.test_ll_compatibility:
15521552
ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9
15531553
if local_rank == 0:
@@ -1591,13 +1591,13 @@ def test_loop(args: argparse.Namespace):
15911591
explicitly_destroy=True,
15921592
)
15931593

1594-
if use_grpcoll_mgr:
1595-
grpcoll_mgr.register_buffer(
1594+
if use_grpcoll_buffer_mgr:
1595+
grpcoll_buffer_mgr.initialize(
15961596
group=group,
15971597
config=buffer_config,
15981598
**extra_buffer_kwargs,
15991599
)
1600-
buffer = grpcoll_mgr.get_buffer(group)
1600+
buffer = grpcoll_buffer_mgr.get_buffer(GrpCollBufferName.GroupCastDefault)
16011601
else:
16021602
buffer_args = buffer_config.to_buffer_args()
16031603
buffer_args.update(extra_buffer_kwargs)
@@ -1619,8 +1619,8 @@ def test_loop(args: argparse.Namespace):
16191619
)
16201620

16211621
# Destroy the buffer runtime
1622-
if use_grpcoll_mgr:
1623-
grpcoll_mgr.release_buffer(group)
1622+
if use_grpcoll_buffer_mgr:
1623+
grpcoll_buffer_mgr.release_buffer(GrpCollBufferName.GroupCastDefault)
16241624
else:
16251625
buffer.destroy()
16261626
dist.barrier()

exps/grpcoll/test_intranode_grpcoll.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,14 @@
4747
from magi_attention.comm.primitive.grpcoll._buffer import GrpCollBuffer
4848
from magi_attention.comm.primitive.grpcoll._config import GrpCollConfig
4949
from magi_attention.comm.primitive.grpcoll._handle import GrpCollIntraHandle
50-
from magi_attention.comm.primitive.grpcoll._mgr import grpcoll_mgr
50+
from magi_attention.comm.primitive.grpcoll._mgr import grpcoll_buffer_mgr
5151
from magi_attention.comm.primitive.grpcoll.utils import (
5252
get_a2av_perm_idxs_from_group_cast_meta,
5353
get_native_group_cast_meta,
5454
transfer_splits_and_dst_idxs_to_t2r_idx,
5555
unpermute_output,
5656
)
57-
from magi_attention.common.enum import GroupReduceOp
57+
from magi_attention.common.enum import GroupReduceOp, GrpCollBufferName
5858
from magi_attention.utils import pad_and_pack_tensors
5959

6060
# isort: split
@@ -1424,7 +1424,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
14241424
rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
14251425

14261426
# set grpcoll config
1427-
use_grpcoll_mgr = True
1427+
use_grpcoll_buffer_mgr = True
14281428
test_ll_compatibility, num_rdma_bytes = False, 0
14291429
if test_ll_compatibility:
14301430
ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9
@@ -1466,13 +1466,13 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
14661466
explicitly_destroy=True,
14671467
)
14681468

1469-
if use_grpcoll_mgr:
1470-
grpcoll_mgr.register_buffer(
1469+
if use_grpcoll_buffer_mgr:
1470+
grpcoll_buffer_mgr.initialize(
14711471
group=group,
14721472
config=buffer_config,
14731473
**extra_buffer_kwargs,
14741474
)
1475-
buffer = grpcoll_mgr.get_buffer(group)
1475+
buffer = grpcoll_buffer_mgr.get_buffer(GrpCollBufferName.GroupCastDefault)
14761476
else:
14771477
buffer_args = buffer_config.to_buffer_args()
14781478
buffer_args.update(extra_buffer_kwargs)
@@ -1492,8 +1492,8 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
14921492
)
14931493

14941494
# Destroy the buffer runtime
1495-
if use_grpcoll_mgr:
1496-
grpcoll_mgr.release_buffer(group)
1495+
if use_grpcoll_buffer_mgr:
1496+
grpcoll_buffer_mgr.release_buffer(GrpCollBufferName.GroupCastDefault)
14971497
else:
14981498
buffer.destroy()
14991499
dist.barrier()

magi_attention/__init__.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import importlib.util
16+
import logging
1617
import os
1718
import warnings
1819

@@ -22,6 +23,36 @@
2223
init_dist_attn_runtime_mgr,
2324
)
2425

26+
try:
27+
from . import magi_attn_ext # type: ignore[attr-defined] # noqa: F401
28+
except ImportError as e:
29+
warnings.warn(
30+
f"Failed to import magi_attn_ext extension module. "
31+
f"Please make sure MagiAttention is properly installed. "
32+
f"Original error message: {e}"
33+
)
34+
35+
try:
36+
from . import magi_attn_comm # type: ignore[attr-defined] # noqa: F401
37+
except ImportError as e:
38+
warnings.warn(
39+
f"Failed to import magi_attn_comm extension module. "
40+
f"Please make sure MagiAttention is properly installed. "
41+
f"Original error message: {e}"
42+
)
43+
44+
try:
45+
from . import ( # type: ignore[attr-defined] # noqa: F401
46+
flexible_flash_attention_utils_cuda,
47+
)
48+
except ImportError as e:
49+
warnings.warn(
50+
f"Failed to import flexible_flash_attention_utils_cuda extension module. "
51+
f"Please make sure MagiAttention is properly installed. "
52+
f"Original error message: {e}"
53+
)
54+
55+
2556
if importlib.util.find_spec("magi_attention._version") is None:
2657
warnings.warn(
2758
"You are using magi_attention without installing it. This may cause some unexpected errors."
@@ -34,6 +65,13 @@
3465

3566
__version__: str | None = version
3667

68+
# Initialize a logger specific to this module/namespace
69+
logger = logging.getLogger(__name__)
70+
71+
# Add a NullHandler to prevent logging warnings ("No handlers could be found...")
72+
# if the application using this library hasn't configured logging.
73+
logger.addHandler(logging.NullHandler())
74+
3775

3876
def is_sanity_check_enable() -> bool:
3977
"""
@@ -123,4 +161,7 @@ def dist_attn_runtime_dict_size() -> int:
123161
"config",
124162
"comm",
125163
"functional",
164+
"magi_attn_ext",
165+
"magi_attn_comm",
166+
"flexible_flash_attention_utils_cuda",
126167
]

magi_attention/comm/primitive/grpcoll/_buffer.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,7 @@ def group_cast(
383383
post_perm_idx: torch.Tensor | None = None,
384384
config: GrpCollConfig | None = None,
385385
previous_event: EventOverlap | None = None,
386+
kernel_barrier=None,
386387
async_op: bool = False,
387388
allocate_on_comm_stream: bool = False,
388389
cast_lse: bool = False,
@@ -495,6 +496,7 @@ def group_cast(
495496
is_token_in_rank=is_token_in_rank,
496497
post_perm_idx=post_perm_idx,
497498
previous_event=previous_event,
499+
kernel_barrier=kernel_barrier,
498500
async_op=async_op,
499501
allocate_on_comm_stream=allocate_on_comm_stream,
500502
cast_lse=cast_lse,
@@ -514,6 +516,7 @@ def group_cast(
514516
is_token_in_rank=is_token_in_rank,
515517
post_perm_idx=post_perm_idx,
516518
previous_event=previous_event,
519+
kernel_barrier=kernel_barrier,
517520
async_op=async_op,
518521
allocate_on_comm_stream=allocate_on_comm_stream,
519522
cast_lse=cast_lse,
@@ -531,6 +534,7 @@ def group_reduce(
531534
pre_perm_idx: torch.Tensor | None = None,
532535
config: GrpCollConfig | None = None,
533536
previous_event: EventOverlap | None = None,
537+
kernel_barrier=None,
534538
async_op: bool = False,
535539
allocate_on_comm_stream: bool = False,
536540
comm_dtype: torch.dtype | None = None,
@@ -625,6 +629,7 @@ def group_reduce(
625629
acc_reduce=acc_reduce,
626630
pre_perm_idx=pre_perm_idx,
627631
previous_event=previous_event,
632+
kernel_barrier=kernel_barrier,
628633
async_op=async_op,
629634
allocate_on_comm_stream=allocate_on_comm_stream,
630635
comm_dtype=comm_dtype,
@@ -643,6 +648,7 @@ def group_reduce(
643648
acc_reduce=acc_reduce,
644649
pre_perm_idx=pre_perm_idx,
645650
previous_event=previous_event,
651+
kernel_barrier=kernel_barrier,
646652
async_op=async_op,
647653
allocate_on_comm_stream=allocate_on_comm_stream,
648654
comm_dtype=comm_dtype,
@@ -661,6 +667,7 @@ def _intranode_group_cast(
661667
is_token_in_rank: torch.Tensor | None = None,
662668
post_perm_idx: torch.Tensor | None = None,
663669
previous_event: EventOverlap | None = None,
670+
kernel_barrier=None,
664671
async_op: bool = False,
665672
allocate_on_comm_stream: bool = False,
666673
cast_lse: bool = False,
@@ -747,6 +754,7 @@ def _intranode_group_cast(
747754
post_perm_idx,
748755
config.to_kernel_config(),
749756
getattr(previous_event, "event", None),
757+
kernel_barrier,
750758
async_op,
751759
allocate_on_comm_stream,
752760
)
@@ -791,6 +799,7 @@ def _intranode_group_reduce(
791799
acc_reduce: bool = False,
792800
pre_perm_idx: torch.Tensor | None = None,
793801
previous_event: EventOverlap | None = None,
802+
kernel_barrier=None,
794803
async_op: bool = False,
795804
allocate_on_comm_stream: bool = False,
796805
comm_dtype: torch.dtype | None = None,
@@ -843,6 +852,7 @@ def _intranode_group_reduce(
843852
pre_perm_idx,
844853
config.to_kernel_config(),
845854
getattr(previous_event, "event", None),
855+
kernel_barrier,
846856
async_op,
847857
allocate_on_comm_stream,
848858
reduce_op,
@@ -873,6 +883,7 @@ def _internode_group_cast(
873883
is_token_in_rank: torch.Tensor | None = None,
874884
post_perm_idx: torch.Tensor | None = None,
875885
previous_event: EventOverlap | None = None,
886+
kernel_barrier=None,
876887
async_op: bool = False,
877888
allocate_on_comm_stream: bool = False,
878889
cast_lse: bool = False,
@@ -975,6 +986,7 @@ def _internode_group_cast(
975986
post_perm_idx,
976987
config.to_kernel_config(),
977988
getattr(previous_event, "event", None),
989+
kernel_barrier,
978990
async_op,
979991
allocate_on_comm_stream,
980992
)
@@ -1023,6 +1035,7 @@ def _internode_group_reduce(
10231035
acc_reduce: bool = False,
10241036
pre_perm_idx: torch.Tensor | None = None,
10251037
previous_event: EventOverlap | None = None,
1038+
kernel_barrier=None,
10261039
async_op: bool = False,
10271040
allocate_on_comm_stream: bool = False,
10281041
comm_dtype: torch.dtype | None = None,
@@ -1078,6 +1091,7 @@ def _internode_group_reduce(
10781091
pre_perm_idx,
10791092
config.to_kernel_config(),
10801093
getattr(previous_event, "event", None),
1094+
kernel_barrier,
10811095
async_op,
10821096
allocate_on_comm_stream,
10831097
reduce_op,

magi_attention/comm/primitive/grpcoll/_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class GrpCollConfig:
4242
num_sms: int = 24
4343
nvl_chunk_size: int = 8
4444
nvl_buffer_size: int = 256
45-
rdma_chunk_size: int = 4
45+
rdma_chunk_size: int = 16
4646
rdma_buffer_size: int = 128
4747

4848
# for buffer initialization

0 commit comments

Comments
 (0)