2727 maybe_finalize_async_save ,
2828 save_checkpoint ,
2929)
30+ from megatron .bridge .training .utils .pg_utils import get_pg_collection
3031from megatron .bridge .training .utils .train_utils import (
3132 logical_and_across_model_parallel_group ,
3233 reduce_max_stat_across_model_parallel_group ,
@@ -415,18 +416,20 @@ def train(
415416 else :
416417 update_successful , grad_norm , num_zeros_in_grad = (True , 0.0 , 0.0 )
417418
419+ pg_collection = get_pg_collection (self .model )
420+
418421 # when freezing sub-models we may have a mixture of successful and unsucessful ranks,
419422 # so we must gather across mp ranks
420423 update_successful = logical_and_across_model_parallel_group (
421- update_successful
424+ update_successful , mp_group = pg_collection . mp
422425 )
423426 # grad_norm and num_zeros_in_grad will be None on ranks without trainable params,
424427 # so we must gather across mp ranks
425428 grad_norm : float = reduce_max_stat_across_model_parallel_group (
426- grad_norm
429+ grad_norm , mp_group = pg_collection . mp
427430 )
428431 num_zeros_in_grad : float = reduce_max_stat_across_model_parallel_group (
429- num_zeros_in_grad
432+ num_zeros_in_grad , mp_group = pg_collection . mp
430433 )
431434
432435 if update_successful :
@@ -1036,9 +1039,6 @@ def generate(
10361039 ]
10371040 enable_chunked_prefill = mcore_generation_config ["enable_chunked_prefill" ]
10381041 unified_memory_level = mcore_generation_config ["unified_memory_level" ]
1039- buffer_guaranteed_fraction = mcore_generation_config [
1040- "buffer_guaranteed_fraction"
1041- ]
10421042 max_tokens = mcore_generation_config ["max_tokens" ]
10431043
10441044 model_config = self .model .config
@@ -1050,7 +1050,6 @@ def generate(
10501050 kv_channels = model_config .kv_channels ,
10511051 num_attention_heads = model_config .num_query_groups ,
10521052 max_sequence_length = self .cfg ["generation" ]["max_new_tokens" ],
1053- buffer_guaranteed_fraction = buffer_guaranteed_fraction ,
10541053 buffer_size_gb = buffer_size_gb ,
10551054 materialize_only_last_token_logits = False ,
10561055 num_cuda_graphs = num_cuda_graphs ,
@@ -1061,7 +1060,7 @@ def generate(
10611060 use_cuda_graphs_for_non_decode_steps = use_cuda_graphs_for_non_decode_steps ,
10621061 use_flashinfer_fused_rope = False ,
10631062 unified_memory_level = unified_memory_level ,
1064- max_tokens_override = max_tokens ,
1063+ max_tokens = max_tokens ,
10651064 )
10661065 inference_wrapped_model = GPTInferenceWrapper (
10671066 self .model , inference_wrapper_config , dynamic_context
@@ -1134,23 +1133,27 @@ def generate(
11341133
11351134 result = []
11361135 while dynamic_engine .has_unfinished_requests ():
1137- result_step = dynamic_engine .step_modern (verbose = False )
1138- finished_requests = result_step .get ("finished_requests" , [])
1139- for finished_request in finished_requests :
1140- result .append (finished_request )
1136+ result_step = dynamic_engine .step_modern ()
1137+ result .extend (result_step ["finished_request_records" ])
11411138
11421139 # Sort results by request_id to maintain original batch order
11431140 result .sort (key = lambda x : x .request_id )
11441141
11451142 out = {
1146- "tokens" : [x .prompt_tokens .tolist () + x .generated_tokens for x in result ],
1147- "logprobs" : [x .prompt_log_probs + x .generated_log_probs for x in result ],
1143+ "tokens" : [
1144+ x .requests [0 ].prompt_tokens .tolist () + x .requests [0 ].generated_tokens
1145+ for x in result
1146+ ],
1147+ "logprobs" : [
1148+ x .requests [0 ].prompt_log_probs + x .requests [0 ].generated_log_probs
1149+ for x in result
1150+ ],
11481151 }
11491152
11501153 input_lengths = data ["input_lengths" ]
11511154 # pad the out "tokens" and "logprobs" and make them into tensors from lists
11521155 batch_size = data ["input_ids" ].size (0 )
1153- max_gen_seq_len = max ([len (x .generated_tokens ) for x in result ])
1156+ max_gen_seq_len = max ([len (x .requests [ 0 ]. generated_tokens ) for x in result ])
11541157 padded_input_length = input_ids .size (1 )
11551158
11561159 max_seq_len = padded_input_length + max_gen_seq_len
0 commit comments