diff --git a/megatron/core/datasets/data_schedule.py b/megatron/core/datasets/data_schedule.py index 45b78e625c6..d43598fb618 100644 --- a/megatron/core/datasets/data_schedule.py +++ b/megatron/core/datasets/data_schedule.py @@ -1,317 +1,29 @@ # Copyright (c) 2025 NVIDIA CORPORATION. All rights reserved. -from typing import Any, Dict, List, Optional, Type +from typing import Dict, Optional, Type import torch from megatron.core import parallel_state from megatron.core.datasets.data_schedule_utils import ( + align_sample_id_groups, broadcast_scalars, broadcast_tensor, broadcast_to_pp_group, build_packed_microbatches, create_data_iterator, + dcp_get_total_workload, + dcp_gpus_needed, + dcp_make_buckets_equal, get_batch_and_global_seqlens, get_cp_slice_for_thd, + next_hdp_group, reroute_samples_to_dcp_ranks, ) from megatron.core.packed_seq_params import PackedSeqParams -from megatron.core.pipeline_parallel.dynamic_cp_schedule import BalancedCPScheduler from megatron.core.process_groups_config import ProcessGroupCollection -class DynamicCPDataLoaderWrapper: - """ - A wrapper class that wraps around an existing data_iterator. - For every __next__ call, - 1. Each DP rank pulls a batch of packed samples. - 2. Extracts the sequence lengths of each sub-sample and all-gathers across the DP group. - 3. Schedules the sub-samples to the DPxCP ranks using the BalancedCPScheduler. - 4. Based on the schedule, reroutes the sub-samples to the correct rank using all-to-all. - 5. Returns the assigned sub-samples to this rank. - - Args: - data_iterator: The original data_iterator to wrap around - config: The config object containing the max_seqlen_per_dp_cp_rank - dp_cp_group: Data parallel context parallel group. - """ - - def __init__( - self, data_iterator, config, pg_collection: Optional[ProcessGroupCollection] = None - ): - self.data_iterator = data_iterator - self.config = config - if pg_collection is None: - self.dp_cp_group = parallel_state.get_data_parallel_group(with_context_parallel=True) - self.dp_group = parallel_state.get_data_parallel_group() - self.tp_group = parallel_state.get_tensor_model_parallel_group() - else: - self.dp_cp_group = pg_collection.dp_cp - self.dp_group = pg_collection.dp - self.tp_group = pg_collection.tp - assert ( - self.dp_cp_group is not None and self.dp_group is not None and self.tp_group is not None - ), "dp_cp_group, dp_group, tp_group must not be None when using dynamic context parallel" - - self.cp_balancing_scheduler = BalancedCPScheduler( - max_seq_len_per_rank=self.config.max_seqlen_per_dp_cp_rank, dp_cp_group=self.dp_cp_group - ) - - self.total_hdp_gpus = self.dp_cp_group.size() - - def __iter__(self): - """Return self as an iterator.""" - return self - - def get_global_seqlens(self, subsample_seqlens: torch.Tensor) -> List[int]: - """ - Gathers the sequence lengths of all subsamples from all DP ranks. - Each DP rank loads the same number of microbatches but each microbatch - may have a different number of subsamples. - - We find the number of subsamples each rank holds and then gather the - sequence lengths of all subsamples from all ranks. - """ - # Collect the number of subsamples from all ranks - local_len = torch.tensor([subsample_seqlens.shape[0]], dtype=torch.int32).cuda() - dp_subsample_count = [torch.zeros_like(local_len) for _ in range(self.dp_group.size())] - torch.distributed.all_gather(dp_subsample_count, local_len, group=self.dp_group) - - # Find the max number of subsamples across all ranks and pad subsample_seqlens to max length - dp_subsample_counts = torch.stack(dp_subsample_count, dim=0).cpu().view(-1) - max_sub_samples = int(dp_subsample_counts.max().item()) - - if local_len.item() < max_sub_samples: - subsample_seqlens_padded = torch.cat( - [ - subsample_seqlens, - torch.zeros(max_sub_samples - local_len.item(), dtype=torch.int32).cuda(), - ], - dim=0, - ) - else: - subsample_seqlens_padded = subsample_seqlens - - # Gather the subsample_seqlens from all ranks - seqlens_gathered = [ - torch.empty_like(subsample_seqlens_padded) for _ in range(self.dp_group.size()) - ] - torch.distributed.all_gather( - seqlens_gathered, subsample_seqlens_padded, group=self.dp_group - ) - - # Trim each seqlens_gathered to the length of the correct sample - for dp_rank, seqlen in enumerate(seqlens_gathered): - seqlens_gathered[dp_rank] = seqlen[: dp_subsample_counts[dp_rank]] - - seqlens_gathered = torch.cat(seqlens_gathered, dim=0) - seqlens_gathered = seqlens_gathered.cpu().tolist() - - # Calculate the offsets to assign unique global ID to each subsample. - csum = torch.cumsum(dp_subsample_counts, dim=0, dtype=torch.int32) - offsets = torch.cat([torch.zeros(1, dtype=torch.int32), csum[:-1]], dim=0) - - return seqlens_gathered, offsets - - def get_global_id_seqlens(self, num_local_subsamples, offsets, seqlens_gathered): - """ - Calculates the global ID for each subsample. - - We assign a unique global ID to each subsample. - - Returns: - global_id_seqlens: list of (global_id, seqlen) tuples for scheduling. - global_ids_this_rank: list of global IDs locally present on this rank. - """ - dp_rank = self.dp_group.rank() - global_ids = torch.arange(len(seqlens_gathered), dtype=torch.int32).cuda() - # Create a list of (global_id, seqlen) tuples for scheduling - global_id_seqlens = [(i, seqlens_gathered[i]) for i in range(len(global_ids))] - # Get the global IDs locally present on this rank - global_ids_this_rank = global_ids[ - offsets[dp_rank] : offsets[dp_rank] + num_local_subsamples - ] - - return global_id_seqlens, global_ids_this_rank - - def _gid_to_src_rank(self, gid: int, offsets: List[int]) -> int: - dp_src_rank = torch.bucketize(gid, offsets[1:] - 1) - # Since the torch.distributed.get_process_group_ranks - # provides the global rank, we need to consider TP - hdp_rank = ( - torch.distributed.get_process_group_ranks(self.dp_group)[dp_src_rank] - // self.tp_group.size() - ) - return hdp_rank - - def reroute_samples_to_hdp_ranks( - self, batch, global_ids_this_rank, global_id_seqlens, sample_id_groups, offsets - ): - """ - Reroutes the sub-samples to the correct rank after scheduling. - - For each key in the batch dict, we perform an all-to-all communication - to transfer the data to the correct ranks. - Since all CP ranks within a DP group have the same data, we only need - to transfer data between matching CP ranks. - """ - gid2local_id = {int(gid): i for i, gid in enumerate(global_ids_this_rank)} - hdp_rank = self.dp_cp_group.rank() - dp_ranks = torch.distributed.get_process_group_ranks(self.dp_group) - # Here we actually want to get the DP group's rank within the HDP group, - # we need to consider TP - dp_ranks = [r // self.tp_group.size() for r in dp_ranks] - - data_keys = batch[0].keys() - - # Create the send plan - combined_sample_id_groups: List[List[int]] = [[] for _ in range(self.total_hdp_gpus)] - - for d in range(self.total_hdp_gpus): - for sample_id_group in sample_id_groups: - combined_sample_id_groups[d].extend(sample_id_group[d]) - - for dest_rank in range(self.total_hdp_gpus): - combined_sample_id_groups[dest_rank].sort() - - # Filter out samples that are not present on this rank - send_ids_sorted = [ - gid - for d in dp_ranks - for gid in combined_sample_id_groups[d] - if gid in global_ids_this_rank - ] - # send_counts = [len(combined_sample_id_groups[d]) for d in range(self.total_hdp_gpus)] - - send_lens_split = [0] * self.total_hdp_gpus - for dest_rank in range(self.total_hdp_gpus): - if dest_rank in dp_ranks: - send_lens_split[dest_rank] = sum( - [ - global_id_seqlens[gid][1] - for gid in combined_sample_id_groups[dest_rank] - if gid in global_ids_this_rank - ] - ) - else: - # We only need to share local data with DP ranks that have different data. - send_lens_split[dest_rank] = 0 - - # Create the recv plan - recv_sample_id_groups = [[] for _ in range(self.total_hdp_gpus)] - for gid in combined_sample_id_groups[hdp_rank]: - src_rank = self._gid_to_src_rank(gid, offsets) - recv_sample_id_groups[src_rank].append(gid) - - recv_lens_split = [0] * self.total_hdp_gpus - for src_rank in range(self.total_hdp_gpus): - recv_lens_split[src_rank] = sum( - [global_id_seqlens[gid][1] for gid in recv_sample_id_groups[src_rank]] - ) - - recv_ids_sorted = [ - gid for d in range(self.total_hdp_gpus) for gid in recv_sample_id_groups[d] - ] - recv_counts = [len(recv_sample_id_groups[d]) for d in range(self.total_hdp_gpus)] - - recv_samples = [{k: None for k in data_keys} for _ in range(sum(recv_counts))] - - def _pack_sample_by_key(key: str) -> torch.Tensor: - flattened_tensors = [] - for gid in send_ids_sorted: - t = batch[gid2local_id[gid]][key].to(torch.cuda.current_device(), non_blocking=True) - flattened_tensors.append(t) - return ( - torch.cat(flattened_tensors, dim=0) - if flattened_tensors - else torch.empty(0, device=torch.cuda.current_device(), dtype=batch[0][key].dtype) - ) - - def _unpack_sample_by_key(key: str, recv_tensor: torch.Tensor): - cursor = 0 - for i, gid in enumerate(recv_ids_sorted): - sample_len = global_id_seqlens[gid][1] - recv_samples[i][key] = recv_tensor[cursor : cursor + sample_len] - cursor += sample_len - - for key in data_keys: - send_tensor = _pack_sample_by_key(key) - recv_tensor = torch.empty( - sum(recv_lens_split), device=torch.cuda.current_device(), dtype=send_tensor.dtype - ) - torch.distributed.all_to_all_single( - output=recv_tensor, - input=send_tensor, - output_split_sizes=recv_lens_split, - input_split_sizes=send_lens_split, - group=self.dp_cp_group, - ) - _unpack_sample_by_key(key, recv_tensor) - - recv_sample_with_id = { - recv_id: recv_samples[i] for i, recv_id in enumerate(recv_ids_sorted) - } - return recv_sample_with_id - - def unpack_batch(self, batch): - """ - Unpacks the packed samples into a list of sub-samples. - Since each sub-sample may be routed to different DPxCP ranks, - we unpack the sample here to avoid unnecessarily transferring - the entire packed sample. - """ - batch_unpacked = [] - for sample in batch: - for sub_sample in range(sample["cu_seqlens"].shape[0] - 1): - sub_sample_dict = {} - start_idx = sample["cu_seqlens"][sub_sample] - end_idx = sample["cu_seqlens"][sub_sample + 1] - if end_idx - start_idx == 0: - continue - for key in sample.keys(): - if key in ["cu_seqlens", "batch_idx", "max_seqlen"]: - continue - sub_sample_dict[key] = sample[key][start_idx:end_idx] - batch_unpacked.append(sub_sample_dict) - return batch_unpacked - - def __next__(self) -> Any: - """ - Get the next item from the dataset, pull scheduling metadata and return it. - """ - if self.data_iterator is None: - # TP0 reads from data_iterator, others receive via broadcast. - return None, None - else: - batch = next(self.data_iterator) - subsample_seqlens = [] - for sample in batch: - subsample_seqlens.extend( - [ - int(sample["cu_seqlens"][i + 1] - sample["cu_seqlens"][i]) - for i in range(0, sample["cu_seqlens"].shape[0] - 1) - ] - ) - subsample_seqlens = torch.tensor(subsample_seqlens, dtype=torch.int32).cuda() - subsample_seqlens = subsample_seqlens[subsample_seqlens != 0] - - seqlens_gathered, offsets = self.get_global_seqlens(subsample_seqlens) - - global_id_seqlens, global_ids_this_rank = self.get_global_id_seqlens( - subsample_seqlens.shape[0], offsets, seqlens_gathered - ) - - groups, sample_id_groups = self.cp_balancing_scheduler.get_groups_and_subsamples( - global_id_seqlens, self.config - ) - - batch = self.unpack_batch(batch) - samples_this_rank_with_id = self.reroute_samples_to_hdp_ranks( - batch, global_ids_this_rank, global_id_seqlens, sample_id_groups, offsets - ) - return samples_this_rank_with_id, sample_id_groups - - class BasePackingScheduler: """Base class for sequence packing schedulers.""" @@ -382,6 +94,7 @@ class DpBalancedScheduler(BasePackingScheduler): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.max_seq_len_all_ranks = self.max_seqlen_per_dp_cp_rank * self.cp_size + self.is_dynamic_cp = False def get_required_sample_keys(self): """Return the required key of each batch.""" @@ -556,16 +269,10 @@ def run( dcp_rank = dp_cp_group.rank() num_micro_batches = len(sample_id_groups) - grouped_samples = [ - [ - samples_this_rank_with_id[sub_sample_id] - for sub_sample_id in sample_id_groups[i][dcp_rank] - ] - for i in range(num_micro_batches) - ] - # Step 5: Build packed microbatches - new_samples = build_packed_microbatches(grouped_samples, dev) + new_samples = build_packed_microbatches( + samples_this_rank_with_id, sample_id_groups, dcp_rank, dev, self.is_dynamic_cp + ) # Step 6: Calculate FLOPs info seqlen_sum_this_global_batch = float(sum(seqlens_gathered)) @@ -594,6 +301,7 @@ def run( seqlen_squared_sum_this_global_batch, pp_group, dev, + is_dynamic_cp=self.is_dynamic_cp, ) # Step 8: Broadcast to TP group (for non-TP-0 ranks) @@ -611,7 +319,9 @@ def run( num_micro_batches = int(num_micro_batches) # Step 9: create data_iterator and handle VPP if enabled - new_data_iterator = create_data_iterator(new_samples, tp_group, config, vpp_has_data) + new_data_iterator = create_data_iterator( + new_samples, tp_group, config, vpp_has_data, self.is_dynamic_cp + ) return ( new_data_iterator, @@ -621,7 +331,66 @@ def run( ) -scheduler_map: Dict[str, Type[BasePackingScheduler]] = {"dp_balanced": DpBalancedScheduler} +class DefaultDynamicCPScheduler(DpBalancedScheduler): + """ + Dynamic CP scheduler that balances workload across variable CP sizes. + """ + + def __init__(self, *args, min_cp_size=1, max_cp_size=None, **kwargs): + super().__init__(*args, **kwargs) + self.is_dynamic_cp = True + self.max_seq_len_per_rank = self.max_seqlen_per_dp_cp_rank + self.total_hdp_gpus = self.dp_size * self.cp_size + self.min_cp_size = min_cp_size + self.max_cp_size = max_cp_size if max_cp_size is not None else self.cp_size + + def get_groups_and_subsamples(self, sample_id_seqlens): + """ + This function recursively forms groups of sub-samples such that all DPxCP ranks + have a roughly balanced workload in the group. + """ + mslpr = self.max_seq_len_per_rank + min_cp = self.min_cp_size + max_cp = self.max_cp_size + workload_fn = lambda seq_len, cp_size=None: dcp_get_total_workload( + seq_len, mslpr, cp_size, min_cp, max_cp + ) + gpus_fn = lambda seq_len: dcp_gpus_needed(seq_len, mslpr, min_cp, max_cp) + buckets_fn = lambda sample_seqlens, compute_est: dcp_make_buckets_equal( + sample_seqlens, compute_est, mslpr, min_cp, max_cp + ) + + groups = [] + sample_id_groups = [] + sample_id_seqlens = sorted(sample_id_seqlens, key=lambda x: x[1], reverse=True) + while sample_id_seqlens: + mb, sample_id_seqlens, exec_times, sample_ids = next_hdp_group( + sample_id_seqlens, + workload_fn, + self.total_hdp_gpus, + gpus_needed_fn=gpus_fn, + make_buckets_equal_fn=buckets_fn, + max_seq_len_per_rank=mslpr, + get_total_workload_fn=workload_fn, + ) + groups.append(mb) + sample_id_groups.append(sample_ids) + + if ( + self.microbatch_group_size_per_vp_stage is not None + and self.microbatch_group_size_per_vp_stage > 1 + ): + sample_id_groups = align_sample_id_groups( + sample_id_groups, self.microbatch_group_size_per_vp_stage + ) + + return sample_id_groups + + +scheduler_map: Dict[str, Type[BasePackingScheduler]] = { + "dp_balanced": DpBalancedScheduler, + "default_dynamic_cp": DefaultDynamicCPScheduler, +} def wrap_data_iterator( @@ -662,16 +431,21 @@ def wrap_data_iterator( # Look up the scheduler class by name scheduler_type = config.sequence_packing_scheduler + scheduler_kwargs = {} + if scheduler_type == 'default_dynamic_cp': + scheduler_kwargs['min_cp_size'] = config.min_dynamic_context_parallel_size + scheduler_kwargs['max_cp_size'] = cp_size + scheduler = scheduler_map[scheduler_type]( config.max_seqlen_per_dp_cp_rank, cp_size, dp_size, - # When VPP is enabled, align num_micro_batches to this multiple. ( None if config.virtual_pipeline_model_parallel_size is None else config.microbatch_group_size_per_vp_stage ), + **scheduler_kwargs, ) ( @@ -696,6 +470,7 @@ def get_batch_on_this_rank_for_sequence_packing( vpp_size: Optional[int] = None, mtp_on_this_rank: bool = False, vp_stage: Optional[int] = None, + dynamic_cp: bool = False, pg_collection: Optional[ProcessGroupCollection] = None, ): """ @@ -730,6 +505,8 @@ def get_batch_on_this_rank_for_sequence_packing( # data_iterator should return a batch including the following keys. batch_keys = ['cu_seqlens', 'cu_seqlens_padded', 'max_seqlen'] + if dynamic_cp: + batch_keys.append('local_cp_size') if is_first_stage or mtp_on_this_rank: batch_keys.append('tokens') batch_keys.append('position_ids') @@ -747,6 +524,15 @@ def get_batch_on_this_rank_for_sequence_packing( assert data_iterator is None, "Non TP 0 rank should not have data_iterator" batch = {} + # For dynamic CP, determine the correct cp_group from batch on TP rank 0. + if dynamic_cp and is_tp_rank_0: + local_cp_size_val = batch['local_cp_size'] + if isinstance(local_cp_size_val, torch.Tensor): + local_cp_size_val = local_cp_size_val.item() + cp_group = parallel_state.get_dynamic_data_context_parallel_groups( + group_size=local_cp_size_val + ) + # Partition tokens, position_ids, labels, loss_mask for context parallel. # Only TP rank 0 on stages that have data (first/last PP stage or MTP stage) needs this. if is_tp_rank_0 and (is_first_or_last_stage or mtp_on_this_rank): @@ -819,6 +605,21 @@ def get_batch_on_this_rank_for_sequence_packing( batch['cu_seqlens_padded'] = torch.empty([cu_seqlen_size], dtype=torch.int32, device=dev) batch['max_seqlen'] = torch.empty(1, dtype=torch.int32, device=dev) + # Step4: Prepare "local_cp_size" if dynamic context parallel is enabled. + if dynamic_cp: + if is_tp_rank_0: + if type(batch['local_cp_size']) == int: + batch['local_cp_size'] = torch.tensor( + batch['local_cp_size'], dtype=torch.int32, device=dev + ) + else: + assert batch['local_cp_size'].dtype == torch.int32 + assert batch['local_cp_size'].numel() == 1 + else: + batch['local_cp_size'] = torch.empty(1, dtype=torch.int32, device=dev) + else: + batch['local_cp_size'] = None + # Broadcast batch inside TP group. broadcast_tensor(batch['tokens'], tp_src_rank, tp_group) broadcast_tensor(batch['position_ids'], tp_src_rank, tp_group) @@ -827,6 +628,7 @@ def get_batch_on_this_rank_for_sequence_packing( broadcast_tensor(batch['cu_seqlens'], tp_src_rank, tp_group) broadcast_tensor(batch['cu_seqlens_padded'], tp_src_rank, tp_group) broadcast_tensor(batch['max_seqlen'], tp_src_rank, tp_group) + broadcast_tensor(batch['local_cp_size'], tp_src_rank, tp_group) # Extract the data from batch after broadcasting. tokens = batch['tokens'] @@ -836,6 +638,12 @@ def get_batch_on_this_rank_for_sequence_packing( cu_seqlens = batch['cu_seqlens'] cu_seqlens_padded = batch['cu_seqlens_padded'] max_seqlen = batch['max_seqlen'].item() + local_cp_size = batch['local_cp_size'].item() if dynamic_cp else None + cp_group = ( + parallel_state.get_dynamic_data_context_parallel_groups(group_size=local_cp_size) + if dynamic_cp + else None + ) # Transformer Engine has a bug of cu_seqlens, we must treat cu_seqlens_padded as cu_seqlens to # get the correct result. @@ -848,8 +656,8 @@ def get_batch_on_this_rank_for_sequence_packing( cu_seqlens_kv_padded=cu_seqlens_padded, max_seqlen_q=max_seqlen, max_seqlen_kv=max_seqlen, - local_cp_size=None, - cp_group=None, + local_cp_size=local_cp_size, + cp_group=cp_group, ) # "attention_mask" is not valid for sequence packing, so set it to None. diff --git a/megatron/core/datasets/data_schedule_utils.py b/megatron/core/datasets/data_schedule_utils.py index f3c637e4c79..4e504e54184 100644 --- a/megatron/core/datasets/data_schedule_utils.py +++ b/megatron/core/datasets/data_schedule_utils.py @@ -1,6 +1,9 @@ # Copyright (c) 2025 NVIDIA CORPORATION. All rights reserved. -from typing import Dict, List +from collections import deque +from functools import lru_cache +from math import ceil, log2 +from typing import Callable, Dict, List, Optional, Tuple import numpy as np import torch @@ -137,7 +140,11 @@ def _get_global_seqlens_and_ids(subsample_seqlens: torch.Tensor, dp_group): def _pack_sequences( - samples: List, padded_lengths: torch.Tensor, original_lengths: torch.Tensor, dev: torch.device + samples: List, + padded_lengths: torch.Tensor, + original_lengths: torch.Tensor, + local_cp_size: Optional[torch.Tensor], + dev: torch.device, ) -> Dict[str, torch.Tensor]: """Pack multiple samples into a single packed sample.""" @@ -172,6 +179,9 @@ def _pack_tensors(tensors): cu_seqlens[1:] = torch.cumsum(original_lengths, dim=0).reshape(-1) new_sample["cu_seqlens"] = cu_seqlens + if local_cp_size is not None: + new_sample["local_cp_size"] = local_cp_size + return new_sample @@ -188,6 +198,7 @@ def broadcast_to_pp_group( seqlen_squared_sum_this_global_batch, pp_group, dev, + is_dynamic_cp: bool = False, ): """ Broadcast num_micro_batches, seqlen_sum_this_global_batch, @@ -213,6 +224,11 @@ def broadcast_to_pp_group( ] for sample in new_samples: tensor_list.append(sample["max_seqlen"].unsqueeze(0)) + + if is_dynamic_cp: + for sample in new_samples: + tensor_list.append(sample["local_cp_size"].unsqueeze(0)) + for sample in new_samples: tensor_list.append(sample["cu_seqlens"]) tensor_list.append(sample["cu_seqlens_padded"]) @@ -232,6 +248,11 @@ def broadcast_to_pp_group( seqlen_sum_this_global_batch = info_numpy[1] seqlen_squared_sum_this_global_batch = info_numpy[2] max_seqlens = info_to_broadcast[3 : 3 + num_micro_batches] + local_cp_sizes = ( + info_to_broadcast[3 + num_micro_batches : 3 + 2 * num_micro_batches] + if is_dynamic_cp + else None + ) cu_seqlens_list = [] cu_seqlens_padded_list = [] # cu_seqlens always starts with 0, and the other metadata values @@ -255,6 +276,8 @@ def broadcast_to_pp_group( new_sample["max_seqlen"] = max_seqlens[i].to(torch.int32) new_sample["cu_seqlens"] = cu_seqlens_list[i].to(torch.int32) new_sample["cu_seqlens_padded"] = cu_seqlens_padded_list[i].to(torch.int32) + if is_dynamic_cp: + new_sample["local_cp_size"] = local_cp_sizes[i].to(torch.int32) new_samples.append(new_sample) return ( @@ -297,7 +320,9 @@ def broadcast_scalars(values: List, group, dev, dtype=torch.float32) -> List: return values -def create_data_iterator(new_samples, tp_group, config, vpp_has_data=None): +def create_data_iterator( + new_samples, tp_group, config, vpp_has_data=None, is_dynamic_cp: bool = False +): """Handle virtual pipeline parallelism. For VPP, each PP rank needs a list of data iterators (one per VPP stage). @@ -318,9 +343,11 @@ def create_data_iterator(new_samples, tp_group, config, vpp_has_data=None): ): vpp_size = config.virtual_pipeline_model_parallel_size if tp_group.rank() == 0: + metadata_keys = ["max_seqlen", "cu_seqlens", "cu_seqlens_padded"] + if is_dynamic_cp: + metadata_keys.append("local_cp_size") metadata = [ - {k: sample[k] for k in ["max_seqlen", "cu_seqlens", "cu_seqlens_padded"]} - for sample in new_samples + {k: sample[k] for k in metadata_keys if k in sample} for sample in new_samples ] new_data_iterator = [] for i in range(vpp_size): @@ -456,14 +483,51 @@ def _unpack_sample_by_key(key: str, recv_tensor: torch.Tensor): def build_packed_microbatches( - grouped_samples: List[List[Dict[str, torch.Tensor]]], dev: torch.device + samples_this_rank_with_id: Dict[int, Dict[str, torch.Tensor]], + sample_id_groups: List[List[List[int]]], + dcp_rank: int, + dev: torch.device, + is_dynamic_cp: bool = False, ) -> List[Dict[str, torch.Tensor]]: - """Build packed samples for each microbatch.""" - num_micro_batches = len(grouped_samples) + """Build packed samples for each microbatch. + + Args: + samples_this_rank_with_id: Mapping from global sample ID to sample dict, + as returned by reroute_samples_to_dcp_ranks. + sample_id_groups: Per-microbatch, per-rank lists of global sample IDs. + dcp_rank: This rank's index within the DP×CP group. + dev: Target device. + is_dynamic_cp: Whether dynamic context parallel is enabled. + """ + num_micro_batches = len(sample_id_groups) seg_starts: List[int] = [0] original_lens_tensors = [] padded_lens_tensors = [] + grouped_samples = [ + [ + samples_this_rank_with_id[sub_sample_id] + for sub_sample_id in sample_id_groups[i][dcp_rank] + ] + for i in range(num_micro_batches) + ] + + local_cp_sizes_gpu = None + if is_dynamic_cp: + local_cp_sizes_cpu: List[int] = [] + for i in range(num_micro_batches): + sample_ids_this_group = sample_id_groups[i][dcp_rank] + local_cp_sizes_cpu.append( + len( + [ + 1 + for sample_ids in sample_id_groups[i] + if sample_ids_this_group[0] in sample_ids + ] + ) + ) + local_cp_sizes_gpu = torch.tensor(local_cp_sizes_cpu, dtype=torch.int32, device=dev) + for i in range(num_micro_batches): samples = grouped_samples[i] seg_starts.append(seg_starts[-1] + len(samples)) @@ -478,7 +542,8 @@ def build_packed_microbatches( samples = grouped_samples[i] lens_padded = padded_lens_all_gpu[seg_starts[i] : seg_starts[i + 1]] lens_original = original_lens_all_gpu[seg_starts[i] : seg_starts[i + 1]] - new_sample = _pack_sequences(samples, lens_padded, lens_original, dev) + local_cp_size = local_cp_sizes_gpu[i] if is_dynamic_cp else None + new_sample = _pack_sequences(samples, lens_padded, lens_original, local_cp_size, dev) new_samples.append(new_sample) return new_samples @@ -527,3 +592,439 @@ def get_batch_and_global_seqlens(data_iterator, num_microbatches, dp_group): ) return batch, global_id_seqlens, global_ids_this_rank, offsets, seqlens_gathered + + +# ============================================================================= +# Dynamic CP scheduling algorithms (used by DefaultDynamicCPScheduler) +# ============================================================================= + + +def next_hdp_group( + sample_seqlens: List[Tuple[int, int]], + compute_estimator: Callable[[int], float], + total_gpus: int, + gpus_needed_fn: Callable[[int], int], + make_buckets_equal_fn: Callable, + max_seq_len_per_rank: float, + get_total_workload_fn: Callable, + delta: float = 0.05, + strategy: str = "dp", + eps_bucket: float = 0.10, +) -> Tuple[List[List[int]], List[Tuple[int, int]], List[float], List[List[int]]]: + """Form one balanced micro-batch group across DPxCP ranks. + + This is a standalone version of the scheduling algorithm extracted from + DefaultDynamicCPScheduler so it can live in a utils module. + + Extra args compared to the method version: + gpus_needed_fn: callable(seq_len) -> int + make_buckets_equal_fn: callable(sample_seqlens, compute_estimator) -> list[deque] + max_seq_len_per_rank: max tokens per rank for packing + get_total_workload_fn: callable(seq_len, cp_size) -> float + """ + if not sample_seqlens: + return ( + [[] for _ in range(total_gpus)], + [], + [0.0 for _ in range(total_gpus)], + [[] for _ in range(total_gpus)], + ) + + buckets = make_buckets_equal_fn(sample_seqlens, compute_estimator) + + micro_batches = [[] for _ in range(total_gpus)] + exec_times = [0.0 for _ in range(total_gpus)] + sample_ids_per_gpu = [[] for _ in range(total_gpus)] + packing_sequence_len = {} + + gpu_group_id = [None] * total_gpus + group_members = {} + group_size = {} + next_gid = 0 + + pp_cursor = 0 + prev_needed = None + check_balance = False + + while buckets: + sample_seq_tuple = bucket_idx = None + needed = None + + scan_order = ( + range(len(buckets)) + if strategy == "dp" + else [(pp_cursor + i) % len(buckets) for i in range(len(buckets))] + ) + + for idx in scan_order: + if not buckets[idx]: + continue + cand_tuple = buckets[idx][0] + cand_seq_len = cand_tuple[1] + needed = gpus_needed_fn(cand_seq_len) + + candidate_gids = [gid for gid, sz in group_size.items() if sz == needed] + free_ranks = [r for r, gid in enumerate(gpu_group_id) if gid is None] + if candidate_gids or len(free_ranks) >= needed: + sample_seq_tuple, bucket_idx = cand_tuple, idx + break + + if sample_seq_tuple is None: + break + + if strategy == "pp": + pp_cursor = (bucket_idx + 1) % len(buckets) + + sample_id, seq_len = sample_seq_tuple + needed = gpus_needed_fn(seq_len) + if prev_needed is None: + prev_needed = needed + + candidate_gids = [ + gid + for gid, sz in group_size.items() + if sz == needed and packing_sequence_len[gid] + seq_len / needed <= max_seq_len_per_rank + ] + if candidate_gids: + best_gid, best_load = min( + ((gid, max(exec_times[r] for r in group_members[gid])) for gid in candidate_gids), + key=lambda t: t[1], + ) + else: + best_gid, best_load = None, float("inf") + + free_ranks = [r for r, gid in enumerate(gpu_group_id) if gid is None] + if len(free_ranks) >= needed: + free_sorted = sorted(free_ranks, key=lambda r: exec_times[r]) + new_members = free_sorted[:needed] + new_load = exec_times[new_members[-1]] + + if new_load < best_load: + best_gid = None + chosen_members = new_members + else: + chosen_members = group_members[best_gid] + else: + if best_gid is None: + break + chosen_members = group_members[best_gid] + + if best_gid is None: + best_gid = next_gid + next_gid += 1 + group_members[best_gid] = chosen_members + group_size[best_gid] = needed + for r in chosen_members: + gpu_group_id[r] = best_gid + + per_gpu_cost = compute_estimator(seq_len) + + packing_sequence_len[best_gid] = packing_sequence_len.get(best_gid, 0) + seq_len / needed + for r in chosen_members: + micro_batches[r].append(seq_len) + exec_times[r] += per_gpu_cost + sample_ids_per_gpu[r].append(sample_id) + + buckets[bucket_idx].popleft() + + while buckets and not buckets[0]: + buckets.pop(0) + pp_cursor %= max(1, len(buckets)) + + if needed < prev_needed: + check_balance = True + + if ( + check_balance + and buckets + and max(exec_times) - min(exec_times) <= delta * max(exec_times) + ): + break + + leftovers = [] + for b in buckets: + for sample_seq_tuple in b: + leftovers.append(sample_seq_tuple) + + def trim_overload(): + while True: + cur_max = max(exec_times) + cur_min = min(exec_times) + cur_slack = cur_max - cur_min + if cur_slack <= delta * cur_max: + break + if cur_min == 0: + break + + max_r = exec_times.index(cur_max) + gid = gpu_group_id[max_r] + members = group_members[gid] + + if not micro_batches[max_r] or len(micro_batches[max_r]) <= 1: + break + + seq = micro_batches[max_r][-1] + per_gpu_cost = compute_estimator(seq) + + proj_times = exec_times[:] + for r in members: + proj_times[r] -= per_gpu_cost + + proj_slack = max(proj_times) - min(proj_times) + + if proj_slack < cur_slack: + sample_id_to_remove = sample_ids_per_gpu[max_r][-1] + for r in members: + micro_batches[r].pop() + exec_times[r] -= per_gpu_cost + sample_ids_per_gpu[r].pop() + leftovers.append((sample_id_to_remove, seq)) + else: + break + + # TODO(tailaim): uncomment this to support different ranks have different num_microbatches + # trim_overload() + + total_work_before = sum(len(mb) for mb in micro_batches) + + def fill_empty_gpus(micro_batches, exec_times, sample_ids_per_gpu, group_members, group_size): + empty_gpus = [i for i in range(total_gpus) if not micro_batches[i]] + if not empty_gpus: + return (micro_batches, exec_times, sample_ids_per_gpu, group_members, group_size) + + existing_group_sizes = set(group_size.values()) + assert ( + existing_group_sizes + ), "There should be at least one group existing, cannot redistribute, " + "try to increase 'max-seqlen-per-dp-cp-rank'." + + min_group_size = min(existing_group_sizes) + next_power = min(min_group_size * 2, total_gpus) + + for gid, size in group_size.items(): + if size == min_group_size: + members = group_members[gid] + needed_count = next_power - min_group_size + group_start_gpu = members[0] + group_end_gpu = members[-1] + empty_gpu = [idx for idx, work in enumerate(micro_batches) if not work][0] + assert not all( + work for work in micro_batches[empty_gpu : empty_gpu + needed_count] + ), "Empty GPUs were detected but not enough to expand." + work_to_push = micro_batches[group_end_gpu + 1 : empty_gpu] + exec_times_to_push = exec_times[group_end_gpu + 1 : empty_gpu] + sample_ids_to_push = sample_ids_per_gpu[group_end_gpu + 1 : empty_gpu] + + new_micro_batches = [[]] * len(micro_batches) + new_exec_times = [0.0] * len(exec_times) + new_sample_ids_per_gpu = [[]] * len(sample_ids_per_gpu) + + for i in range(group_start_gpu): + new_micro_batches[i] = micro_batches[i] + new_exec_times[i] = exec_times[i] + new_sample_ids_per_gpu[i] = sample_ids_per_gpu[i] + + for i in range(group_start_gpu, group_end_gpu + needed_count + 1): + new_micro_batches[i] = micro_batches[group_end_gpu] + new_exec_times[i] = get_total_workload_fn( + micro_batches[group_end_gpu][0], next_power + ) + new_sample_ids_per_gpu[i] = sample_ids_per_gpu[group_end_gpu] + + for i, work in enumerate(work_to_push): + new_micro_batches[group_end_gpu + needed_count + 1 + i] = work + new_exec_times[group_end_gpu + needed_count + 1 + i] = exec_times_to_push[i] + new_sample_ids_per_gpu[group_end_gpu + needed_count + 1 + i] = ( + sample_ids_to_push[i] + ) + + group_size[gid] = next_power + group_members[gid] = list(range(members[0], members[-1] + needed_count + 1)) + for pushed_gid in group_size.keys(): + if pushed_gid > gid: + group_members[pushed_gid] = [ + x + needed_count for x in group_members[pushed_gid] + ] + + return ( + new_micro_batches, + new_exec_times, + new_sample_ids_per_gpu, + group_members, + group_size, + ) + + empty_gpus = any([not micro_batches[i] for i in range(total_gpus)]) + while empty_gpus: + micro_batches, exec_times, sample_ids_per_gpu, group_members, group_size = fill_empty_gpus( + micro_batches, exec_times, sample_ids_per_gpu, group_members, group_size + ) + empty_gpus = any([not micro_batches[i] for i in range(total_gpus)]) + + total_work_after = sum(len(mb) for mb in micro_batches) + assert ( + total_work_after >= total_work_before + ), f"Samples were removed: {total_work_before} -> {total_work_after}" + + return micro_batches, leftovers, exec_times, sample_ids_per_gpu + + +def align_sample_id_groups(sample_id_groups: List, microbatch_group_size_per_vp_stage: int) -> List: + """Align len(sample_id_groups) to microbatch_group_size_per_vp_stage when VPP is enabled. + + Standalone version extracted from DefaultDynamicCPScheduler. + """ + multiple = int(microbatch_group_size_per_vp_stage) + remainder = (-len(sample_id_groups)) % multiple + i = len(sample_id_groups) - 1 + + def split_group(sample_id_group): + total_hdp_ranks = len(sample_id_group) + cu_ranks = [0] + prev_cp_size = 0 + + while cu_ranks[-1] != total_hdp_ranks: + start_rank = cu_ranks[-1] + sid0 = sample_id_group[start_rank][0] + cp_size = 0 + for r in range(start_rank, total_hdp_ranks): + if sid0 in sample_id_group[r]: + cp_size += 1 + else: + break + assert ( + prev_cp_size == 0 or cp_size <= prev_cp_size + ), f"split_group: CP size is not decreasing: prev={prev_cp_size}, cur={cp_size}" + cu_ranks.append(start_rank + cp_size) + prev_cp_size = cp_size + if len(cu_ranks) == 2: + return None, None + + k = 0 + while cu_ranks[k] < total_hdp_ranks // 2: + k += 1 + + old_mb = sample_id_group[: cu_ranks[k]] + [[] for _ in range(total_hdp_ranks - cu_ranks[k])] + new_mb = sample_id_group[cu_ranks[k] :] + [[] for _ in range(cu_ranks[k])] + old_mb = fill_empty_by_expanding_cp(old_mb) + new_mb = fill_empty_by_expanding_cp(new_mb) + return new_mb, old_mb + + def fill_empty_by_expanding_cp(sample_id_group): + def fill_empty(sample_id_group): + empty_size = sum(1 for x in sample_id_group if len(x) == 0) + i = len(sample_id_group) - 1 - empty_size + prev_cp_size = 0 + while i >= 0: + sid0 = sample_id_group[i][0] + cp_size = 0 + while sid0 in sample_id_group[i] and i >= 0: + cp_size += 1 + i -= 1 + if cp_size > prev_cp_size and prev_cp_size != 0: + start_idx = i + 1 + cp_size + end_idx = -empty_size + prev_cp_size if -empty_size + prev_cp_size < 0 else None + sample_id_group[start_idx + 2 * prev_cp_size : end_idx] = sample_id_group[ + start_idx + prev_cp_size : -empty_size + ] + sample_id_group[start_idx + prev_cp_size : start_idx + 2 * prev_cp_size] = ( + sample_id_group[start_idx : start_idx + prev_cp_size] + ) + break + elif cp_size <= empty_size and i == -1: + end_idx = -empty_size + cp_size if -empty_size + cp_size < 0 else None + sample_id_group[2 * cp_size : end_idx] = sample_id_group[cp_size:-empty_size] + sample_id_group[cp_size : 2 * cp_size] = sample_id_group[0:cp_size] + break + prev_cp_size = cp_size + return sample_id_group + + while len(sample_id_group[-1]) == 0: + sample_id_group = fill_empty(sample_id_group) + return sample_id_group + + attempts_since_split = 0 + while remainder > 0: + if i < 0: + if attempts_since_split >= len(sample_id_groups): + assert False, 'align_sample_id_groups: no tail microbatch has enough ids to split' + i = len(sample_id_groups) - 1 + group1, group2 = split_group(sample_id_groups[i]) + if group1 is not None and group2 is not None: + sample_id_groups[i] = group1 + sample_id_groups.append(group2) + remainder -= 1 + attempts_since_split = 0 + else: + attempts_since_split += 1 + i -= 1 + + return sample_id_groups + + +# ============================================================================= +# Workload estimation helpers for dynamic CP scheduling +# ============================================================================= + + +@lru_cache(maxsize=128) +def dcp_gpus_needed( + seq_len: int, max_seq_len_per_rank: int, min_cp_size: int = 1, max_cp_size: Optional[int] = None +) -> int: + """Number of GPUs needed, rounded up to the next power of 2, clamped to [min_cp_size, max_cp_size].""" + raw = max(1, 2 ** ceil(log2(seq_len / max_seq_len_per_rank))) + clamped = max(min_cp_size, raw) + if max_cp_size is not None: + clamped = min(clamped, max_cp_size) + return clamped + + +@lru_cache(maxsize=128) +def dcp_get_total_workload( + seq_length: int, + max_seq_len_per_rank: int, + cp_size: Optional[int] = None, + min_cp_size: int = 1, + max_cp_size: Optional[int] = None, +) -> float: + """Estimate workload of a sub-sample for scheduling balance.""" + if cp_size is None: + cp_size = dcp_gpus_needed(seq_length, max_seq_len_per_rank, min_cp_size, max_cp_size) + return (seq_length * seq_length) / cp_size + + +def dcp_make_buckets_equal( + sample_seqlens: List[Tuple[int, int]], + compute_estimator: Callable, + max_seq_len_per_rank: int, + min_cp_size: int = 1, + max_cp_size: Optional[int] = None, +) -> List[deque]: + """Split samples into buckets of roughly equal work, one per unique CP size.""" + seqlens = [seq_len for _, seq_len in sample_seqlens] + k = len({dcp_gpus_needed(L, max_seq_len_per_rank, min_cp_size, max_cp_size) for L in seqlens}) + + work = [] + for _, s in sample_seqlens: + cp_size = dcp_gpus_needed(s, max_seq_len_per_rank, min_cp_size, max_cp_size) + work.append(compute_estimator(s, cp_size)) + total_work = sum(work) + target = total_work / k + buckets, cur, cur_work = [], [], 0.0 + remaining_k = k + + for i, (sample_id, seq_len) in enumerate(sample_seqlens): + w = compute_estimator(seq_len) + projected = cur_work + w + if cur and ( + projected > target * 1.1 or len(sample_seqlens) - i <= remaining_k - len(buckets) + ): + buckets.append(deque(cur)) + cur, cur_work = [], 0.0 + remaining_k -= 1 + cur.append((sample_id, seq_len)) + cur_work += w + + if cur: + buckets.append(deque(cur)) + return buckets diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index 943a72c531f..afc4b24ff0b 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -1371,6 +1371,8 @@ def forward( packed_seq_params.cp_group is not None ), "cp_group is not set in packed_seq_params for dynamic CP" self.cp_group = packed_seq_params.cp_group + if TEDotProductAttention.cp_stream is None: + TEDotProductAttention.cp_stream = torch.cuda.Stream() super().set_context_parallel_group( self.cp_group, torch.distributed.get_process_group_ranks(self.cp_group), diff --git a/megatron/core/model_parallel_config.py b/megatron/core/model_parallel_config.py index 075aa75c76a..9a015d07b1a 100644 --- a/megatron/core/model_parallel_config.py +++ b/megatron/core/model_parallel_config.py @@ -72,13 +72,18 @@ class ModelParallelConfig: Please set max_seqlen_per_dp_cp_rank when using dynamic_context_parallel. """ + min_dynamic_context_parallel_size: int = 1 + """Minimum CP group size for dynamic context parallel. Default 1 (no CP). + The maximum is always context_parallel_size.""" + hybrid_context_parallel: bool = False """Deprecated. Use ``dynamic_context_parallel`` instead.""" - sequence_packing_scheduler: Optional[Literal['dp_balanced']] = None + sequence_packing_scheduler: Optional[Literal['dp_balanced', 'default_dynamic_cp']] = None """ Scheduler for sequence packing and dynamic context parallel. dp_balanced: DP-balanced scheduler for sequence packing. + default_dynamic_cp: Dynamic-CP scheduler for packed sequence balancing. """ expert_model_parallel_size: int = 1 @@ -428,6 +433,37 @@ def __post_init__(self): ) self.dynamic_context_parallel = True + if self.dynamic_context_parallel: + if self.sequence_packing_scheduler is None: + self.sequence_packing_scheduler = 'default_dynamic_cp' + if self.sequence_packing_scheduler != 'default_dynamic_cp': + raise ValueError( + 'Dynamic context parallelism requires ' + 'sequence_packing_scheduler=default_dynamic_cp' + ) + + if self.min_dynamic_context_parallel_size < 1: + raise ValueError( + f"min_dynamic_context_parallel_size must be >= 1, " + f"got {self.min_dynamic_context_parallel_size}" + ) + + if self.min_dynamic_context_parallel_size > self.context_parallel_size: + raise ValueError( + f"min_dynamic_context_parallel_size ({self.min_dynamic_context_parallel_size}) " + f"must be <= context_parallel_size ({self.context_parallel_size}), " + f"since context_parallel_size is the maximum dynamic CP group size." + ) + + if self.min_dynamic_context_parallel_size > 1: + warnings.warn( + f"min_dynamic_context_parallel_size is set to {self.min_dynamic_context_parallel_size}. " + f"Dynamic CP groups will range from {self.min_dynamic_context_parallel_size} " + f"to {self.context_parallel_size} (context_parallel_size). " + f"This may cause padding overhead for short sequences.", + UserWarning, + ) + if self.sequence_parallel: if self.tensor_model_parallel_size <= 1: raise ValueError("Cannot use sequence parallelism without tensor parallelism") diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py index a0e1b392b43..8ea8556b177 100644 --- a/megatron/core/parallel_state.py +++ b/megatron/core/parallel_state.py @@ -421,16 +421,20 @@ def create_hierarchical_groups( return hierarchical_groups, hierarchical_groups_gloo -def create_dynamic_dp_cp_groups(rank, ranks, pg_options): +def create_dynamic_dp_cp_groups(rank, ranks, pg_options, min_cp_size=1, max_cp_size=None): """ Creates groups required for dynamic DPxCP. - Creates a new group for every power of 2 up to the number of DPxCP ranks. + Creates a new group for every power of 2 from min_cp_size up to max_cp_size. + max_cp_size defaults to len(ranks) (the full DPxCP group size). Returns a dictionary indexed by group size. """ + if max_cp_size is None: + max_cp_size = len(ranks) dynamic_dp_cp_groups = {} - # Generate group for every power of 2 up to the number of CP ranks - # We limit the allowed group sizes in order to avoid excessive overhead. - group_sizes = [2**i for i in range(int(log2(len(ranks))))] + group_sizes = [ + 2**i for i in range(int(log2(len(ranks)))) + if 2**i >= min_cp_size and 2**i <= max_cp_size + ] for group_size in group_sizes: for i in range(0, len(ranks), group_size): group = create_group( @@ -556,6 +560,7 @@ def initialize_model_parallel( context_parallel_size: int = 1, hierarchical_context_parallel_sizes: Optional[List[int]] = None, dynamic_context_parallel: bool = False, + min_dynamic_context_parallel_size: int = 1, expert_model_parallel_size: int = 1, num_distributed_optimizer_instances: int = 1, expert_tensor_parallel_size: Optional[int] = None, @@ -946,16 +951,21 @@ def initialize_model_parallel( ), "Dynamic context parallel requires an even number of ranks" _DYNAMIC_DP_CP_GROUPS.update( create_dynamic_dp_cp_groups( - rank, ranks_with_cp, get_nccl_options("dp_cp", nccl_comm_cfgs) + rank, + ranks_with_cp, + get_nccl_options("dp_cp", nccl_comm_cfgs), + min_cp_size=min_dynamic_context_parallel_size, + max_cp_size=context_parallel_size, ) ) - # PyTorch is performing lazy initialization of the communicator group. - # Therefore, we need to perform a nccl call to ensure that the communicator group is created. data_parallel_size_with_cp = data_parallel_size * context_parallel_size - group_sizes = [2**i for i in range(0, int(log2(data_parallel_size_with_cp)))] - if group_sizes[-1] * 2 == data_parallel_size_with_cp: - group_sizes.append(data_parallel_size_with_cp) + group_sizes = [ + 2**i for i in range(int(log2(data_parallel_size_with_cp))) + if 2**i >= min_dynamic_context_parallel_size and 2**i <= context_parallel_size + ] + if context_parallel_size == data_parallel_size_with_cp: + group_sizes.append(context_parallel_size) for group_size in group_sizes: group = get_dynamic_data_context_parallel_groups(group_size=group_size) torch.distributed.barrier(group=group, device_ids=[torch.cuda.current_device()]) @@ -2101,6 +2111,9 @@ def destroy_model_parallel(): global _CONTEXT_PARALLEL_GLOBAL_RANKS _CONTEXT_PARALLEL_GLOBAL_RANKS = None + global _DYNAMIC_DP_CP_GROUPS + _DYNAMIC_DP_CP_GROUPS = {} + global _EMBEDDING_GROUP _EMBEDDING_GROUP = None diff --git a/megatron/core/pipeline_parallel/dynamic_cp_schedule.py b/megatron/core/pipeline_parallel/dynamic_cp_schedule.py deleted file mode 100644 index 48dd633aeba..00000000000 --- a/megatron/core/pipeline_parallel/dynamic_cp_schedule.py +++ /dev/null @@ -1,660 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. - -from collections import deque -from functools import lru_cache -from math import ceil, log2 -from typing import Callable, List, Optional, Tuple - -import torch - -from megatron.core import parallel_state -from megatron.core.rerun_state_machine import RerunDataIterator - - -class BalancedCPScheduler: - """ - This class provides the functionality to form groups of sub-samples - such that all DPxCP ranks have a roughly balanced workload in the group. - """ - - def __init__(self, max_seq_len_per_rank: int, dp_cp_group: torch.distributed.ProcessGroup): - self.max_seq_len_per_rank = max_seq_len_per_rank - self.num_subsamples = 0 - self.num_subsamples_processed = 0 - self.free_resources = [] - self.total_hdp_gpus = dp_cp_group.size() - - @lru_cache(maxsize=128) - def get_total_workload(self, seq_length: int, cp_size: Optional[int] = None): - """ - seq_length: sequence length of a sub-sample - cp_size: total number of CP ranks working on this sub-sample - - Note: - This function is used to estimate the relative workload intensity - of a sub-sample. This is not meant to be an accurate flops calculator. - - Returns: workload of a sub-sample - """ - if cp_size is None: - cp_size = self.gpus_needed(seq_length) - return (seq_length * seq_length) / cp_size - - @lru_cache(maxsize=128) - def gpus_needed(self, seq_len: int) -> int: - """ - Calculates the number of GPUs needed for a given sequence length - and max sequence length per CP rank. - This is used to determine the CP size of a sub-sample. - - The number is rounded up to the next power of 2 to match the available - dynamic context parallel process group sizes. - """ - return max(1, 2 ** ceil(log2((seq_len / self.max_seq_len_per_rank)))) - - def make_buckets_equal( - self, - sample_seqlens: List[Tuple[int, int]], # List of (sample_id, sequence_length) tuples - compute_estimator: Callable[[int], float], - ) -> List[deque]: - """ - Makes as many buckets as unique CP sizes needed. - This keeps sample IDs tethered to their sequence lengths throughout the bucketing process. - """ - # Extract just the sequence lengths for determining k - seqlens = [seq_len for _, seq_len in sample_seqlens] - - # Determine k based on unique GPU categories needed - k = len({self.gpus_needed(L) for L in seqlens}) - - # Create a work target for each bucket - # This is the total work divided by the number of buckets - work = [] - for _, s in sample_seqlens: - cp_size = self.gpus_needed(s) - work.append(compute_estimator(s, cp_size)) - total_work = sum(work) - target = total_work / k - buckets, cur, cur_work = [], [], 0.0 - remaining_work = total_work - remaining_k = k - - for i, (sample_id, seq_len) in enumerate(sample_seqlens): - work = compute_estimator(seq_len) - projected = cur_work + work - - # Check if we should close this bucket - if cur and ( - projected > target * 1.1 # Too much work - or len(sample_seqlens) - i <= remaining_k - len(buckets) - ): # Need to save sequences for remaining buckets - buckets.append(deque(cur)) - cur, cur_work = [], 0.0 - remaining_work -= sum(compute_estimator(seq_len) for _, seq_len in cur) - remaining_k -= 1 - - cur.append((sample_id, seq_len)) - cur_work += work - - if cur: - buckets.append(deque(cur)) - - return buckets - - def next_hdp_group( - self, - sample_seqlens: List[Tuple[int, int]], # List of (sample_id, sequence_length) tuples - compute_estimator: Callable[[int], float], - total_gpus: int, - delta: float = 0.05, # balance slack (e.g. 5 %) - strategy: str = "dp", # "dp" or "pp" - eps_bucket: float = 0.10, # ε target for bucket balance - ) -> Tuple[List[List[int]], List[Tuple[int, int]], List[float], List[List[int]]]: - """ - Given a list of (sample_id, sequence_length) tuples, this function aims to assign - sequences in a group such that all GPUs in the DPxCP group have a roughly balanced - workload. Once each group is roughly balanced, we exit and return the - group and the leftover sequences. - - The function performs the following passes in order to form a balanced microbatch: - 1. We create buckets of sequences that are roughly balanced. - We try to create as many buckets as possible CP sizes. - 2. Given a bucket has sequences available, we assign the sample - a. To a new set of GPUs if there are enough free GPUs. - b. To an existing set of GPUs with the lowest load. - 3. We check if the group is balanced whenever we need to move onto a new CP size - in the same set of GPUs. - 4. We trim the group if removing the last added sequence helps improve balance. - 5. If we run out of sequences to assign and there are empty GPUs, - we redistribute work to empty GPUs by recursively increasing the CP size of a - sample until no empty GPUs are left. - - Returns (micro_batches, leftover_sample_seqlens, exec_times, sample_ids_per_gpu). - """ - if not sample_seqlens: - return ( - [[] for _ in range(total_gpus)], - [], - [0.0 for _ in range(total_gpus)], - [[] for _ in range(total_gpus)], - ) - - # Get buckets of sequences with balanced work - buckets = self.make_buckets_equal(sample_seqlens, compute_estimator) - - # Initialize tracking structures - micro_batches = [[] for _ in range(total_gpus)] - exec_times = [0.0 for _ in range(total_gpus)] - sample_ids_per_gpu = [[] for _ in range(total_gpus)] - - gpu_group_id = [None] * total_gpus - group_members = {} - group_size = {} - next_gid = 0 - - pp_cursor = 0 - prev_needed = None - check_balance = False - - while buckets: - # ---- Step 1 – pick the next sequence we COULD place ------------------ - sample_seq_tuple = bucket_idx = None - needed = None - - scan_order = ( - range(len(buckets)) - if strategy == "dp" - else [(pp_cursor + i) % len(buckets) for i in range(len(buckets))] - ) - - for idx in scan_order: - if not buckets[idx]: - continue - cand_tuple = buckets[idx][0] # This is now (sample_id, seq_len) - cand_seq_len = cand_tuple[1] - needed = self.gpus_needed(cand_seq_len) - - # (a) Do we have an *existing* group of size `needed`? - candidate_gids = [gid for gid, sz in group_size.items() if sz == needed] - - # (b) Or enough completely free GPUs to start a new group? - free_ranks = [r for r, gid in enumerate(gpu_group_id) if gid is None] - if candidate_gids or len(free_ranks) >= needed: - sample_seq_tuple, bucket_idx = cand_tuple, idx - break - - # No place to put any remaining sequence – finish this micro‑batch - if sample_seq_tuple is None: - break - - # TODO[pmannan]: PP not yet supported. Add PP scheduling. - if strategy == "pp": - pp_cursor = (bucket_idx + 1) % len(buckets) - - sample_id, seq_len = sample_seq_tuple - needed = self.gpus_needed(seq_len) - if prev_needed is None: - prev_needed = needed - - # (a) Existing groups of exactly this size - candidate_gids = [gid for gid, sz in group_size.items() if sz == needed] - if candidate_gids: - best_gid, best_load = min( - ( - (gid, max(exec_times[r] for r in group_members[gid])) - for gid in candidate_gids - ), - key=lambda t: t[1], - ) - else: - best_gid, best_load = None, float("inf") - - # (b) Hypothetical **new** group from completely free GPUs - free_ranks = [r for r, gid in enumerate(gpu_group_id) if gid is None] - if len(free_ranks) >= needed: - free_sorted = sorted(free_ranks, key=lambda r: exec_times[r]) - new_members = free_sorted[:needed] - new_load = exec_times[new_members[-1]] - - if new_load < best_load: - best_gid = None - chosen_members = new_members - else: - chosen_members = group_members[best_gid] - else: - chosen_members = group_members[best_gid] - - # ---- Step 2 – if we decided to create a fresh group ---------------- - if best_gid is None: - best_gid = next_gid - next_gid += 1 - group_members[best_gid] = chosen_members - group_size[best_gid] = needed - for r in chosen_members: - gpu_group_id[r] = best_gid - - # ---- Step 3 – assign the sequence to every member of that group ------ - per_gpu_cost = compute_estimator(seq_len) - - for r in chosen_members: - micro_batches[r].append(seq_len) - exec_times[r] += per_gpu_cost - sample_ids_per_gpu[r].append(sample_id) - - # Remove the sequence definitively from its bucket - buckets[bucket_idx].popleft() - - # ---- Step 4 – tidy, balance‑check, maybe early‑exit ------------------ - while buckets and not buckets[0]: - buckets.pop(0) - pp_cursor %= max(1, len(buckets)) - - # TODO: Removing this helps reduce the number of groups when we have - # lots of samples with same CP size. - # But because we don't exit as soon as we get balanced, - # even if there is one group available that can take the next sample, - # we will keep adding samples to the same group. - # trim_overload() does not help because it only checks if removing the - # last added sample helps. - # We cannot check after adding every sample because there will always be imbalance - # if we don't wait for future scheduling. - - # IMPORTANT: So we need a solution here - if needed < prev_needed: - # When we get into a lower CP size in the same group, - # we can start checking for balance. There is still a gotcha here. - # Let's say we have a group of 3 GPU 0-2, then we move onto group of 2. - # We keep assigning group of 2 as we do in descending order but GPU 7/15 - # never sees a microbatch assigned to it - # until we run out of samples with CP2. - # This means we are never balanced as min(exec_times) will always be 0. - # We need a smart way of identifying that we have run out of big samples - # and if we are having to assign work to a GPU already working, - # is it because there are empty GPUs? - # Would assigning work to empty GPUs first by moving onto next CP bucket help? - # But we need to remember to come back to this CP size bucket and then - # check for balance. Maybe the scheduling algorithm should look at empty - # GPUs and find work rather than going sequence by sequence. - check_balance = True - - if ( - check_balance - and buckets - and max(exec_times) - min(exec_times) <= delta * max(exec_times) - ): - break - - # Gather leftovers (flatten remaining buckets, preserve order) - leftovers = [] - for b in buckets: - for sample_seq_tuple in b: - leftovers.append(sample_seq_tuple) - - # --------------------------------------------------------------------------- - def trim_overload(): - """ - Iteratively pop the most‑recent sequence from the *most‑loaded group* - whenever doing so reduces the global slack. - """ - while True: - cur_max = max(exec_times) - cur_min = min(exec_times) - cur_slack = cur_max - cur_min - if cur_slack <= delta * cur_max: - # Slack is already within limit. - break - if cur_min == 0: - # There are empty GPUs that will be - # handled in the next step. - break - - max_r = exec_times.index(cur_max) - gid = gpu_group_id[max_r] - members = group_members[gid] - - if not micro_batches[max_r] or len(micro_batches[max_r]) <= 1: - break - - seq = micro_batches[max_r][-1] - need = group_size[gid] - per_gpu_cost = compute_estimator(seq) - - proj_times = exec_times[:] - for r in members: - proj_times[r] -= per_gpu_cost - - proj_slack = max(proj_times) - min(proj_times) - - # Check if trimming the workload helps imbalance - if proj_slack < cur_slack: - sample_id_to_remove = sample_ids_per_gpu[max_r][-1] - for r in members: - micro_batches[r].pop() - exec_times[r] -= per_gpu_cost - sample_ids_per_gpu[r].pop() - leftovers.append((sample_id_to_remove, seq)) - else: - break - - trim_overload() - - # Track samples in this group before redistribution to empty GPUs - total_work_before = sum(len(mb) for mb in micro_batches) - - # Check for empty GPUs and redistribute work - def fill_empty_gpus( - micro_batches, exec_times, sample_ids_per_gpu, group_members, group_size - ): - """ - Recursively check for empty GPUs and redistribute work by increasing - the number of GPUs sharing samples. This ensures all GPUs have work. - GPUs must be allocated consecutively so we may need to push existing - work to other ranks in order to expand samples. - """ - # Find empty GPUs - empty_gpus = [i for i in range(total_gpus) if not micro_batches[i]] - if not empty_gpus: - return ( - micro_batches, - exec_times, - sample_ids_per_gpu, - group_members, - group_size, - ) # No empty GPUs, we're done - - # Find the smallest group size that exists - existing_group_sizes = set(group_size.values()) - assert ( - existing_group_sizes - ), "There should be at least one group existing, cannot reditribute, " - "try to increase 'max-seqlen-per-cp-rank'." - - min_group_size = min(existing_group_sizes) - # We have Dynamic DPxCP groups for every power of 2 of GPUs or the entire DPxCP group. - next_power = min(min_group_size * 2, total_gpus) - - # Find the first group of min_group_size that can be expanded - expandable_gid = None - expandable_members = None - expandable_new_gpus = None - - for gid, size in group_size.items(): - if size == min_group_size: - members = group_members[gid] - needed_count = next_power - min_group_size - group_start_gpu = members[0] - group_end_gpu = members[-1] - empty_gpu = [idx for idx, work in enumerate(micro_batches) if not work][0] - assert not all( - work for work in micro_batches[empty_gpu : empty_gpu + needed_count] - ), f"Empty GPUs were detected but not enough to expand." - work_to_push = micro_batches[ - group_end_gpu + 1 : empty_gpu - ] # This is work of all other subsequent sub-samples - exec_times_to_push = exec_times[group_end_gpu + 1 : empty_gpu] - sample_ids_to_push = sample_ids_per_gpu[group_end_gpu + 1 : empty_gpu] - - new_micro_batches = [[]] * len(micro_batches) - new_exec_times = [0.0] * len(exec_times) - new_sample_ids_per_gpu = [[]] * len(sample_ids_per_gpu) - - # No change in work until the group selected for expansion - for i in range(group_start_gpu): - new_micro_batches[i] = micro_batches[i] - new_exec_times[i] = exec_times[i] - new_sample_ids_per_gpu[i] = sample_ids_per_gpu[i] - - # The work is distributed across the expanded group - for i in range(group_start_gpu, group_end_gpu + needed_count + 1): - new_micro_batches[i] = micro_batches[group_end_gpu] - new_exec_times[i] = self.get_total_workload( - micro_batches[group_end_gpu][0], next_power - ) - new_sample_ids_per_gpu[i] = sample_ids_per_gpu[group_end_gpu] - - # Any assigned work on expanded GPUs is pushed - for i, work in enumerate(work_to_push): - new_micro_batches[group_end_gpu + needed_count + 1 + i] = work - new_exec_times[group_end_gpu + needed_count + 1 + i] = exec_times_to_push[i] - new_sample_ids_per_gpu[group_end_gpu + needed_count + 1 + i] = ( - sample_ids_to_push[i] - ) - - group_size[gid] = next_power - group_members[gid] = list(range(members[0], members[-1] + needed_count + 1)) - for pushed_gid in group_size.keys(): - if pushed_gid > gid: - group_members[pushed_gid] = [ - x + needed_count for x in group_members[pushed_gid] - ] - - return ( - new_micro_batches, - new_exec_times, - new_sample_ids_per_gpu, - group_members, - group_size, - ) - - empty_gpus = any([not micro_batches[i] for i in range(total_gpus)]) - while empty_gpus: - micro_batches, exec_times, sample_ids_per_gpu, group_members, group_size = ( - fill_empty_gpus( - micro_batches, exec_times, sample_ids_per_gpu, group_members, group_size - ) - ) - empty_gpus = any([not micro_batches[i] for i in range(total_gpus)]) - - # Assert that no sample has been completely removed - total_work_after = sum(len(mb) for mb in micro_batches) - assert ( - total_work_after >= total_work_before - ), f"Samples were removed: {total_work_before} -> {total_work_after}" - - return micro_batches, leftovers, exec_times, sample_ids_per_gpu - - def get_groups_and_subsamples(self, sample_id_seqlens, config): - """ - This function recursively forms groups of sub-samples such that all DPxCP ranks - have a roughly balanced workload in the group. - """ - groups = [] - sample_id_groups = [] - # We assign a sample_id to each sub-sample in order to track assignment to each GPU. - sample_id_seqlens = sorted(sample_id_seqlens, key=lambda x: x[1], reverse=True) - while sample_id_seqlens: - mb, sample_id_seqlens, exec_times, sample_ids = self.next_hdp_group( - sample_id_seqlens, self.get_total_workload, self.total_hdp_gpus - ) - groups.append(mb) - if len(sample_ids) < self.total_hdp_gpus: - sample_ids.extend([] * (self.total_hdp_gpus - len(sample_ids))) - sample_id_groups.append(sample_ids) - - return groups, sample_id_groups - - -def dynamic_context_parallel_forward_backward( - forward_step_func, - data_iterator, - model, - num_microbatches, - input_tensor, - output_tensor_grad, - forward_data_store, - config, - collect_non_loss_data, - first_val_step, - forward_only, - no_sync_func, - total_num_tokens, - check_first_val_step, - model_type, -): - """ - Scheduler for Dynamic Context Parallel. - - This function performs the packed sample scheduling and determines - 1. The number of microbatches to schedule for each CP rank - 2. The number of groups each CP rank should execute - 3. The number of sub-samples per group each CP rank should execute - - A group is defined by a set of samples that can run across the CP domain without any barrier. - There are many reasons why we may not be able to run endless samples within a single group. - For example, if we have 8 GPUs, - if GPU 0-5 are assigned a long sample that requires CP6, - GPU 6-7 are assigned a short sample that requires CP2, - The next sample which requires CP4 can be assigned GPU 4-7. - But GPU 6-7 will finish first and get deadlocked if GPU 4-5 are not participating in the group. - """ - from .schedules import backward_step, forward_step - - def _broadcast(item): - if item is not None: - torch.distributed.broadcast( - item, - parallel_state.get_tensor_model_parallel_src_rank(), - group=parallel_state.get_tensor_model_parallel_group(), - ) - - def _broadcast_num_samples_this_group(num_samples_this_group): - dev = torch.cuda.current_device() - torch.distributed.barrier() - - n = 0 if num_samples_this_group is None else int(num_samples_this_group.numel()) - n = torch.tensor([n], dtype=torch.int64, device=dev) - - _broadcast(n) - n = int(n.item()) - - assert n > 0, "there should be at least 1 sub samples in the group" - num_samples_this_group_broadcast = ( - torch.empty(n, dtype=torch.int32, device=dev) - if num_samples_this_group is None - else num_samples_this_group - ) - _broadcast(num_samples_this_group_broadcast) - return num_samples_this_group_broadcast - - def _get_new_data_iterator(sample_id_in_group, group_id): - if is_first_tp_rank: - sub_sample_id = sample_ids_this_group[sample_id_in_group] - sample = batch[sub_sample_id] - partner_cp_size = len( - [True for sample_ids in sample_id_groups[group_id] if sub_sample_id in sample_ids] - ) - sample["local_cp_size"] = torch.tensor(partner_cp_size, dtype=torch.int32) - new_data_iterator = RerunDataIterator(iter([sample])) - return new_data_iterator - else: - return None - - # We get data once per global batch and schedule the sub-samples. - # TODO(pmannan): Should we wrap the data_iterator here instead of the training.py file? - hdp_rank = parallel_state.get_data_parallel_rank(with_context_parallel=True) - is_first_tp_rank = parallel_state.get_tensor_model_parallel_rank() == 0 - - if is_first_tp_rank: - data = next(data_iterator) - sample_id_groups = data[1] - batch = data[0] - else: - data, sample_id_groups, batch = None, None, None - - num_samples_this_group = None - if is_first_tp_rank: - num_samples_this_group = torch.tensor( - [len(group[hdp_rank]) for group in sample_id_groups], dtype=torch.int32, device='cuda' - ) - - num_samples_this_group = _broadcast_num_samples_this_group(num_samples_this_group) - num_samples_this_group = num_samples_this_group.cpu().numpy() - num_total_groups = num_samples_this_group.shape[0] - - current_microbatch = 0 - - # Upto last group, we don't need any sync. - with no_sync_func(): - for j in range(num_total_groups - 1): - sample_ids_this_group = sample_id_groups[j][hdp_rank] if is_first_tp_rank else None - for i in range(num_samples_this_group[j]): - # Call forward step for each sub-sample - new_data_iterator = _get_new_data_iterator(i, j) - # TODO: Find the usage of current_microbatch and is_first_microbatch and - # how that may affect my usage. - output_tensor, num_tokens = forward_step( - forward_step_func, - new_data_iterator, - model, - num_microbatches, - input_tensor, - forward_data_store, - config, - collect_non_loss_data, - is_first_microbatch=check_first_val_step( - first_val_step, forward_only, current_microbatch == 0 - ), - current_microbatch=current_microbatch, - ) - current_microbatch += 1 - total_num_tokens += num_tokens.item() - if not forward_only: - backward_step( - input_tensor, output_tensor, output_tensor_grad, model_type, config - ) - - # Create a barrier at end of each group. - # This barrier ensures that all ranks are prepared to change assigned CP group sizes and - # no rank is starting a sub-sample ahead of it's partner ranks. - torch.distributed.barrier( - parallel_state.get_data_parallel_group(with_context_parallel=True) - ) - - # For the last group, we need to run the last sub-sample out of the context handler. - with no_sync_func(): - sample_ids_this_group = sample_id_groups[-1][hdp_rank] if is_first_tp_rank else None - for i in range(num_samples_this_group[-1] - 1): - new_data_iterator = _get_new_data_iterator(i, -1) - # Call forward step for each sub-sample - output_tensor, num_tokens = forward_step( - forward_step_func, - new_data_iterator, - model, - num_microbatches, - input_tensor, - forward_data_store, - config, - collect_non_loss_data, - is_first_microbatch=check_first_val_step( - first_val_step, forward_only, current_microbatch == 0 - ), - current_microbatch=current_microbatch, - ) - current_microbatch += 1 - total_num_tokens += num_tokens.item() - if not forward_only: - backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config) - - # The last sub-sample of the last group of the last microbatch is - # run out of the context handler. - new_data_iterator = _get_new_data_iterator(-1, -1) - # Call forward step for each sub-sample - output_tensor, num_tokens = forward_step( - forward_step_func, - new_data_iterator, - model, - num_microbatches, - input_tensor, - forward_data_store, - config, - collect_non_loss_data, - is_first_microbatch=check_first_val_step( - first_val_step, forward_only, current_microbatch == 0 - ), - current_microbatch=current_microbatch, - ) - total_num_tokens += num_tokens.item() - if not forward_only: - backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config) - - return forward_data_store, total_num_tokens diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index ed3794208f0..d60c3cd9ad5 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -36,7 +36,6 @@ combined_1f1b_schedule_for_interleaved_pipelining, combined_1f1b_schedule_for_no_pipelining, ) -from .dynamic_cp_schedule import dynamic_context_parallel_forward_backward # Types Shape = Union[List[int], torch.Size] @@ -617,24 +616,6 @@ def forward_backward_no_pipelining( total_num_tokens, partial(check_first_val_step, first_val_step, forward_only), ) - elif config.dynamic_context_parallel: - forward_data_store, total_num_tokens = dynamic_context_parallel_forward_backward( - forward_step_func, - data_iterator, - model, - num_microbatches, - input_tensor, - output_tensor_grad, - forward_data_store, - config, - collect_non_loss_data, - first_val_step, - forward_only, - no_sync_func, - total_num_tokens, - check_first_val_step, - model_type, - ) else: with no_sync_func(): for i in range(num_microbatches - 1): diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index 3b054ccc4b1..e3ca98d13d1 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -1050,6 +1050,7 @@ def forward( out = output.transpose(0, 1).contiguous() context_layer = out.view(out.size(0), out.size(1), -1) output, bias = self.linear_proj(context_layer) + self.pg_collection.cp = _orig_cp_group return output, bias if ( diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index d055b7d96cb..681ff0896ff 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -2178,7 +2178,7 @@ def __post_init__(self): f"got '{self.moe_token_dispatcher_type}'" ) - supported_schedulers = ['dp_balanced'] + supported_schedulers = ['dp_balanced', 'default_dynamic_cp'] if ( self.sequence_packing_scheduler is not None and self.sequence_packing_scheduler not in supported_schedulers diff --git a/megatron/core/utils.py b/megatron/core/utils.py index 14c783ab0dc..250c1680c7f 100644 --- a/megatron/core/utils.py +++ b/megatron/core/utils.py @@ -2140,57 +2140,6 @@ def get_thd_batch_on_this_cp_rank( return batch, packed_seq_params -################################ -### dynamic context parallel ### -################################ - - -def get_batch_on_this_dynamic_cp_rank( - batch: Dict[str, Any], - local_cp_size: int, - cp_group: Optional[torch.distributed.ProcessGroup] = None, -): - """Slice batch input along sequence dimension into multiple chunks, - which are parallelized across GPUs in a context parallel group. - """ - assert local_cp_size is not None - if cp_group is None: - # Get the local cp group required for as defined by the DynamicCPDataLoaderWrapper - cp_group = parallel_state.get_dynamic_data_context_parallel_groups(group_size=local_cp_size) - else: - # If cp group is provided, it must match the local cp size - # as defined by the DynamicCPDataLoaderWrapper - assert cp_group.size() == local_cp_size - - # Convert [seqlen] to [1, seqlen] similar to default collate_fn - # as dynamic_context_parallel dataloader wrapper does not go through default collate_fn - for key, data in batch.items(): - if key in ['attention_mask']: - continue - batch[key] = torch.stack([data], 0) - sample_length = batch['tokens'].shape[1] - # TODO(pmannan): Take care of padding tokens here if not divisible by cp_size*2 - # Create packed_seq_params for SBHD format with cp group information. - packed_seq_params = PackedSeqParams( - qkv_format="sbhd", - cu_seqlens_q=torch.tensor([0, sample_length], device="cuda", pin_memory=True), - cu_seqlens_kv=torch.tensor([0, sample_length], device="cuda", pin_memory=True), - cu_seqlens_q_padded=torch.tensor([0, sample_length], device="cuda", pin_memory=True), - cu_seqlens_kv_padded=torch.tensor([0, sample_length], device="cuda", pin_memory=True), - max_seqlen_q=sample_length, - max_seqlen_kv=sample_length, - local_cp_size=local_cp_size, - cp_group=cp_group, - ) - - if cp_group.size() > 1: - # When using dynamic_context_parallel, each sub-sample of a packed sample is - # required to be divisible by CP*DP*2 or CP*DP*TP*2 (if using sequence parallel) - batch = get_batch_on_this_cp_rank(batch, cp_group=cp_group) - - return batch, packed_seq_params - - ###################### ### NVTX profiling ### ###################### diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index eb91fa11cc0..440df865bb2 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1048,14 +1048,33 @@ def validate_args(args, defaults={}): assert args.sequence_parallel == True, 'Tensor parallel communication/GEMM overlap can happen only when sequence parallelism is enabled' if args.dynamic_context_parallel: - assert not args.pipeline_model_parallel_size > 1, 'Dynamic context parallelism not supported with pipeline parallelism' assert not args.enable_cuda_graph, 'Dynamic context parallelism not supported with CUDA Graph' assert not args.use_megatron_fsdp, 'Dynamic context parallelism not supported with Megatron FSDP' assert args.dataloader_type == 'single', 'Dynamic context parallelism only supported with single dataloader type' assert args.calculate_per_token_loss, 'Dynamic context parallelism must be used with --calculate-per-token-loss' + if args.sequence_packing_scheduler is None: + args.sequence_packing_scheduler = 'default_dynamic_cp' + if args.sequence_packing_scheduler != 'default_dynamic_cp': + raise ValueError( + 'Dynamic context parallelism requires ' + 'sequence_packing_scheduler=default_dynamic_cp' + ) + + import warnings + warnings.warn( + f"Dynamic CP enabled: context_parallel_size={args.context_parallel_size} " + f"will be used as the maximum dynamic CP group size. " + f"Dynamic CP groups will range from " + f"min_dynamic_context_parallel_size={args.min_dynamic_context_parallel_size} " + f"to {args.context_parallel_size}." + ) if args.sequence_packing_scheduler is not None: - assert args.context_parallel_size * args.max_seqlen_per_dp_cp_rank >= args.seq_length, \ + if args.sequence_packing_scheduler == 'dp_balanced': + max_cp_size = args.context_parallel_size + else: + max_cp_size = args.data_parallel_size * args.context_parallel_size + assert max_cp_size * args.max_seqlen_per_dp_cp_rank >= args.seq_length, \ f'Packed sequence buffer size ({args.context_parallel_size * args.max_seqlen_per_dp_cp_rank}) ' \ f'must be >= single sequence max length ({args.seq_length})' diff --git a/megatron/training/datasets/data_samplers.py b/megatron/training/datasets/data_samplers.py index 166d4597a97..4c8e81ff3b2 100644 --- a/megatron/training/datasets/data_samplers.py +++ b/megatron/training/datasets/data_samplers.py @@ -39,22 +39,13 @@ def build_pretraining_data_loader(dataset, consumed_samples): data_parallel_size=mpu.get_data_parallel_world_size(), ) elif args.dataloader_type == 'single': - if args.dynamic_context_parallel: - batch_sampler = DynamicCPMegatronPretrainingSampler( - total_samples=len(dataset), - consumed_samples=consumed_samples, - micro_batch_size=args.micro_batch_size, - global_batch_size=args.global_batch_size, - data_parallel_rank=mpu.get_data_parallel_rank(), - data_parallel_size=mpu.get_data_parallel_world_size()) - else: - # Megatron sampler - batch_sampler = MegatronPretrainingSampler( - total_samples=len(dataset), - consumed_samples=consumed_samples, - micro_batch_size=args.micro_batch_size, - data_parallel_rank=mpu.get_data_parallel_rank(), - data_parallel_size=mpu.get_data_parallel_world_size()) + # Megatron sampler + batch_sampler = MegatronPretrainingSampler( + total_samples=len(dataset), + consumed_samples=consumed_samples, + micro_batch_size=args.micro_batch_size, + data_parallel_rank=mpu.get_data_parallel_rank(), + data_parallel_size=mpu.get_data_parallel_world_size()) elif args.dataloader_type == 'cyclic': batch_sampler = MegatronPretrainingRandomSampler( dataset, @@ -162,50 +153,6 @@ def __iter__(self): start_idx, end_idx = self.get_start_end_idx() yield batch[start_idx:end_idx] -class DynamicCPMegatronPretrainingSampler(MegatronPretrainingSampler): - """ - Data sampler for dynamic context parallel (Dynamic CP) format. - This data sampler pulls in the entire global batch at once across all data parallel ranks. - This helps provide the Dynamic CP Dataloader Wrapper to schedule and load balance sub-samples - of the entire global batch. - """ - - def __init__(self, total_samples, consumed_samples, micro_batch_size, global_batch_size, - data_parallel_rank, data_parallel_size, drop_last=True): - super().__init__(total_samples, consumed_samples, micro_batch_size, data_parallel_rank, data_parallel_size, drop_last) - self.global_batch_size = global_batch_size - self.data_parallel_size = data_parallel_size - self.num_micro_batches = self.global_batch_size // self.micro_batch_times_data_parallel_size - - def __len__(self): - return self.total_samples - - def get_start_end_idx_global_batch(self): - start_idx = [self.data_parallel_rank * self.micro_batch_size + i * self.micro_batch_size * self.data_parallel_size for i in range(self.num_micro_batches)] - end_idx = [start_idx[i] + self.micro_batch_size for i in range(self.num_micro_batches)] - return start_idx, end_idx - - def __iter__(self): - batch = [] - # Last batch will be dropped if drop_last is not set False - for idx in range(self.consumed_samples, self.total_samples): - batch.append(idx) - if len(batch) == self.micro_batch_times_data_parallel_size * self.num_micro_batches: - start_idx, end_idx = self.get_start_end_idx_global_batch() - global_batch_idx = [] - for i in range(self.num_micro_batches): - global_batch_idx.extend(batch[start_idx[i]:end_idx[i]]) - yield global_batch_idx - batch = [] - - # Check the last partial batch and see drop_last is set - if len(batch) > 0 and not self.drop_last: - start_idx, end_idx = self.get_start_end_idx_global_batch() - global_batch_idx = [] - for i in range(self.num_micro_batches): - global_batch_idx.extend(batch[start_idx[i]:end_idx[i]]) - yield global_batch_idx - class RandomSeedDataset(Dataset): """ A dataset wrapper that resets the random seed before each sample. diff --git a/megatron/training/initialize.py b/megatron/training/initialize.py index a5c757ca41b..289a8dec4c7 100644 --- a/megatron/training/initialize.py +++ b/megatron/training/initialize.py @@ -376,6 +376,7 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks, s context_parallel_size=args.context_parallel_size, hierarchical_context_parallel_sizes=args.hierarchical_context_parallel_sizes, dynamic_context_parallel=args.dynamic_context_parallel, + min_dynamic_context_parallel_size=args.min_dynamic_context_parallel_size, expert_model_parallel_size=args.expert_model_parallel_size, num_distributed_optimizer_instances=args.num_distributed_optimizer_instances, expert_tensor_parallel_size=args.expert_tensor_parallel_size, diff --git a/megatron/training/training.py b/megatron/training/training.py index c5715e96aed..b3506f4bc4e 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -142,7 +142,6 @@ def set_startup_timestamps(program_start=None, main_entry=None): from megatron.training.initialize import set_jit_fusion_options from megatron.training.utils import get_batch_on_this_cp_rank, get_batch_on_this_tp_rank from megatron.training.datasets.data_samplers import build_pretraining_data_loader -from megatron.core.datasets.data_schedule import DynamicCPDataLoaderWrapper from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler from megatron.core.transformer.moe import upcycling_utils from megatron.core.transformer.moe.moe_logging import get_moe_metrics_tracker @@ -2589,9 +2588,6 @@ def train( energy_monitor = get_energy_monitor() one_logger = get_one_logger() - if args.dynamic_context_parallel: - train_data_iterator = iter(DynamicCPDataLoaderWrapper(train_data_iterator, config)) - if args.run_workload_inspector_server: try: from workload_inspector.utils.webserver import run_server diff --git a/pretrain_gpt.py b/pretrain_gpt.py index 6ca303386ed..d80985db641 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -30,7 +30,7 @@ from megatron.core.models.gpt import GPTModel from megatron.core.rerun_state_machine import get_rerun_state_machine from megatron.core.tokenizers.utils.build_tokenizer import build_tokenizer -from megatron.core.utils import get_attr_wrapped_model, get_thd_batch_on_this_cp_rank, get_batch_on_this_dynamic_cp_rank, StragglerDetector +from megatron.core.utils import get_attr_wrapped_model, get_thd_batch_on_this_cp_rank, StragglerDetector from megatron.training import ( get_args, get_timers, @@ -75,6 +75,7 @@ def get_batch(data_iterator, vp_stage: Optional[int] = None): vpp_size=config.virtual_pipeline_model_parallel_size, mtp_on_this_rank=mtp_on_this_rank(config, ignore_virtual=False, vp_stage=vp_stage), vp_stage=vp_stage, + dynamic_cp=args.dynamic_context_parallel, ) # TODO: this is pretty hacky, find a better way @@ -91,19 +92,14 @@ def get_batch(data_iterator, vp_stage: Optional[int] = None): cu_seqlens = batch.pop('cu_seqlens', None) cu_seqlens_padded = batch.pop('cu_seqlens_padded', None) max_seqlen = batch.pop('max_seqlen', None) - local_cp_size = batch.pop('local_cp_size', None) - if local_cp_size is not None: - local_cp_size = int(local_cp_size.item()) - if cu_seqlens is None and local_cp_size is None: + if cu_seqlens is None: # slice batch along sequence dimension for context parallelism batch = get_batch_on_this_cp_rank(batch) # The implementation of this function is in MCore packed_seq_params = None - elif local_cp_size is None: # Packed THD format + else: # Packed THD format assert max_seqlen.dim() == 1 batch, packed_seq_params = get_thd_batch_on_this_cp_rank(batch, cu_seqlens, cu_seqlens_padded, max_seqlen) - else: # Dynamic CP format - batch, packed_seq_params = get_batch_on_this_dynamic_cp_rank(batch, local_cp_size) return (*batch.values(), packed_seq_params)