File tree Expand file tree Collapse file tree 2 files changed +3
-6
lines changed
Expand file tree Collapse file tree 2 files changed +3
-6
lines changed Original file line number Diff line number Diff line change 1- import copy
21import math
32import pickle # nosec B403
43from abc import ABC , abstractmethod
Original file line number Diff line number Diff line change @@ -372,8 +372,6 @@ def _fetch_new_requests_attention_dp(
372372 responses_list = self .dist .tp_cp_allgather (
373373 [len (activate_requests ), num_active_tokens ])
374374
375- # @B: Do we really need to check for all CP ranks? Should num_tokens be 1
376- # for all generation requests?
377375 aggregated_responses = []
378376 for dp_group_idx in range (self .dist .tp_size ):
379377 # Get all entries for this DP group (cp_size entries per group).
@@ -385,9 +383,9 @@ def _fetch_new_requests_attention_dp(
385383 assert all (entry [0 ] == group_entries [0 ][0 ] for entry in group_entries ), \
386384 f"CP ranks within DP group { dp_group_idx } have mismatched request counts: " \
387385 f"{ [entry [0 ] for entry in group_entries ]} "
388- # Sum the token counts across CP ranks (sequence is split) .
389- total_tokens = sum ( entry [ 1 ] for entry in group_entries )
390- aggregated_responses . append ( [group_entries [0 ][0 ], total_tokens ])
386+ # Use token count from cp_rank0 .
387+ aggregated_responses . append (
388+ [group_entries [0 ][0 ], group_entries [ 0 ][ 1 ] ])
391389 responses_list = aggregated_responses
392390 else :
393391 responses_list = self .dist .tp_allgather (
You can’t perform that action at this time.
0 commit comments