Skip to content

Commit d83575a

Browse files
committed
add rdma_comm_bytes log
1 parent df64896 commit d83575a

File tree

1 file changed

+45
-24
lines changed

1 file changed

+45
-24
lines changed

magi_attention/functional/dist_attn.py

Lines changed: 45 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1413,8 +1413,10 @@ def _fetch_remote_kv(
14131413
* num_tensors,
14141414
input=local_kv if self.concat_kv else local_kv[0],
14151415
)
1416-
internode_output_seqlen: int = group_cast_args.get(
1417-
"internode_output_seqlen", -1
1416+
internode_output_seqlen: int = group_cast_args.get("internode_output_seqlen", 0)
1417+
group_cast_kv_rdma_bytes = self.compute_group_comm_bytes(
1418+
comm_tokens=internode_output_seqlen * num_tensors,
1419+
input=local_kv if self.concat_kv else local_kv[0],
14181420
)
14191421

14201422
with nvtx.add_nvtx_event(
@@ -1426,7 +1428,7 @@ def _fetch_remote_kv(
14261428
f"{output_kv_shape=} | "
14271429
f"{output_kv_dtype=} |"
14281430
f"{num_tensors=} | "
1429-
f"{internode_output_seqlen=}"
1431+
f"{group_cast_kv_rdma_bytes=}"
14301432
)
14311433
):
14321434
# launch group cast kernel
@@ -1498,8 +1500,10 @@ def _fetch_remote_q(
14981500
].group_cast_comm_tokens,
14991501
input=local_q,
15001502
)
1501-
internode_output_seqlen: int = group_cast_args.get(
1502-
"internode_output_seqlen", -1
1503+
internode_output_seqlen: int = group_cast_args.get("internode_output_seqlen", 0)
1504+
group_cast_q_rdma_bytes = self.compute_group_comm_bytes(
1505+
comm_tokens=internode_output_seqlen,
1506+
input=local_q,
15031507
)
15041508

15051509
with nvtx.add_nvtx_event(
@@ -1511,7 +1515,7 @@ def _fetch_remote_q(
15111515
f"output_q.shape={remote_q_buffer.shape} | "
15121516
f"output_q.dtype={remote_q_buffer.dtype} | "
15131517
f"num_tensors=1 | "
1514-
f"{internode_output_seqlen=}"
1518+
f"{group_cast_q_rdma_bytes=}"
15151519
)
15161520
):
15171521
# launch group cast kernel
@@ -1627,8 +1631,20 @@ def _fetch_remote_qo_do_lse(
16271631
)
16281632

16291633
group_cast_qo_do_lse_bytes = group_cast_qo_do_bytes + group_cast_lse_bytes
1634+
16301635
internode_output_seqlen: int = group_cast_args.get(
1631-
"internode_output_seqlen", -1
1636+
"internode_output_seqlen", 0
1637+
)
1638+
group_cast_qo_do_rdma_bytes = self.compute_group_comm_bytes(
1639+
comm_tokens=internode_output_seqlen * 3,
1640+
input=local_qo_do[0],
1641+
)
1642+
group_cast_lse_rdma_bytes = self.compute_group_comm_bytes(
1643+
comm_tokens=internode_output_seqlen,
1644+
input=local_lse,
1645+
)
1646+
group_cast_qo_do_lse_rdma_bytes = (
1647+
group_cast_qo_do_rdma_bytes + group_cast_lse_rdma_bytes
16321648
)
16331649

16341650
with nvtx.add_nvtx_event(
@@ -1645,7 +1661,7 @@ def _fetch_remote_qo_do_lse(
16451661
f"output_lse_shape={remote_lse_buffer.shape} | "
16461662
f"output_lse_dtype={remote_lse_buffer.dtype} | "
16471663
f"num_tensors_lse=1 | "
1648-
f"{internode_output_seqlen=}"
1664+
f"{group_cast_qo_do_lse_rdma_bytes=}"
16491665
)
16501666
):
16511667
# launch group cast kernel
@@ -1696,9 +1712,6 @@ def _fetch_remote_qo_do_lse(
16961712
].group_cast_comm_tokens,
16971713
lse=local_lse,
16981714
)
1699-
internode_output_seqlen_lse: int = group_cast_args_lse.get(
1700-
"internode_output_seqlen", -1
1701-
)
17021715

