@@ -955,124 +955,119 @@ def build_model():
955955 model .model_type = model_type
956956 return model
957957
958- # Setup stream for model building/ddp initialization. The side-stream may be necessary for
959- # cuda graph capture support with DDP, but we sync it with the current stream to avoid race
960- # conditions.
961- setup_stream = torch .cuda .Stream ()
962- # Wait for the default stream to complete before starting setup_stream
963- setup_stream .wait_stream (torch .cuda .current_stream ())
964- # Make setup_stream start after whatever the default stream already queued
965- with torch .cuda .stream (setup_stream ):
966- if args .init_model_with_meta_device :
967- with torch .device ('meta' ):
968- model = build_model ()
969- else :
958+
959+ if args .init_model_with_meta_device :
960+ with torch .device ('meta' ):
970961 model = build_model ()
962+ else :
963+ model = build_model ()
971964
972- if not isinstance (model , list ):
973- model = [model ]
965+ if not isinstance (model , list ):
966+ model = [model ]
974967
975- # Set tensor model parallel attributes if not set.
976- # Only parameters that are already tensor model parallel have these
977- # attributes set for them. We should make sure the default attributes
978- # are set for all params so the optimizer can use them.
979- for model_module in model :
980- for param in model_module .parameters ():
981- tensor_parallel .set_defaults_if_not_set_tensor_model_parallel_attributes (param )
968+ # Set tensor model parallel attributes if not set.
969+ # Only parameters that are already tensor model parallel have these
970+ # attributes set for them. We should make sure the default attributes
971+ # are set for all params so the optimizer can use them.
972+ for model_module in model :
973+ for param in model_module .parameters ():
974+ tensor_parallel .set_defaults_if_not_set_tensor_model_parallel_attributes (param )
982975
983- # Print number of parameters.
984- num_parameters = sum (
985- [sum ([p .nelement () for p in model_module .parameters ()]) for model_module in model ]
976+ # Print number of parameters.
977+ num_parameters = sum (
978+ [sum ([p .nelement () for p in model_module .parameters ()]) for model_module in model ]
979+ )
980+ if get_pg_rank (pg_collection .dp ) == 0 and get_pg_rank (pg_collection .cp ) == 0 :
981+ print (
982+ ' > number of parameters on (tensor, pipeline) '
983+ 'model parallel rank ({}, {}): {}' .format (
984+ get_pg_rank (pg_collection .tp ),
985+ get_pg_rank (pg_collection .pp ),
986+ num_parameters ,
987+ ),
988+ flush = True ,
986989 )
987- if get_pg_rank (pg_collection .dp ) == 0 and get_pg_rank (pg_collection .cp ) == 0 :
988- print (
989- ' > number of parameters on (tensor, pipeline) '
990- 'model parallel rank ({}, {}): {}' .format (
991- get_pg_rank (pg_collection .tp ),
992- get_pg_rank (pg_collection .pp ),
993- num_parameters ,
994- ),
995- flush = True ,
996- )
997-
998- # GPU allocation.
999- # For FSDP2, we don't allocate GPU memory here. We allocate GPU memory
1000- # in the fully_shard function of FSDP2 instead.
1001- if (
1002- not (args .use_torch_fsdp2 and args .use_cpu_initialization )
1003- and not args .init_model_with_meta_device
1004- ):
1005- for model_module in model :
1006- model_module .cuda (torch .cuda .current_device ())
1007-
1008- # Fp16 conversion.
1009- if args .fp16 or args .bf16 :
1010- config = get_model_config (model [0 ])
1011- model = [Float16Module (config , model_module ) for model_module in model ]
1012-
1013- # Materialize tensors on meta device (GPU allocation) if not using FSDP2 and not using Megatron FSDP.
1014- if args .init_model_with_meta_device and not args .use_torch_fsdp2 and not args .use_megatron_fsdp :
1015- #for model_module in model:
1016- model = [to_empty_if_meta_device (model_module , device = torch .device ("cuda" )) for model_module in model ]
1017-
1018990
991+ # GPU allocation.
992+ # For FSDP2, we don't allocate GPU memory here. We allocate GPU memory
993+ # in the fully_shard function of FSDP2 instead.
994+ if (
995+ not (args .use_torch_fsdp2 and args .use_cpu_initialization )
996+ and not args .init_model_with_meta_device
997+ ):
998+ for model_module in model :
999+ model_module .cuda (torch .cuda .current_device ())
1000+
1001+ # Fp16 conversion.
1002+ if args .fp16 or args .bf16 :
1003+ config = get_model_config (model [0 ])
1004+ model = [Float16Module (config , model_module ) for model_module in model ]
1005+
1006+ # Materialize tensors on meta device (GPU allocation) if not using FSDP2 and not using Megatron FSDP.
1007+ if args .init_model_with_meta_device and not args .use_torch_fsdp2 and not args .use_megatron_fsdp :
1008+ model = [to_empty_if_meta_device (model_module , device = torch .device ("cuda" )) for model_module in model ]
1009+
1010+ # Before TE2.x: The model_module.bfloat16()/model_module.half() above will call the inplace
1011+ # copy of TE's Float8Tensor, which will write an unwanted value (amax calculated
1012+ # from the current fp8 param) to its amax_history. The below function will correct
1013+ # the amax_history back.
1014+ # After TE2.x: Below function is an empty function and does nothing.
1015+ correct_amax_history_if_needed (model )
1016+
1017+ if wrap_with_ddp :
1018+ if args .use_torch_fsdp2 :
1019+ assert HAVE_FSDP2 , "Torch FSDP2 requires torch>=2.4.0"
1020+ DP = torch_FSDP
1021+ elif args .use_megatron_fsdp :
1022+ DP = megatron_FSDP
1023+ else :
1024+ DP = DDP
10191025
1026+ config = get_model_config (model [0 ])
10201027
1021- # Before TE2.x: The model_module.bfloat16()/model_module.half() above will call the inplace
1022- # copy of TE's Float8Tensor, which will write an unwanted value (amax calculated
1023- # from the current fp8 param) to its amax_history. The below function will correct
1024- # the amax_history back.
1025- # After TE2.x: Below function is an empty function and does nothing.
1026- correct_amax_history_if_needed (model )
1027-
1028- if wrap_with_ddp :
1029- if args .use_torch_fsdp2 :
1030- assert HAVE_FSDP2 , "Torch FSDP2 requires torch>=2.4.0"
1031- DP = torch_FSDP
1032- elif args .use_megatron_fsdp :
1033- DP = megatron_FSDP
1034- else :
1035- DP = DDP
1036-
1037- config = get_model_config (model [0 ])
1038-
1039- if getattr (args , "use_torch_fsdp2" , False ):
1040- reshard_after_forward = getattr (args , "torch_fsdp2_reshard_after_forward" , True )
1041- ddp_config = TorchFullyShardedDataParallelConfig (reshard_after_forward = reshard_after_forward )
1028+ if getattr (args , "use_torch_fsdp2" , False ):
1029+ reshard_after_forward = getattr (args , "torch_fsdp2_reshard_after_forward" , True )
1030+ ddp_config = TorchFullyShardedDataParallelConfig (reshard_after_forward = reshard_after_forward )
1031+ else :
1032+ kwargs = {}
1033+ for f in dataclasses .fields (DistributedDataParallelConfig ):
1034+ if hasattr (args , f .name ):
1035+ kwargs [f .name ] = getattr (args , f .name )
1036+ kwargs ['grad_reduce_in_fp32' ] = args .accumulate_allreduce_grads_in_fp32
1037+ kwargs ['check_for_nan_in_grad' ] = args .check_for_nan_in_loss_and_grad
1038+ kwargs ['check_for_large_grads' ] = args .check_for_large_grads
1039+ if args .ddp_num_buckets is not None :
1040+ assert args .ddp_bucket_size is None , \
1041+ "Cannot specify both --ddp-num-buckets and --ddp-bucket-size"
1042+ assert args .ddp_num_buckets > 0 , \
1043+ "--ddp-num-buckets must be greater than 0"
1044+ kwargs ['bucket_size' ] = num_parameters // args .ddp_num_buckets
10421045 else :
1043- kwargs = {}
1044- for f in dataclasses .fields (DistributedDataParallelConfig ):
1045- if hasattr (args , f .name ):
1046- kwargs [f .name ] = getattr (args , f .name )
1047- kwargs ['grad_reduce_in_fp32' ] = args .accumulate_allreduce_grads_in_fp32
1048- kwargs ['check_for_nan_in_grad' ] = args .check_for_nan_in_loss_and_grad
1049- kwargs ['check_for_large_grads' ] = args .check_for_large_grads
1050- if args .ddp_num_buckets is not None :
1051- assert args .ddp_bucket_size is None , \
1052- "Cannot specify both --ddp-num-buckets and --ddp-bucket-size"
1053- assert args .ddp_num_buckets > 0 , \
1054- "--ddp-num-buckets must be greater than 0"
1055- kwargs ['bucket_size' ] = num_parameters // args .ddp_num_buckets
1056- else :
1057- kwargs ['bucket_size' ] = args .ddp_bucket_size
1058- kwargs ['pad_buckets_for_high_nccl_busbw' ] = args .ddp_pad_buckets_for_high_nccl_busbw
1059- kwargs ['reduce_scatter_with_fp32_accumulation' ] = args .ddp_reduce_scatter_with_fp32_accumulation
1060- kwargs ['average_in_collective' ] = args .ddp_average_in_collective
1061- ddp_config = DistributedDataParallelConfig (** kwargs )
1062-
1063- # In the Megatron FSDP and DDP use path, we need to initialize the bucket size.
1064- # If bucket_size is not provided as an input, use sane default.
1065- # If using very large dp_sizes, make buckets larger to ensure that chunks used in NCCL
1066- # ring-reduce implementations are large enough to remain bandwidth-bound rather than
1067- # latency-bound.
1068- if ddp_config .bucket_size is None :
1069- ddp_config .bucket_size = max (
1070- 40000000 , 1000000 * mpu .get_data_parallel_world_size (with_context_parallel = True )
1071- )
1072- # Set bucket_size to infinity if overlap_grad_reduce is False.
1073- if not ddp_config .overlap_grad_reduce :
1074- ddp_config .bucket_size = None
1075-
1046+ kwargs ['bucket_size' ] = args .ddp_bucket_size
1047+ kwargs ['pad_buckets_for_high_nccl_busbw' ] = args .ddp_pad_buckets_for_high_nccl_busbw
1048+ kwargs ['reduce_scatter_with_fp32_accumulation' ] = args .ddp_reduce_scatter_with_fp32_accumulation
1049+ kwargs ['average_in_collective' ] = args .ddp_average_in_collective
1050+ ddp_config = DistributedDataParallelConfig (** kwargs )
1051+
1052+ # In the Megatron FSDP and DDP use path, we need to initialize the bucket size.
1053+ # If bucket_size is not provided as an input, use sane default.
1054+ # If using very large dp_sizes, make buckets larger to ensure that chunks used in NCCL
1055+ # ring-reduce implementations are large enough to remain bandwidth-bound rather than
1056+ # latency-bound.
1057+ if ddp_config .bucket_size is None :
1058+ ddp_config .bucket_size = max (
1059+ 40000000 , 1000000 * mpu .get_data_parallel_world_size (with_context_parallel = True )
1060+ )
1061+ # Set bucket_size to infinity if overlap_grad_reduce is False.
1062+ if not ddp_config .overlap_grad_reduce :
1063+ ddp_config .bucket_size = None
1064+ # Setup stream for ddp initialization. The side-stream may be necessary for cuda graph
1065+ # capture support with DDP, but we sync it with the current stream to avoid races.
1066+ ddp_stream = torch .cuda .Stream ()
1067+ # Wait for the default stream to complete before starting ddp_stream
1068+ ddp_stream .wait_stream (torch .cuda .current_stream ())
1069+ # Make ddp_stream start after whatever the default stream already queued
1070+ with torch .cuda .stream (ddp_stream ):
10761071 model = [
10771072 DP (
10781073 config = config ,
@@ -1085,15 +1080,14 @@ def build_model():
10851080 )
10861081 for (model_chunk_idx , model_chunk ) in enumerate (model )
10871082 ]
1083+ # End of setup_stream
1084+ # Critical: ensure side-stream work completes before touching params on default stream
1085+ torch .cuda .current_stream ().wait_stream (ddp_stream )
10881086
1089- # Broadcast params from data parallel src rank to other data parallel ranks.
1090- if args .data_parallel_random_init :
1091- for model_module in model :
1092- model_module .broadcast_params ()
1093-
1094- # End of setup_stream
1095- # Critical: ensure side-stream work completes before touching params on default stream
1096- torch .cuda .current_stream ().wait_stream (setup_stream )
1087+ # Broadcast params from data parallel src rank to other data parallel ranks.
1088+ if args .data_parallel_random_init :
1089+ for model_module in model :
1090+ model_module .broadcast_params ()
10971091
10981092 return model
10991093
0 commit comments