5252from megatron .core .utils import (
5353 check_param_hashes_across_dp_replicas ,
5454 get_model_config ,
55+ get_pg_size ,
56+ get_pg_rank ,
5557 StragglerDetector ,
5658)
5759from megatron .core .fp8_utils import correct_amax_history_if_needed
60+ from megatron .core .process_groups_config import ProcessGroupCollection
61+ from megatron .core .pipeline_parallel .utils import (
62+ is_pp_first_stage ,
63+ is_pp_last_stage ,
64+ is_vp_first_stage ,
65+ is_vp_last_stage ,
66+ )
5867from megatron .training .checkpointing import load_checkpoint
5968from megatron .training .checkpointing import save_checkpoint
6069from megatron .training .checkpointing import checkpoint_exists
@@ -873,10 +882,12 @@ def update_train_iters(args):
873882 print_rank_0 (f'setting training iterations to { args .train_iters } ' )
874883
875884
876- def get_model (model_provider_func , model_type = ModelType .encoder_or_decoder , wrap_with_ddp = True ):
885+ def get_model (model_provider_func , model_type = ModelType .encoder_or_decoder , wrap_with_ddp = True , config = None , pg_collection = None ):
877886 """Build the model."""
878887 args = get_args ()
879888 args .model_type = model_type
889+ if pg_collection is None :
890+ pg_collection = ProcessGroupCollection .use_mpu_process_groups ()
880891
881892 if has_nvidia_modelopt :
882893 from megatron .post_training .checkpointing import has_modelopt_state
@@ -893,23 +904,38 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
893904 # Build model.
894905 def build_model ():
895906 if (
896- mpu . get_pipeline_model_parallel_world_size ( ) > 1
907+ get_pg_size ( pg_collection . pp ) > 1
897908 and args .virtual_pipeline_model_parallel_size is not None
898909 ):
899910 model = []
900- for i in range (args .virtual_pipeline_model_parallel_size ):
911+ vp_size = args .virtual_pipeline_model_parallel_size
912+ for i in range (vp_size ):
901913 # Set pre_process and post_process only after virtual rank is set.
902- pre_process = mpu .is_pipeline_first_stage (ignore_virtual = False , vp_stage = i )
903- post_process = mpu .is_pipeline_last_stage (ignore_virtual = False , vp_stage = i )
914+ pre_process = is_pp_first_stage (pg_collection .pp ) and is_vp_first_stage (
915+ vp_stage = i , vp_size = vp_size
916+ )
917+ post_process = is_pp_last_stage (pg_collection .pp ) and is_vp_last_stage (
918+ vp_stage = i , vp_size = vp_size
919+ )
904920 this_model = model_provider_func (
905- pre_process = pre_process , post_process = post_process , vp_stage = i )
921+ pre_process = pre_process ,
922+ post_process = post_process ,
923+ vp_stage = i ,
924+ config = config ,
925+ pg_collection = pg_collection ,
926+ )
906927 this_model .model_type = model_type
907928 this_model .vp_stage = i
908929 model .append (this_model )
909930 else :
910- pre_process = mpu .is_pipeline_first_stage ()
911- post_process = mpu .is_pipeline_last_stage ()
912- model = model_provider_func (pre_process = pre_process , post_process = post_process )
931+ pre_process = is_pp_first_stage (pg_collection .pp )
932+ post_process = is_pp_last_stage (pg_collection .pp )
933+ model = model_provider_func (
934+ pre_process = pre_process ,
935+ post_process = post_process ,
936+ config = config ,
937+ pg_collection = pg_collection ,
938+ )
913939 model .model_type = model_type
914940 return model
915941
@@ -934,12 +960,12 @@ def build_model():
934960 num_parameters = sum (
935961 [sum ([p .nelement () for p in model_module .parameters ()]) for model_module in model ]
936962 )
937- if mpu . get_data_parallel_rank ( ) == 0 and mpu . get_context_parallel_rank ( ) == 0 :
963+ if get_pg_rank ( pg_collection . dp ) == 0 and get_pg_rank ( pg_collection . cp ) == 0 :
938964 print (
939965 ' > number of parameters on (tensor, pipeline) '
940966 'model parallel rank ({}, {}): {}' .format (
941- mpu . get_tensor_model_parallel_rank ( ),
942- mpu . get_pipeline_model_parallel_rank ( ),
967+ get_pg_rank ( pg_collection . tp ),
968+ get_pg_rank ( pg_collection . pp ),
943969 num_parameters ,
944970 ),
945971 flush = True ,
0 commit comments