@@ -1413,8 +1413,10 @@ def _fetch_remote_kv(
14131413 * num_tensors ,
14141414 input = local_kv if self .concat_kv else local_kv [0 ],
14151415 )
1416- internode_output_seqlen : int = group_cast_args .get (
1417- "internode_output_seqlen" , - 1
1416+ internode_output_seqlen : int = group_cast_args .get ("internode_output_seqlen" , 0 )
1417+ group_cast_kv_rdma_bytes = self .compute_group_comm_bytes (
1418+ comm_tokens = internode_output_seqlen * num_tensors ,
1419+ input = local_kv if self .concat_kv else local_kv [0 ],
14181420 )
14191421
14201422 with nvtx .add_nvtx_event (
@@ -1426,7 +1428,7 @@ def _fetch_remote_kv(
14261428 f"{ output_kv_shape = } | "
14271429 f"{ output_kv_dtype = } |"
14281430 f"{ num_tensors = } | "
1429- f"{ internode_output_seqlen = } "
1431+ f"{ group_cast_kv_rdma_bytes = } "
14301432 )
14311433 ):
14321434 # launch group cast kernel
@@ -1498,8 +1500,10 @@ def _fetch_remote_q(
14981500 ].group_cast_comm_tokens ,
14991501 input = local_q ,
15001502 )
1501- internode_output_seqlen : int = group_cast_args .get (
1502- "internode_output_seqlen" , - 1
1503+ internode_output_seqlen : int = group_cast_args .get ("internode_output_seqlen" , 0 )
1504+ group_cast_q_rdma_bytes = self .compute_group_comm_bytes (
1505+ comm_tokens = internode_output_seqlen ,
1506+ input = local_q ,
15031507 )
15041508
15051509 with nvtx .add_nvtx_event (
@@ -1511,7 +1515,7 @@ def _fetch_remote_q(
15111515 f"output_q.shape={ remote_q_buffer .shape } | "
15121516 f"output_q.dtype={ remote_q_buffer .dtype } | "
15131517 f"num_tensors=1 | "
1514- f"{ internode_output_seqlen = } "
1518+ f"{ group_cast_q_rdma_bytes = } "
15151519 )
15161520 ):
15171521 # launch group cast kernel
@@ -1627,8 +1631,20 @@ def _fetch_remote_qo_do_lse(
16271631 )
16281632
16291633 group_cast_qo_do_lse_bytes = group_cast_qo_do_bytes + group_cast_lse_bytes
1634+
16301635 internode_output_seqlen : int = group_cast_args .get (
1631- "internode_output_seqlen" , - 1
1636+ "internode_output_seqlen" , 0
1637+ )
1638+ group_cast_qo_do_rdma_bytes = self .compute_group_comm_bytes (
1639+ comm_tokens = internode_output_seqlen * 3 ,
1640+ input = local_qo_do [0 ],
1641+ )
1642+ group_cast_lse_rdma_bytes = self .compute_group_comm_bytes (
1643+ comm_tokens = internode_output_seqlen ,
1644+ input = local_lse ,
1645+ )
1646+ group_cast_qo_do_lse_rdma_bytes = (
1647+ group_cast_qo_do_rdma_bytes + group_cast_lse_rdma_bytes
16321648 )
16331649
16341650 with nvtx .add_nvtx_event (
@@ -1645,7 +1661,7 @@ def _fetch_remote_qo_do_lse(
16451661 f"output_lse_shape={ remote_lse_buffer .shape } | "
16461662 f"output_lse_dtype={ remote_lse_buffer .dtype } | "
16471663 f"num_tensors_lse=1 | "
1648- f"{ internode_output_seqlen = } "
1664+ f"{ group_cast_qo_do_lse_rdma_bytes = } "
16491665 )
16501666 ):
16511667 # launch group cast kernel
@@ -1696,9 +1712,6 @@ def _fetch_remote_qo_do_lse(
16961712 ].group_cast_comm_tokens ,
16971713 lse = local_lse ,
16981714 )
1699- internode_output_seqlen_lse : int = group_cast_args_lse .get (
1700- "internode_output_seqlen" , - 1
1701- )
17021715
17031716 with nvtx .add_nvtx_event (
17041717 (
@@ -1708,8 +1721,7 @@ def _fetch_remote_qo_do_lse(
17081721 f"input_lse.dtype={ local_lse .dtype } | "
17091722 f"output_lse.shape={ remote_lse_buffer .shape } | "
17101723 f"output_lse.dtype={ remote_lse_buffer .dtype } | "
1711- f"num_tensors=1 | "
1712- f"{ internode_output_seqlen_lse = } "
1724+ f"num_tensors=1"
17131725 )
17141726 ):
17151727 # launch group cast kernel for lse
@@ -1746,9 +1758,6 @@ def _fetch_remote_qo_do_lse(
17461758 ].group_cast_comm_tokens ,
17471759 input = local_qo_do ,
17481760 )
1749- internode_output_seqlen_qo_do : int = group_cast_args_qo_do .get (
1750- "internode_output_seqlen" , - 1
1751- )
17521761
17531762 with nvtx .add_nvtx_event (
17541763 (
@@ -1758,8 +1767,7 @@ def _fetch_remote_qo_do_lse(
17581767 f"input_qo_do.dtype={ local_qo_do .dtype } | "
17591768 f"output_qo_do.shape={ remote_qo_do_buffer .shape } | " # type: ignore
17601769 f"output_qo_do.dtype={ remote_qo_do_buffer .dtype } | " # type: ignore
1761- f"num_tensors=1 | "
1762- f"{ internode_output_seqlen_qo_do = } "
1770+ f"num_tensors=1"
17631771 )
17641772 ):
17651773 # launch group cast kernel for qo_do
@@ -1876,7 +1884,12 @@ def _reduce_partial_out_lse(
18761884 lse = partial_remote_lse ,
18771885 )
18781886 internode_output_seqlen : int = group_reduce_args .get (
1879- "internode_output_seqlen" , - 1
1887+ "internode_output_seqlen" , 0
1888+ )
1889+ group_cast_out_lse_rdma_bytes = self .compute_group_comm_bytes (
1890+ comm_tokens = internode_output_seqlen ,
1891+ input = partial_remote_out ,
1892+ lse = partial_remote_lse ,
18801893 )
18811894
18821895 with nvtx .add_nvtx_event (
@@ -1891,7 +1904,7 @@ def _reduce_partial_out_lse(
18911904 f"input_lse.dtype={ partial_remote_lse .dtype } | "
18921905 f"output_lse.shape={ partial_local_lse .shape } | "
18931906 f"output_lse.dtype={ partial_local_lse .dtype } | "
1894- f"{ internode_output_seqlen = } "
1907+ f"{ group_cast_out_lse_rdma_bytes = } "
18951908 )
18961909 ):
18971910 # launch group-reduce kernel
@@ -2022,7 +2035,11 @@ def _reduce_partial_dkv(
20222035 input = partial_remote_dkv if self .concat_dkv else partial_remote_dkv [0 ], # type: ignore
20232036 )
20242037 internode_output_seqlen : int = group_reduce_args .get (
2025- "internode_output_seqlen" , - 1
2038+ "internode_output_seqlen" , 0
2039+ )
2040+ group_cast_dkv_rdma_bytes = self .compute_group_comm_bytes (
2041+ comm_tokens = internode_output_seqlen * num_tensors_of_dkv ,
2042+ input = partial_remote_dkv if self .concat_dkv else partial_remote_dkv [0 ], # type: ignore
20262043 )
20272044 with nvtx .add_nvtx_event (
20282045 (
@@ -2033,7 +2050,7 @@ def _reduce_partial_dkv(
20332050 f"{ output_dkv_shape = } | "
20342051 f"{ output_dkv_dtype = } | "
20352052 f"{ num_tensors_of_dkv = } | "
2036- f"{ internode_output_seqlen = } "
2053+ f"{ group_cast_dkv_rdma_bytes = } "
20372054 )
20382055 ):
20392056 # launch group-reduce kernel
@@ -2116,7 +2133,11 @@ def _reduce_partial_dq(
21162133 input = partial_remote_dq ,
21172134 )
21182135 internode_output_seqlen : int = group_reduce_args .get (
2119- "internode_output_seqlen" , - 1
2136+ "internode_output_seqlen" , 0
2137+ )
2138+ group_cast_dq_rdma_bytes = self .compute_group_comm_bytes (
2139+ comm_tokens = internode_output_seqlen ,
2140+ input = partial_remote_dq ,
21202141 )
21212142
21222143 with nvtx .add_nvtx_event (
@@ -2128,7 +2149,7 @@ def _reduce_partial_dq(
21282149 f"output_dq.shape={ partial_local_dq .shape } | "
21292150 f"output_dq.dtype={ partial_local_dq .dtype } | "
21302151 f"tensors_num_of_dq=1 | "
2131- f"{ internode_output_seqlen = } "
2152+ f"{ group_cast_dq_rdma_bytes = } "
21322153 )
21332154 ):
21342155 # launch group-reduce kernel
0 commit comments