Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
424 changes: 116 additions & 308 deletions megatron/core/datasets/data_schedule.py

Large diffs are not rendered by default.

519 changes: 510 additions & 9 deletions megatron/core/datasets/data_schedule_utils.py

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions megatron/core/extensions/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1371,6 +1371,8 @@ def forward(
packed_seq_params.cp_group is not None
), "cp_group is not set in packed_seq_params for dynamic CP"
self.cp_group = packed_seq_params.cp_group
if TEDotProductAttention.cp_stream is None:
TEDotProductAttention.cp_stream = torch.cuda.Stream()
super().set_context_parallel_group(
self.cp_group,
torch.distributed.get_process_group_ranks(self.cp_group),
Expand Down
38 changes: 37 additions & 1 deletion megatron/core/model_parallel_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,18 @@ class ModelParallelConfig:
Please set max_seqlen_per_dp_cp_rank when using dynamic_context_parallel.
"""

min_dynamic_context_parallel_size: int = 1
"""Minimum CP group size for dynamic context parallel. Default 1 (no CP).
The maximum is always context_parallel_size."""

hybrid_context_parallel: bool = False
"""Deprecated. Use ``dynamic_context_parallel`` instead."""

sequence_packing_scheduler: Optional[Literal['dp_balanced']] = None
sequence_packing_scheduler: Optional[Literal['dp_balanced', 'default_dynamic_cp']] = None
"""
Scheduler for sequence packing and dynamic context parallel.
dp_balanced: DP-balanced scheduler for sequence packing.
default_dynamic_cp: Dynamic-CP scheduler for packed sequence balancing.
"""

expert_model_parallel_size: int = 1
Expand Down Expand Up @@ -428,6 +433,37 @@ def __post_init__(self):
)
self.dynamic_context_parallel = True

if self.dynamic_context_parallel:
if self.sequence_packing_scheduler is None:
self.sequence_packing_scheduler = 'default_dynamic_cp'
if self.sequence_packing_scheduler != 'default_dynamic_cp':
raise ValueError(
'Dynamic context parallelism requires '
'sequence_packing_scheduler=default_dynamic_cp'
)

if self.min_dynamic_context_parallel_size < 1:
raise ValueError(
f"min_dynamic_context_parallel_size must be >= 1, "
f"got {self.min_dynamic_context_parallel_size}"
)

if self.min_dynamic_context_parallel_size > self.context_parallel_size:
raise ValueError(
f"min_dynamic_context_parallel_size ({self.min_dynamic_context_parallel_size}) "
f"must be <= context_parallel_size ({self.context_parallel_size}), "
f"since context_parallel_size is the maximum dynamic CP group size."
)

if self.min_dynamic_context_parallel_size > 1:
warnings.warn(
f"min_dynamic_context_parallel_size is set to {self.min_dynamic_context_parallel_size}. "
f"Dynamic CP groups will range from {self.min_dynamic_context_parallel_size} "
f"to {self.context_parallel_size} (context_parallel_size). "
f"This may cause padding overhead for short sequences.",
UserWarning,
)

if self.sequence_parallel:
if self.tensor_model_parallel_size <= 1:
raise ValueError("Cannot use sequence parallelism without tensor parallelism")
Expand Down
35 changes: 24 additions & 11 deletions megatron/core/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,16 +421,20 @@ def create_hierarchical_groups(
return hierarchical_groups, hierarchical_groups_gloo


def create_dynamic_dp_cp_groups(rank, ranks, pg_options):
def create_dynamic_dp_cp_groups(rank, ranks, pg_options, min_cp_size=1, max_cp_size=None):
"""
Creates groups required for dynamic DPxCP.
Creates a new group for every power of 2 up to the number of DPxCP ranks.
Creates a new group for every power of 2 from min_cp_size up to max_cp_size.
max_cp_size defaults to len(ranks) (the full DPxCP group size).
Returns a dictionary indexed by group size.
"""
if max_cp_size is None:
max_cp_size = len(ranks)
dynamic_dp_cp_groups = {}
# Generate group for every power of 2 up to the number of CP ranks
# We limit the allowed group sizes in order to avoid excessive overhead.
group_sizes = [2**i for i in range(int(log2(len(ranks))))]
group_sizes = [
2**i for i in range(int(log2(len(ranks))))
if 2**i >= min_cp_size and 2**i <= max_cp_size
]
for group_size in group_sizes:
for i in range(0, len(ranks), group_size):
group = create_group(
Expand Down Expand Up @@ -556,6 +560,7 @@ def initialize_model_parallel(
context_parallel_size: int = 1,
hierarchical_context_parallel_sizes: Optional[List[int]] = None,
dynamic_context_parallel: bool = False,
min_dynamic_context_parallel_size: int = 1,
expert_model_parallel_size: int = 1,
num_distributed_optimizer_instances: int = 1,
expert_tensor_parallel_size: Optional[int] = None,
Expand Down Expand Up @@ -946,16 +951,21 @@ def initialize_model_parallel(
), "Dynamic context parallel requires an even number of ranks"
_DYNAMIC_DP_CP_GROUPS.update(
create_dynamic_dp_cp_groups(
rank, ranks_with_cp, get_nccl_options("dp_cp", nccl_comm_cfgs)
rank,
ranks_with_cp,
get_nccl_options("dp_cp", nccl_comm_cfgs),
min_cp_size=min_dynamic_context_parallel_size,
max_cp_size=context_parallel_size,
)
)

# PyTorch is performing lazy initialization of the communicator group.
# Therefore, we need to perform a nccl call to ensure that the communicator group is created.
data_parallel_size_with_cp = data_parallel_size * context_parallel_size
group_sizes = [2**i for i in range(0, int(log2(data_parallel_size_with_cp)))]
if group_sizes[-1] * 2 == data_parallel_size_with_cp:
group_sizes.append(data_parallel_size_with_cp)
group_sizes = [
2**i for i in range(int(log2(data_parallel_size_with_cp)))
if 2**i >= min_dynamic_context_parallel_size and 2**i <= context_parallel_size
]
if context_parallel_size == data_parallel_size_with_cp:
group_sizes.append(context_parallel_size)
for group_size in group_sizes:
group = get_dynamic_data_context_parallel_groups(group_size=group_size)
torch.distributed.barrier(group=group, device_ids=[torch.cuda.current_device()])
Expand Down Expand Up @@ -2101,6 +2111,9 @@ def destroy_model_parallel():
global _CONTEXT_PARALLEL_GLOBAL_RANKS
_CONTEXT_PARALLEL_GLOBAL_RANKS = None

global _DYNAMIC_DP_CP_GROUPS
_DYNAMIC_DP_CP_GROUPS = {}

global _EMBEDDING_GROUP
_EMBEDDING_GROUP = None

Expand Down
Loading