@@ -362,33 +362,16 @@ def _fetch_new_requests_attention_dp(
362362 # Get active request counts across all ranks.
363363 all_ranks_num_active_requests = []
364364 all_ranks_num_active_tokens = []
365- num_active_tokens = sum (
366- [req .py_orig_prompt_len for req in activate_requests ])
367365
368366 if self .dist .has_cp_helix :
369- # When CP is enabled with Helix parallelism, we need to gather from all ranks
370- # in the TP x CP space. CP ranks within the same DP group (same tp_rank) handle
371- # the same requests with different token portions (sequence is split across CP ranks).
372- responses_list = self .dist .tp_cp_allgather (
373- [len (activate_requests ), num_active_tokens ])
374-
375- aggregated_responses = []
376- for dp_group_idx in range (self .dist .tp_size ):
377- # Get all entries for this DP group (cp_size entries per group).
378- group_start = dp_group_idx * self .dist .cp_size
379- group_end = (dp_group_idx + 1 ) * self .dist .cp_size
380- group_entries = responses_list [group_start :group_end ]
381-
382- # All CP ranks within a DP group should have the same number of requests.
383- assert all (entry [0 ] == group_entries [0 ][0 ] for entry in group_entries ), \
384- f"CP ranks within DP group { dp_group_idx } have mismatched request counts: " \
385- f"{ [entry [0 ] for entry in group_entries ]} "
386- # Use token count from cp_rank0.
387- aggregated_responses .append (group_entries [0 ])
388- responses_list = aggregated_responses
367+ num_active_tokens = sum (
368+ [req .total_input_len_cp for req in activate_requests ])
389369 else :
390- responses_list = self .dist .tp_allgather (
391- [len (activate_requests ), num_active_tokens ])
370+ num_active_tokens = sum (
371+ [req .py_orig_prompt_len for req in activate_requests ])
372+
373+ responses_list = self .dist .tp_allgather (
374+ [len (activate_requests ), num_active_tokens ])
392375
393376 for num_active_requests , num_active_tokens in responses_list :
394377 all_ranks_num_active_requests .append (num_active_requests )
0 commit comments