Skip to content

Commit 4ba9f46

Browse files
authored
Reduce the scope of the side stream around DDP initialization (#2852)
Signed-off-by: John St. John <[email protected]>
1 parent 7e5e16b commit 4ba9f46

File tree

1 file changed

+111
-117
lines changed

1 file changed

+111
-117
lines changed

megatron/training/training.py

Lines changed: 111 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -955,124 +955,119 @@ def build_model():
955955
model.model_type = model_type
956956
return model
957957

958-
# Setup stream for model building/ddp initialization. The side-stream may be necessary for
959-
# cuda graph capture support with DDP, but we sync it with the current stream to avoid race
960-
# conditions.
961-
setup_stream = torch.cuda.Stream()
962-
# Wait for the default stream to complete before starting setup_stream
963-
setup_stream.wait_stream(torch.cuda.current_stream())
964-
# Make setup_stream start after whatever the default stream already queued
965-
with torch.cuda.stream(setup_stream):
966-
if args.init_model_with_meta_device:
967-
with torch.device('meta'):
968-
model = build_model()
969-
else:
958+
959+
if args.init_model_with_meta_device:
960+
with torch.device('meta'):
970961
model = build_model()
962+
else:
963+
model = build_model()
971964

972-
if not isinstance(model, list):
973-
model = [model]
965+
if not isinstance(model, list):
966+
model = [model]
974967

975-
# Set tensor model parallel attributes if not set.
976-
# Only parameters that are already tensor model parallel have these
977-
# attributes set for them. We should make sure the default attributes
978-
# are set for all params so the optimizer can use them.
979-
for model_module in model:
980-
for param in model_module.parameters():
981-
tensor_parallel.set_defaults_if_not_set_tensor_model_parallel_attributes(param)
968+
# Set tensor model parallel attributes if not set.
969+
# Only parameters that are already tensor model parallel have these
970+
# attributes set for them. We should make sure the default attributes
971+
# are set for all params so the optimizer can use them.
972+
for model_module in model:
973+
for param in model_module.parameters():
974+
tensor_parallel.set_defaults_if_not_set_tensor_model_parallel_attributes(param)
982975

983-
# Print number of parameters.
984-
num_parameters = sum(
985-
[sum([p.nelement() for p in model_module.parameters()]) for model_module in model]
976+
# Print number of parameters.
977+
num_parameters = sum(
978+
[sum([p.nelement() for p in model_module.parameters()]) for model_module in model]
979+
)
980+
if get_pg_rank(pg_collection.dp) == 0 and get_pg_rank(pg_collection.cp) == 0:
981+
print(
982+
' > number of parameters on (tensor, pipeline) '
983+
'model parallel rank ({}, {}): {}'.format(
984+
get_pg_rank(pg_collection.tp),
985+
get_pg_rank(pg_collection.pp),
986+
num_parameters,
987+
),
988+
flush=True,
986989
)
987-
if get_pg_rank(pg_collection.dp) == 0 and get_pg_rank(pg_collection.cp) == 0:
988-
print(
989-
' > number of parameters on (tensor, pipeline) '
990-
'model parallel rank ({}, {}): {}'.format(
991-
get_pg_rank(pg_collection.tp),
992-
get_pg_rank(pg_collection.pp),
993-
num_parameters,
994-
),
995-
flush=True,
996-
)
997-
998-
# GPU allocation.
999-
# For FSDP2, we don't allocate GPU memory here. We allocate GPU memory
1000-
# in the fully_shard function of FSDP2 instead.
1001-
if (
1002-
not (args.use_torch_fsdp2 and args.use_cpu_initialization)
1003-
and not args.init_model_with_meta_device
1004-
):
1005-
for model_module in model:
1006-
model_module.cuda(torch.cuda.current_device())
1007-
1008-
# Fp16 conversion.
1009-
if args.fp16 or args.bf16:
1010-
config = get_model_config(model[0])
1011-
model = [Float16Module(config, model_module) for model_module in model]
1012-
1013-
# Materialize tensors on meta device (GPU allocation) if not using FSDP2 and not using Megatron FSDP.
1014-
if args.init_model_with_meta_device and not args.use_torch_fsdp2 and not args.use_megatron_fsdp:
1015-
#for model_module in model:
1016-
model = [to_empty_if_meta_device(model_module, device=torch.device("cuda")) for model_module in model]
1017-
1018990

991+
# GPU allocation.
992+
# For FSDP2, we don't allocate GPU memory here. We allocate GPU memory
993+
# in the fully_shard function of FSDP2 instead.
994+
if (
995+
not (args.use_torch_fsdp2 and args.use_cpu_initialization)
996+
and not args.init_model_with_meta_device
997+
):
998+
for model_module in model:
999+
model_module.cuda(torch.cuda.current_device())
1000+
1001+
# Fp16 conversion.
1002+
if args.fp16 or args.bf16:
1003+
config = get_model_config(model[0])
1004+
model = [Float16Module(config, model_module) for model_module in model]
1005+
1006+
# Materialize tensors on meta device (GPU allocation) if not using FSDP2 and not using Megatron FSDP.
1007+
if args.init_model_with_meta_device and not args.use_torch_fsdp2 and not args.use_megatron_fsdp:
1008+
model = [to_empty_if_meta_device(model_module, device=torch.device("cuda")) for model_module in model]
1009+
1010+
# Before TE2.x: The model_module.bfloat16()/model_module.half() above will call the inplace
1011+
# copy of TE's Float8Tensor, which will write an unwanted value (amax calculated
1012+
# from the current fp8 param) to its amax_history. The below function will correct
1013+
# the amax_history back.
1014+
# After TE2.x: Below function is an empty function and does nothing.
1015+
correct_amax_history_if_needed(model)
1016+
1017+
if wrap_with_ddp:
1018+
if args.use_torch_fsdp2:
1019+
assert HAVE_FSDP2, "Torch FSDP2 requires torch>=2.4.0"
1020+
DP = torch_FSDP
1021+
elif args.use_megatron_fsdp:
1022+
DP = megatron_FSDP
1023+
else:
1024+
DP = DDP
10191025

1026+
config = get_model_config(model[0])
10201027

1021-
# Before TE2.x: The model_module.bfloat16()/model_module.half() above will call the inplace
1022-
# copy of TE's Float8Tensor, which will write an unwanted value (amax calculated
1023-
# from the current fp8 param) to its amax_history. The below function will correct
1024-
# the amax_history back.
1025-
# After TE2.x: Below function is an empty function and does nothing.
1026-
correct_amax_history_if_needed(model)
1027-
1028-
if wrap_with_ddp:
1029-
if args.use_torch_fsdp2:
1030-
assert HAVE_FSDP2, "Torch FSDP2 requires torch>=2.4.0"
1031-
DP = torch_FSDP
1032-
elif args.use_megatron_fsdp:
1033-
DP = megatron_FSDP
1034-
else:
1035-
DP = DDP
1036-
1037-
config = get_model_config(model[0])
1038-
1039-
if getattr(args, "use_torch_fsdp2", False):
1040-
reshard_after_forward = getattr(args, "torch_fsdp2_reshard_after_forward", True)
1041-
ddp_config = TorchFullyShardedDataParallelConfig(reshard_after_forward=reshard_after_forward)
1028+
if getattr(args, "use_torch_fsdp2", False):
1029+
reshard_after_forward = getattr(args, "torch_fsdp2_reshard_after_forward", True)
1030+
ddp_config = TorchFullyShardedDataParallelConfig(reshard_after_forward=reshard_after_forward)
1031+
else:
1032+
kwargs = {}
1033+
for f in dataclasses.fields(DistributedDataParallelConfig):
1034+
if hasattr(args, f.name):
1035+
kwargs[f.name] = getattr(args, f.name)
1036+
kwargs['grad_reduce_in_fp32'] = args.accumulate_allreduce_grads_in_fp32
1037+
kwargs['check_for_nan_in_grad'] = args.check_for_nan_in_loss_and_grad
1038+
kwargs['check_for_large_grads'] = args.check_for_large_grads
1039+
if args.ddp_num_buckets is not None:
1040+
assert args.ddp_bucket_size is None, \
1041+
"Cannot specify both --ddp-num-buckets and --ddp-bucket-size"
1042+
assert args.ddp_num_buckets > 0, \
1043+
"--ddp-num-buckets must be greater than 0"
1044+
kwargs['bucket_size'] = num_parameters // args.ddp_num_buckets
10421045
else:
1043-
kwargs = {}
1044-
for f in dataclasses.fields(DistributedDataParallelConfig):
1045-
if hasattr(args, f.name):
1046-
kwargs[f.name] = getattr(args, f.name)
1047-
kwargs['grad_reduce_in_fp32'] = args.accumulate_allreduce_grads_in_fp32
1048-
kwargs['check_for_nan_in_grad'] = args.check_for_nan_in_loss_and_grad
1049-
kwargs['check_for_large_grads'] = args.check_for_large_grads
1050-
if args.ddp_num_buckets is not None:
1051-
assert args.ddp_bucket_size is None, \
1052-
"Cannot specify both --ddp-num-buckets and --ddp-bucket-size"
1053-
assert args.ddp_num_buckets > 0, \
1054-
"--ddp-num-buckets must be greater than 0"
1055-
kwargs['bucket_size'] = num_parameters // args.ddp_num_buckets
1056-
else:
1057-
kwargs['bucket_size'] = args.ddp_bucket_size
1058-
kwargs['pad_buckets_for_high_nccl_busbw'] = args.ddp_pad_buckets_for_high_nccl_busbw
1059-
kwargs['reduce_scatter_with_fp32_accumulation'] = args.ddp_reduce_scatter_with_fp32_accumulation
1060-
kwargs['average_in_collective'] = args.ddp_average_in_collective
1061-
ddp_config = DistributedDataParallelConfig(**kwargs)
1062-
1063-
# In the Megatron FSDP and DDP use path, we need to initialize the bucket size.
1064-
# If bucket_size is not provided as an input, use sane default.
1065-
# If using very large dp_sizes, make buckets larger to ensure that chunks used in NCCL
1066-
# ring-reduce implementations are large enough to remain bandwidth-bound rather than
1067-
# latency-bound.
1068-
if ddp_config.bucket_size is None:
1069-
ddp_config.bucket_size = max(
1070-
40000000, 1000000 * mpu.get_data_parallel_world_size(with_context_parallel=True)
1071-
)
1072-
# Set bucket_size to infinity if overlap_grad_reduce is False.
1073-
if not ddp_config.overlap_grad_reduce:
1074-
ddp_config.bucket_size = None
1075-
1046+
kwargs['bucket_size'] = args.ddp_bucket_size
1047+
kwargs['pad_buckets_for_high_nccl_busbw'] = args.ddp_pad_buckets_for_high_nccl_busbw
1048+
kwargs['reduce_scatter_with_fp32_accumulation'] = args.ddp_reduce_scatter_with_fp32_accumulation
1049+
kwargs['average_in_collective'] = args.ddp_average_in_collective
1050+
ddp_config = DistributedDataParallelConfig(**kwargs)
1051+
1052+
# In the Megatron FSDP and DDP use path, we need to initialize the bucket size.
1053+
# If bucket_size is not provided as an input, use sane default.
1054+
# If using very large dp_sizes, make buckets larger to ensure that chunks used in NCCL
1055+
# ring-reduce implementations are large enough to remain bandwidth-bound rather than
1056+
# latency-bound.
1057+
if ddp_config.bucket_size is None:
1058+
ddp_config.bucket_size = max(
1059+
40000000, 1000000 * mpu.get_data_parallel_world_size(with_context_parallel=True)
1060+
)
1061+
# Set bucket_size to infinity if overlap_grad_reduce is False.
1062+
if not ddp_config.overlap_grad_reduce:
1063+
ddp_config.bucket_size = None
1064+
# Setup stream for ddp initialization. The side-stream may be necessary for cuda graph
1065+
# capture support with DDP, but we sync it with the current stream to avoid races.
1066+
ddp_stream = torch.cuda.Stream()
1067+
# Wait for the default stream to complete before starting ddp_stream
1068+
ddp_stream.wait_stream(torch.cuda.current_stream())
1069+
# Make ddp_stream start after whatever the default stream already queued
1070+
with torch.cuda.stream(ddp_stream):
10761071
model = [
10771072
DP(
10781073
config=config,
@@ -1085,15 +1080,14 @@ def build_model():
10851080
)
10861081
for (model_chunk_idx, model_chunk) in enumerate(model)
10871082
]
1083+
# End of setup_stream
1084+
# Critical: ensure side-stream work completes before touching params on default stream
1085+
torch.cuda.current_stream().wait_stream(ddp_stream)
10881086

1089-
# Broadcast params from data parallel src rank to other data parallel ranks.
1090-
if args.data_parallel_random_init:
1091-
for model_module in model:
1092-
model_module.broadcast_params()
1093-
1094-
# End of setup_stream
1095-
# Critical: ensure side-stream work completes before touching params on default stream
1096-
torch.cuda.current_stream().wait_stream(setup_stream)
1087+
# Broadcast params from data parallel src rank to other data parallel ranks.
1088+
if args.data_parallel_random_init:
1089+
for model_module in model:
1090+
model_module.broadcast_params()
10971091

10981092
return model
10991093

0 commit comments

Comments
 (0)