|
| 1 | +# Copyright (c) 2025 NVIDIA CORPORATION. All rights reserved. |
| 2 | + |
| 3 | +from typing import Any, List, Optional |
| 4 | + |
| 5 | +import torch |
| 6 | + |
| 7 | +from megatron.core import parallel_state |
| 8 | +from megatron.core.pipeline_parallel.hybrid_cp_schedule import BalancedCPScheduler |
| 9 | +from megatron.core.process_groups_config import ProcessGroupCollection |
| 10 | + |
| 11 | + |
| 12 | +class HybridCPDataLoaderWrapper: |
| 13 | + """ |
| 14 | + A wrapper class that wraps around an existing data_iterator. |
| 15 | + For every __next__ call, |
| 16 | + 1. Each DP rank pulls a batch of packed samples. |
| 17 | + 2. Extracts the sequence lengths of each sub-sample and all-gathers across the DP group. |
| 18 | + 3. Schedules the sub-samples to the DPxCP ranks using the BalancedCPScheduler. |
| 19 | + 4. Based on the schedule, reroutes the sub-samples to the correct rank using all-to-all. |
| 20 | + 5. Returns the assigned sub-samples to this rank. |
| 21 | +
|
| 22 | + Args: |
| 23 | + data_iterator: The original data_iterator to wrap around |
| 24 | + config: The config object containing the max_seqlen_per_dp_cp_rank |
| 25 | + dp_cp_group: Data parallel context parallel group. |
| 26 | + """ |
| 27 | + |
| 28 | + def __init__( |
| 29 | + self, data_iterator, config, pg_collection: Optional[ProcessGroupCollection] = None |
| 30 | + ): |
| 31 | + self.data_iterator = data_iterator |
| 32 | + self.config = config |
| 33 | + if pg_collection is None: |
| 34 | + self.dp_cp_group = parallel_state.get_data_parallel_group(with_context_parallel=True) |
| 35 | + self.dp_group = parallel_state.get_data_parallel_group() |
| 36 | + self.tp_group = parallel_state.get_tensor_model_parallel_group() |
| 37 | + else: |
| 38 | + self.dp_cp_group = pg_collection.dp_cp |
| 39 | + self.dp_group = pg_collection.dp |
| 40 | + self.tp_group = pg_collection.tp |
| 41 | + assert ( |
| 42 | + self.dp_cp_group is not None and self.dp_group is not None and self.tp_group is not None |
| 43 | + ), "dp_cp_group, dp_group, tp_group must not be None when using hybrid context parallel" |
| 44 | + |
| 45 | + self.cp_balancing_scheduler = BalancedCPScheduler( |
| 46 | + max_seq_len_per_rank=self.config.max_seqlen_per_dp_cp_rank, dp_cp_group=self.dp_cp_group |
| 47 | + ) |
| 48 | + |
| 49 | + self.total_hdp_gpus = self.dp_cp_group.size() |
| 50 | + |
| 51 | + def __iter__(self): |
| 52 | + """Return self as an iterator.""" |
| 53 | + return self |
| 54 | + |
| 55 | + def get_global_seqlens(self, subsample_seqlens: torch.Tensor) -> List[int]: |
| 56 | + """ |
| 57 | + Gathers the sequence lengths of all subsamples from all DP ranks. |
| 58 | + Each DP rank loads the same number of microbatches but each microbatch |
| 59 | + may have a different number of subsamples. |
| 60 | +
|
| 61 | + We find the number of subsamples each rank holds and then gather the |
| 62 | + sequence lengths of all subsamples from all ranks. |
| 63 | + """ |
| 64 | + # Collect the number of subsamples from all ranks |
| 65 | + local_len = torch.tensor([subsample_seqlens.shape[0]], dtype=torch.int32).cuda() |
| 66 | + dp_subsample_count = [torch.zeros_like(local_len) for _ in range(self.dp_group.size())] |
| 67 | + torch.distributed.all_gather(dp_subsample_count, local_len, group=self.dp_group) |
| 68 | + |
| 69 | + # Find the max number of subsamples across all ranks and pad subsample_seqlens to max length |
| 70 | + dp_subsample_counts = torch.stack(dp_subsample_count, dim=0).cpu().view(-1) |
| 71 | + max_sub_samples = int(dp_subsample_counts.max().item()) |
| 72 | + |
| 73 | + if local_len.item() < max_sub_samples: |
| 74 | + subsample_seqlens_padded = torch.cat( |
| 75 | + [ |
| 76 | + subsample_seqlens, |
| 77 | + torch.zeros(max_sub_samples - local_len.item(), dtype=torch.int32).cuda(), |
| 78 | + ], |
| 79 | + dim=0, |
| 80 | + ) |
| 81 | + else: |
| 82 | + subsample_seqlens_padded = subsample_seqlens |
| 83 | + |
| 84 | + # Gather the subsample_seqlens from all ranks |
| 85 | + seqlens_gathered = [ |
| 86 | + torch.empty_like(subsample_seqlens_padded) for _ in range(self.dp_group.size()) |
| 87 | + ] |
| 88 | + torch.distributed.all_gather( |
| 89 | + seqlens_gathered, subsample_seqlens_padded, group=self.dp_group |
| 90 | + ) |
| 91 | + |
| 92 | + # Trim each seqlens_gathered to the length of the correct sample |
| 93 | + for dp_rank, seqlen in enumerate(seqlens_gathered): |
| 94 | + seqlens_gathered[dp_rank] = seqlen[: dp_subsample_counts[dp_rank]] |
| 95 | + |
| 96 | + seqlens_gathered = torch.cat(seqlens_gathered, dim=0) |
| 97 | + seqlens_gathered = seqlens_gathered.cpu().tolist() |
| 98 | + |
| 99 | + # Calculate the offsets to assign unique global ID to each subsample. |
| 100 | + csum = torch.cumsum(dp_subsample_counts, dim=0, dtype=torch.int32) |
| 101 | + offsets = torch.cat([torch.zeros(1, dtype=torch.int32), csum[:-1]], dim=0) |
| 102 | + |
| 103 | + return seqlens_gathered, offsets |
| 104 | + |
| 105 | + def get_global_id_seqlens(self, num_local_subsamples, offsets, seqlens_gathered): |
| 106 | + """ |
| 107 | + Calculates the global ID for each subsample. |
| 108 | +
|
| 109 | + We assign a unique global ID to each subsample. |
| 110 | +
|
| 111 | + Returns: |
| 112 | + global_id_seqlens: list of (global_id, seqlen) tuples for scheduling. |
| 113 | + global_ids_this_rank: list of global IDs locally present on this rank. |
| 114 | + """ |
| 115 | + dp_rank = self.dp_group.rank() |
| 116 | + global_ids = torch.arange(len(seqlens_gathered), dtype=torch.int32).cuda() |
| 117 | + # Create a list of (global_id, seqlen) tuples for scheduling |
| 118 | + global_id_seqlens = [(i, seqlens_gathered[i]) for i in range(len(global_ids))] |
| 119 | + # Get the global IDs locally present on this rank |
| 120 | + global_ids_this_rank = global_ids[ |
| 121 | + offsets[dp_rank] : offsets[dp_rank] + num_local_subsamples |
| 122 | + ] |
| 123 | + |
| 124 | + return global_id_seqlens, global_ids_this_rank |
| 125 | + |
| 126 | + def _gid_to_src_rank(self, gid: int, offsets: List[int]) -> int: |
| 127 | + dp_src_rank = torch.bucketize(gid, offsets[1:] - 1) |
| 128 | + # Since the torch.distributed.get_process_group_ranks |
| 129 | + # provides the global rank, we need to consider TP |
| 130 | + hdp_rank = ( |
| 131 | + torch.distributed.get_process_group_ranks(self.dp_group)[dp_src_rank] |
| 132 | + // self.tp_group.size() |
| 133 | + ) |
| 134 | + return hdp_rank |
| 135 | + |
| 136 | + def reroute_samples_to_hdp_ranks( |
| 137 | + self, batch, global_ids_this_rank, global_id_seqlens, sample_id_groups, offsets |
| 138 | + ): |
| 139 | + """ |
| 140 | + Reroutes the sub-samples to the correct rank after scheduling. |
| 141 | +
|
| 142 | + For each key in the batch dict, we perform an all-to-all communication |
| 143 | + to transfer the data to the correct ranks. |
| 144 | + Since all CP ranks within a DP group have the same data, we only need |
| 145 | + to transfer data between matching CP ranks. |
| 146 | + """ |
| 147 | + gid2local_id = {int(gid): i for i, gid in enumerate(global_ids_this_rank)} |
| 148 | + hdp_rank = self.dp_cp_group.rank() |
| 149 | + dp_ranks = torch.distributed.get_process_group_ranks(self.dp_group) |
| 150 | + # Here we actually want to get the DP group's rank within the HDP group, |
| 151 | + # we need to consider TP |
| 152 | + dp_ranks = [r // self.tp_group.size() for r in dp_ranks] |
| 153 | + |
| 154 | + data_keys = batch[0].keys() |
| 155 | + |
| 156 | + # Create the send plan |
| 157 | + combined_sample_id_groups: List[List[int]] = [[] for _ in range(self.total_hdp_gpus)] |
| 158 | + |
| 159 | + for d in range(self.total_hdp_gpus): |
| 160 | + for sample_id_group in sample_id_groups: |
| 161 | + combined_sample_id_groups[d].extend(sample_id_group[d]) |
| 162 | + |
| 163 | + for dest_rank in range(self.total_hdp_gpus): |
| 164 | + combined_sample_id_groups[dest_rank].sort() |
| 165 | + |
| 166 | + # Filter out samples that are not present on this rank |
| 167 | + send_ids_sorted = [ |
| 168 | + gid |
| 169 | + for d in dp_ranks |
| 170 | + for gid in combined_sample_id_groups[d] |
| 171 | + if gid in global_ids_this_rank |
| 172 | + ] |
| 173 | + # send_counts = [len(combined_sample_id_groups[d]) for d in range(self.total_hdp_gpus)] |
| 174 | + |
| 175 | + send_lens_split = [0] * self.total_hdp_gpus |
| 176 | + for dest_rank in range(self.total_hdp_gpus): |
| 177 | + if dest_rank in dp_ranks: |
| 178 | + send_lens_split[dest_rank] = sum( |
| 179 | + [ |
| 180 | + global_id_seqlens[gid][1] |
| 181 | + for gid in combined_sample_id_groups[dest_rank] |
| 182 | + if gid in global_ids_this_rank |
| 183 | + ] |
| 184 | + ) |
| 185 | + else: |
| 186 | + # We only need to share local data with DP ranks that have different data. |
| 187 | + send_lens_split[dest_rank] = 0 |
| 188 | + |
| 189 | + # Create the recv plan |
| 190 | + recv_sample_id_groups = [[] for _ in range(self.total_hdp_gpus)] |
| 191 | + for gid in combined_sample_id_groups[hdp_rank]: |
| 192 | + src_rank = self._gid_to_src_rank(gid, offsets) |
| 193 | + recv_sample_id_groups[src_rank].append(gid) |
| 194 | + |
| 195 | + recv_lens_split = [0] * self.total_hdp_gpus |
| 196 | + for src_rank in range(self.total_hdp_gpus): |
| 197 | + recv_lens_split[src_rank] = sum( |
| 198 | + [global_id_seqlens[gid][1] for gid in recv_sample_id_groups[src_rank]] |
| 199 | + ) |
| 200 | + |
| 201 | + recv_ids_sorted = [ |
| 202 | + gid for d in range(self.total_hdp_gpus) for gid in recv_sample_id_groups[d] |
| 203 | + ] |
| 204 | + recv_counts = [len(recv_sample_id_groups[d]) for d in range(self.total_hdp_gpus)] |
| 205 | + |
| 206 | + recv_samples = [{k: None for k in data_keys} for _ in range(sum(recv_counts))] |
| 207 | + |
| 208 | + def _pack_sample_by_key(key: str) -> torch.Tensor: |
| 209 | + flattened_tensors = [] |
| 210 | + for gid in send_ids_sorted: |
| 211 | + t = batch[gid2local_id[gid]][key].to(torch.cuda.current_device(), non_blocking=True) |
| 212 | + flattened_tensors.append(t) |
| 213 | + return ( |
| 214 | + torch.cat(flattened_tensors, dim=0) |
| 215 | + if flattened_tensors |
| 216 | + else torch.empty(0, device=torch.cuda.current_device(), dtype=batch[0][key].dtype) |
| 217 | + ) |
| 218 | + |
| 219 | + def _unpack_sample_by_key(key: str, recv_tensor: torch.Tensor): |
| 220 | + cursor = 0 |
| 221 | + for i, gid in enumerate(recv_ids_sorted): |
| 222 | + sample_len = global_id_seqlens[gid][1] |
| 223 | + recv_samples[i][key] = recv_tensor[cursor : cursor + sample_len] |
| 224 | + cursor += sample_len |
| 225 | + |
| 226 | + for key in data_keys: |
| 227 | + send_tensor = _pack_sample_by_key(key) |
| 228 | + recv_tensor = torch.empty( |
| 229 | + sum(recv_lens_split), device=torch.cuda.current_device(), dtype=send_tensor.dtype |
| 230 | + ) |
| 231 | + torch.distributed.all_to_all_single( |
| 232 | + output=recv_tensor, |
| 233 | + input=send_tensor, |
| 234 | + output_split_sizes=recv_lens_split, |
| 235 | + input_split_sizes=send_lens_split, |
| 236 | + group=self.dp_cp_group, |
| 237 | + ) |
| 238 | + _unpack_sample_by_key(key, recv_tensor) |
| 239 | + |
| 240 | + recv_sample_with_id = { |
| 241 | + recv_id: recv_samples[i] for i, recv_id in enumerate(recv_ids_sorted) |
| 242 | + } |
| 243 | + return recv_sample_with_id |
| 244 | + |
| 245 | + def unpack_batch(self, batch): |
| 246 | + """ |
| 247 | + Unpacks the packed samples into a list of sub-samples. |
| 248 | + Since each sub-sample may be routed to different DPxCP ranks, |
| 249 | + we unpack the sample here to avoid unnecessarily transferring |
| 250 | + the entire packed sample. |
| 251 | + """ |
| 252 | + batch_unpacked = [] |
| 253 | + for sample in batch: |
| 254 | + for sub_sample in range(sample["cu_seqlens"].shape[0] - 1): |
| 255 | + sub_sample_dict = {} |
| 256 | + start_idx = sample["cu_seqlens"][sub_sample] |
| 257 | + end_idx = sample["cu_seqlens"][sub_sample + 1] |
| 258 | + if end_idx - start_idx == 0: |
| 259 | + continue |
| 260 | + for key in sample.keys(): |
| 261 | + if key in ["cu_seqlens", "batch_idx", "max_seqlen"]: |
| 262 | + continue |
| 263 | + sub_sample_dict[key] = sample[key][start_idx:end_idx] |
| 264 | + batch_unpacked.append(sub_sample_dict) |
| 265 | + return batch_unpacked |
| 266 | + |
| 267 | + def __next__(self) -> Any: |
| 268 | + """ |
| 269 | + Get the next item from the dataset, pull scheduling metadata and return it. |
| 270 | + """ |
| 271 | + if self.data_iterator is None: |
| 272 | + # TP0 reads from data_iterator, others receive via broadcast. |
| 273 | + return None, None |
| 274 | + else: |
| 275 | + batch = next(self.data_iterator) |
| 276 | + subsample_seqlens = [] |
| 277 | + for sample in batch: |
| 278 | + subsample_seqlens.extend( |
| 279 | + [ |
| 280 | + int(sample["cu_seqlens"][i + 1] - sample["cu_seqlens"][i]) |
| 281 | + for i in range(0, sample["cu_seqlens"].shape[0] - 1) |
| 282 | + ] |
| 283 | + ) |
| 284 | + subsample_seqlens = torch.tensor(subsample_seqlens, dtype=torch.int32).cuda() |
| 285 | + subsample_seqlens = subsample_seqlens[subsample_seqlens != 0] |
| 286 | + |
| 287 | + seqlens_gathered, offsets = self.get_global_seqlens(subsample_seqlens) |
| 288 | + |
| 289 | + global_id_seqlens, global_ids_this_rank = self.get_global_id_seqlens( |
| 290 | + subsample_seqlens.shape[0], offsets, seqlens_gathered |
| 291 | + ) |
| 292 | + |
| 293 | + groups, sample_id_groups = self.cp_balancing_scheduler.get_groups_and_subsamples( |
| 294 | + global_id_seqlens, self.config |
| 295 | + ) |
| 296 | + |
| 297 | + batch = self.unpack_batch(batch) |
| 298 | + samples_this_rank_with_id = self.reroute_samples_to_hdp_ranks( |
| 299 | + batch, global_ids_this_rank, global_id_seqlens, sample_id_groups, offsets |
| 300 | + ) |
| 301 | + return samples_this_rank_with_id, sample_id_groups |
0 commit comments