17031716
with nvtx.add_nvtx_event(
17041717
(
@@ -1708,8 +1721,7 @@ def _fetch_remote_qo_do_lse(
17081721
f"input_lse.dtype={local_lse.dtype} | "
17091722
f"output_lse.shape={remote_lse_buffer.shape} | "
17101723
f"output_lse.dtype={remote_lse_buffer.dtype} | "
1711-
f"num_tensors=1 | "
1712-
f"{internode_output_seqlen_lse=}"
1724+
f"num_tensors=1"
17131725
)
17141726
):
17151727
# launch group cast kernel for lse
@@ -1746,9 +1758,6 @@ def _fetch_remote_qo_do_lse(
17461758
].group_cast_comm_tokens,
17471759
input=local_qo_do,
17481760
)
1749-
internode_output_seqlen_qo_do: int = group_cast_args_qo_do.get(
1750-
"internode_output_seqlen", -1
1751-
)
17521761

17531762
with nvtx.add_nvtx_event(
17541763
(
@@ -1758,8 +1767,7 @@ def _fetch_remote_qo_do_lse(
17581767
f"input_qo_do.dtype={local_qo_do.dtype} | "
17591768
f"output_qo_do.shape={remote_qo_do_buffer.shape} | " # type: ignore
17601769
f"output_qo_do.dtype={remote_qo_do_buffer.dtype} | " # type: ignore
1761-
f"num_tensors=1 | "
1762-
f"{internode_output_seqlen_qo_do=}"
1770+
f"num_tensors=1"
17631771
)
17641772
):
17651773
# launch group cast kernel for qo_do
@@ -1876,7 +1884,12 @@ def _reduce_partial_out_lse(
18761884
lse=partial_remote_lse,
18771885
)
18781886
internode_output_seqlen: int = group_reduce_args.get(
1879-
"internode_output_seqlen", -1
1887+
"internode_output_seqlen", 0
1888+
)
1889+
group_cast_out_lse_rdma_bytes = self.compute_group_comm_bytes(
1890+
comm_tokens=internode_output_seqlen,
1891+
input=partial_remote_out,
1892+
lse=partial_remote_lse,
18801893
)
18811894

18821895
with nvtx.add_nvtx_event(
@@ -1891,7 +1904,7 @@ def _reduce_partial_out_lse(
18911904
f"input_lse.dtype={partial_remote_lse.dtype} | "
18921905
f"output_lse.shape={partial_local_lse.shape} | "
18931906
f"output_lse.dtype={partial_local_lse.dtype} | "
1894-
f"{internode_output_seqlen=}"
1907+
f"{group_cast_out_lse_rdma_bytes=}"
18951908
)
18961909
):
18971910
# launch group-reduce kernel
@@ -2022,7 +2035,11 @@ def _reduce_partial_dkv(
20222035
input=partial_remote_dkv if self.concat_dkv else partial_remote_dkv[0], # type: ignore
20232036
)
20242037
internode_output_seqlen: int = group_reduce_args.get(
2025-
"internode_output_seqlen", -1
2038+
"internode_output_seqlen", 0
2039+
)
2040+
group_cast_dkv_rdma_bytes = self.compute_group_comm_bytes(
2041+
comm_tokens=internode_output_seqlen * num_tensors_of_dkv,
2042+
input=partial_remote_dkv if self.concat_dkv else partial_remote_dkv[0], # type: ignore
20262043
)
20272044
with nvtx.add_nvtx_event(
20282045
(
@@ -2033,7 +2050,7 @@ def _reduce_partial_dkv(
20332050
f"{output_dkv_shape=} | "
20342051
f"{output_dkv_dtype=} | "
20352052
f"{num_tensors_of_dkv=} | "
2036-
f"{internode_output_seqlen=}"
2053+
f"{group_cast_dkv_rdma_bytes=}"
20372054
)
20382055
):
20392056
# launch group-reduce kernel
@@ -2116,7 +2133,11 @@ def _reduce_partial_dq(
21162133
input=partial_remote_dq,
21172134
)
21182135
internode_output_seqlen: int = group_reduce_args.get(
2119-
"internode_output_seqlen", -1
2136+
"internode_output_seqlen", 0
2137+
)
2138+
group_cast_dq_rdma_bytes = self.compute_group_comm_bytes(
2139+
comm_tokens=internode_output_seqlen,
2140+
input=partial_remote_dq,
21202141
)
21212142

21222143
with nvtx.add_nvtx_event(
@@ -2128,7 +2149,7 @@ def _reduce_partial_dq(
21282149
f"output_dq.shape={partial_local_dq.shape} | "
21292150
f"output_dq.dtype={partial_local_dq.dtype} | "
21302151
f"tensors_num_of_dq=1 | "
2131-
f"{internode_output_seqlen=}"
2152+
f"{group_cast_dq_rdma_bytes=}"
21322153
)
21332154
):
21342155
# launch group-reduce kernel

0 commit comments

Comments
 (0)