Skip to content

Commit f707492

Browse files
Merge branch 'main' into helenn-dev-rl-training-graphs-test
2 parents 44b5b4d + 517dfd4 commit f707492

File tree

52 files changed

+3175
-581
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+3175
-581
lines changed

.github/actions/action.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,8 @@ runs:
126126
IS_CI_WORKLOAD: ${{ inputs.is_ci_workload }}
127127
run: |
128128
PR_NUMBER=${{ fromJSON(steps.get-pr-info.outputs.pr-info || '{}').number }}
129-
HAS_RUN_FUNCTIONAL_TESTS_LABEL=$(gh pr view $PR_NUMBER --json labels | jq '[.labels[].name] | any(. == "Run functional tests")') || echo "$IS_CI_WORKLOAD"
129+
HAS_RUN_FUNCTIONAL_TESTS_LABEL=$(gh pr view $PR_NUMBER --json labels | jq '[.labels[].name] | any(. == "Run functional tests")')
130+
HAS_RUN_FUNCTIONAL_TESTS_LABEL=${HAS_RUN_FUNCTIONAL_TESTS_LABEL:-$IS_CI_WORKLOAD}
130131
echo "main=$HAS_RUN_FUNCTIONAL_TESTS_LABEL" | tee -a $GITHUB_OUTPUT
131132
132133
- name: Create run-script (e2e test)

.github/workflows/cicd-main.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -449,8 +449,8 @@ jobs:
449449
IS_CI_WORKLOAD: ${{ needs.pre-flight.outputs.is_ci_workload }}
450450
run: |
451451
PR_NUMBER=${{ fromJSON(steps.get-pr-info.outputs.pr-info || '{}').number }}
452-
HAS_RUN_FUNCTIONAL_TESTS_LABEL=$(gh pr view $PR_NUMBER --json labels | jq '[.labels[].name] | any(. == "Run functional tests")') || echo "$IS_CI_WORKLOAD"
453-
452+
HAS_RUN_FUNCTIONAL_TESTS_LABEL=$(gh pr view $PR_NUMBER --json labels | jq '[.labels[].name] | any(. == "Run functional tests")')
453+
HAS_RUN_FUNCTIONAL_TESTS_LABEL=${HAS_RUN_FUNCTIONAL_TESTS_LABEL:-$IS_CI_WORKLOAD}
454454
echo "main=$HAS_RUN_FUNCTIONAL_TESTS_LABEL" | tee -a $GITHUB_OUTPUT
455455
456456
- name: Parse functional tests

.github/workflows/oncall-rotation.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,7 @@ jobs:
4646
SLACK_TOKEN: ${{ secrets.ONCALL_SLACK_TOKEN }}
4747
run: |
4848
pip install --no-cache-dir uv
49-
uv pip install slack-sdk
50-
uv run python .github/scripts/oncall_manager.py rotate
49+
uv run --with slack-sdk python .github/scripts/oncall_manager.py rotate
5150
5251
- name: Commit and Push changes
5352
run: |
Lines changed: 301 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,301 @@
1+
# Copyright (c) 2025 NVIDIA CORPORATION. All rights reserved.
2+
3+
from typing import Any, List, Optional
4+
5+
import torch
6+
7+
from megatron.core import parallel_state
8+
from megatron.core.pipeline_parallel.hybrid_cp_schedule import BalancedCPScheduler
9+
from megatron.core.process_groups_config import ProcessGroupCollection
10+
11+
12+
class HybridCPDataLoaderWrapper:
13+
"""
14+
A wrapper class that wraps around an existing data_iterator.
15+
For every __next__ call,
16+
1. Each DP rank pulls a batch of packed samples.
17+
2. Extracts the sequence lengths of each sub-sample and all-gathers across the DP group.
18+
3. Schedules the sub-samples to the DPxCP ranks using the BalancedCPScheduler.
19+
4. Based on the schedule, reroutes the sub-samples to the correct rank using all-to-all.
20+
5. Returns the assigned sub-samples to this rank.
21+
22+
Args:
23+
data_iterator: The original data_iterator to wrap around
24+
config: The config object containing the max_seqlen_per_dp_cp_rank
25+
dp_cp_group: Data parallel context parallel group.
26+
"""
27+
28+
def __init__(
29+
self, data_iterator, config, pg_collection: Optional[ProcessGroupCollection] = None
30+
):
31+
self.data_iterator = data_iterator
32+
self.config = config
33+
if pg_collection is None:
34+
self.dp_cp_group = parallel_state.get_data_parallel_group(with_context_parallel=True)
35+
self.dp_group = parallel_state.get_data_parallel_group()
36+
self.tp_group = parallel_state.get_tensor_model_parallel_group()
37+
else:
38+
self.dp_cp_group = pg_collection.dp_cp
39+
self.dp_group = pg_collection.dp
40+
self.tp_group = pg_collection.tp
41+
assert (
42+
self.dp_cp_group is not None and self.dp_group is not None and self.tp_group is not None
43+
), "dp_cp_group, dp_group, tp_group must not be None when using hybrid context parallel"
44+
45+
self.cp_balancing_scheduler = BalancedCPScheduler(
46+
max_seq_len_per_rank=self.config.max_seqlen_per_dp_cp_rank, dp_cp_group=self.dp_cp_group
47+
)
48+
49+
self.total_hdp_gpus = self.dp_cp_group.size()
50+
51+
def __iter__(self):
52+
"""Return self as an iterator."""
53+
return self
54+
55+
def get_global_seqlens(self, subsample_seqlens: torch.Tensor) -> List[int]:
56+
"""
57+
Gathers the sequence lengths of all subsamples from all DP ranks.
58+
Each DP rank loads the same number of microbatches but each microbatch
59+
may have a different number of subsamples.
60+
61+
We find the number of subsamples each rank holds and then gather the
62+
sequence lengths of all subsamples from all ranks.
63+
"""
64+
# Collect the number of subsamples from all ranks
65+
local_len = torch.tensor([subsample_seqlens.shape[0]], dtype=torch.int32).cuda()
66+
dp_subsample_count = [torch.zeros_like(local_len) for _ in range(self.dp_group.size())]
67+
torch.distributed.all_gather(dp_subsample_count, local_len, group=self.dp_group)
68+
69+
# Find the max number of subsamples across all ranks and pad subsample_seqlens to max length
70+
dp_subsample_counts = torch.stack(dp_subsample_count, dim=0).cpu().view(-1)
71+
max_sub_samples = int(dp_subsample_counts.max().item())
72+
73+
if local_len.item() < max_sub_samples:
74+
subsample_seqlens_padded = torch.cat(
75+
[
76+
subsample_seqlens,
77+
torch.zeros(max_sub_samples - local_len.item(), dtype=torch.int32).cuda(),
78+
],
79+
dim=0,
80+
)
81+
else:
82+
subsample_seqlens_padded = subsample_seqlens
83+
84+
# Gather the subsample_seqlens from all ranks
85+
seqlens_gathered = [
86+
torch.empty_like(subsample_seqlens_padded) for _ in range(self.dp_group.size())
87+
]
88+
torch.distributed.all_gather(
89+
seqlens_gathered, subsample_seqlens_padded, group=self.dp_group
90+
)
91+
92+
# Trim each seqlens_gathered to the length of the correct sample
93+
for dp_rank, seqlen in enumerate(seqlens_gathered):
94+
seqlens_gathered[dp_rank] = seqlen[: dp_subsample_counts[dp_rank]]
95+
96+
seqlens_gathered = torch.cat(seqlens_gathered, dim=0)
97+
seqlens_gathered = seqlens_gathered.cpu().tolist()
98+
99+
# Calculate the offsets to assign unique global ID to each subsample.
100+
csum = torch.cumsum(dp_subsample_counts, dim=0, dtype=torch.int32)
101+
offsets = torch.cat([torch.zeros(1, dtype=torch.int32), csum[:-1]], dim=0)
102+
103+
return seqlens_gathered, offsets
104+
105+
def get_global_id_seqlens(self, num_local_subsamples, offsets, seqlens_gathered):
106+
"""
107+
Calculates the global ID for each subsample.
108+
109+
We assign a unique global ID to each subsample.
110+
111+
Returns:
112+
global_id_seqlens: list of (global_id, seqlen) tuples for scheduling.
113+
global_ids_this_rank: list of global IDs locally present on this rank.
114+
"""
115+
dp_rank = self.dp_group.rank()
116+
global_ids = torch.arange(len(seqlens_gathered), dtype=torch.int32).cuda()
117+
# Create a list of (global_id, seqlen) tuples for scheduling
118+
global_id_seqlens = [(i, seqlens_gathered[i]) for i in range(len(global_ids))]
119+
# Get the global IDs locally present on this rank
120+
global_ids_this_rank = global_ids[
121+
offsets[dp_rank] : offsets[dp_rank] + num_local_subsamples
122+
]
123+
124+
return global_id_seqlens, global_ids_this_rank
125+
126+
def _gid_to_src_rank(self, gid: int, offsets: List[int]) -> int:
127+
dp_src_rank = torch.bucketize(gid, offsets[1:] - 1)
128+
# Since the torch.distributed.get_process_group_ranks
129+
# provides the global rank, we need to consider TP
130+
hdp_rank = (
131+
torch.distributed.get_process_group_ranks(self.dp_group)[dp_src_rank]
132+
// self.tp_group.size()
133+
)
134+
return hdp_rank
135+
136+
def reroute_samples_to_hdp_ranks(
137+
self, batch, global_ids_this_rank, global_id_seqlens, sample_id_groups, offsets
138+
):
139+
"""
140+
Reroutes the sub-samples to the correct rank after scheduling.
141+
142+
For each key in the batch dict, we perform an all-to-all communication
143+
to transfer the data to the correct ranks.
144+
Since all CP ranks within a DP group have the same data, we only need
145+
to transfer data between matching CP ranks.
146+
"""
147+
gid2local_id = {int(gid): i for i, gid in enumerate(global_ids_this_rank)}
148+
hdp_rank = self.dp_cp_group.rank()
149+
dp_ranks = torch.distributed.get_process_group_ranks(self.dp_group)
150+
# Here we actually want to get the DP group's rank within the HDP group,
151+
# we need to consider TP
152+
dp_ranks = [r // self.tp_group.size() for r in dp_ranks]
153+
154+
data_keys = batch[0].keys()
155+
156+
# Create the send plan
157+
combined_sample_id_groups: List[List[int]] = [[] for _ in range(self.total_hdp_gpus)]
158+
159+
for d in range(self.total_hdp_gpus):
160+
for sample_id_group in sample_id_groups:
161+
combined_sample_id_groups[d].extend(sample_id_group[d])
162+
163+
for dest_rank in range(self.total_hdp_gpus):
164+
combined_sample_id_groups[dest_rank].sort()
165+
166+
# Filter out samples that are not present on this rank
167+
send_ids_sorted = [
168+
gid
169+
for d in dp_ranks
170+
for gid in combined_sample_id_groups[d]
171+
if gid in global_ids_this_rank
172+
]
173+
# send_counts = [len(combined_sample_id_groups[d]) for d in range(self.total_hdp_gpus)]
174+
175+
send_lens_split = [0] * self.total_hdp_gpus
176+
for dest_rank in range(self.total_hdp_gpus):
177+
if dest_rank in dp_ranks:
178+
send_lens_split[dest_rank] = sum(
179+
[
180+
global_id_seqlens[gid][1]
181+
for gid in combined_sample_id_groups[dest_rank]
182+
if gid in global_ids_this_rank
183+
]
184+
)
185+
else:
186+
# We only need to share local data with DP ranks that have different data.
187+
send_lens_split[dest_rank] = 0
188+
189+
# Create the recv plan
190+
recv_sample_id_groups = [[] for _ in range(self.total_hdp_gpus)]
191+
for gid in combined_sample_id_groups[hdp_rank]:
192+
src_rank = self._gid_to_src_rank(gid, offsets)
193+
recv_sample_id_groups[src_rank].append(gid)
194+
195+
recv_lens_split = [0] * self.total_hdp_gpus
196+
for src_rank in range(self.total_hdp_gpus):
197+
recv_lens_split[src_rank] = sum(
198+
[global_id_seqlens[gid][1] for gid in recv_sample_id_groups[src_rank]]
199+
)
200+
201+
recv_ids_sorted = [
202+
gid for d in range(self.total_hdp_gpus) for gid in recv_sample_id_groups[d]
203+
]
204+
recv_counts = [len(recv_sample_id_groups[d]) for d in range(self.total_hdp_gpus)]
205+
206+
recv_samples = [{k: None for k in data_keys} for _ in range(sum(recv_counts))]
207+
208+
def _pack_sample_by_key(key: str) -> torch.Tensor:
209+
flattened_tensors = []
210+
for gid in send_ids_sorted:
211+
t = batch[gid2local_id[gid]][key].to(torch.cuda.current_device(), non_blocking=True)
212+
flattened_tensors.append(t)
213+
return (
214+
torch.cat(flattened_tensors, dim=0)
215+
if flattened_tensors
216+
else torch.empty(0, device=torch.cuda.current_device(), dtype=batch[0][key].dtype)
217+
)
218+
219+
def _unpack_sample_by_key(key: str, recv_tensor: torch.Tensor):
220+
cursor = 0
221+
for i, gid in enumerate(recv_ids_sorted):
222+
sample_len = global_id_seqlens[gid][1]
223+
recv_samples[i][key] = recv_tensor[cursor : cursor + sample_len]
224+
cursor += sample_len
225+
226+
for key in data_keys:
227+
send_tensor = _pack_sample_by_key(key)
228+
recv_tensor = torch.empty(
229+
sum(recv_lens_split), device=torch.cuda.current_device(), dtype=send_tensor.dtype
230+
)
231+
torch.distributed.all_to_all_single(
232+
output=recv_tensor,
233+
input=send_tensor,
234+
output_split_sizes=recv_lens_split,
235+
input_split_sizes=send_lens_split,
236+
group=self.dp_cp_group,
237+
)
238+
_unpack_sample_by_key(key, recv_tensor)
239+
240+
recv_sample_with_id = {
241+
recv_id: recv_samples[i] for i, recv_id in enumerate(recv_ids_sorted)
242+
}
243+
return recv_sample_with_id
244+
245+
def unpack_batch(self, batch):
246+
"""
247+
Unpacks the packed samples into a list of sub-samples.
248+
Since each sub-sample may be routed to different DPxCP ranks,
249+
we unpack the sample here to avoid unnecessarily transferring
250+
the entire packed sample.
251+
"""
252+
batch_unpacked = []
253+
for sample in batch:
254+
for sub_sample in range(sample["cu_seqlens"].shape[0] - 1):
255+
sub_sample_dict = {}
256+
start_idx = sample["cu_seqlens"][sub_sample]
257+
end_idx = sample["cu_seqlens"][sub_sample + 1]
258+
if end_idx - start_idx == 0:
259+
continue
260+
for key in sample.keys():
261+
if key in ["cu_seqlens", "batch_idx", "max_seqlen"]:
262+
continue
263+
sub_sample_dict[key] = sample[key][start_idx:end_idx]
264+
batch_unpacked.append(sub_sample_dict)
265+
return batch_unpacked
266+
267+
def __next__(self) -> Any:
268+
"""
269+
Get the next item from the dataset, pull scheduling metadata and return it.
270+
"""
271+
if self.data_iterator is None:
272+
# TP0 reads from data_iterator, others receive via broadcast.
273+
return None, None
274+
else:
275+
batch = next(self.data_iterator)
276+
subsample_seqlens = []
277+
for sample in batch:
278+
subsample_seqlens.extend(
279+
[
280+
int(sample["cu_seqlens"][i + 1] - sample["cu_seqlens"][i])
281+
for i in range(0, sample["cu_seqlens"].shape[0] - 1)
282+
]
283+
)
284+
subsample_seqlens = torch.tensor(subsample_seqlens, dtype=torch.int32).cuda()
285+
subsample_seqlens = subsample_seqlens[subsample_seqlens != 0]
286+
287+
seqlens_gathered, offsets = self.get_global_seqlens(subsample_seqlens)
288+
289+
global_id_seqlens, global_ids_this_rank = self.get_global_id_seqlens(
290+
subsample_seqlens.shape[0], offsets, seqlens_gathered
291+
)
292+
293+
groups, sample_id_groups = self.cp_balancing_scheduler.get_groups_and_subsamples(
294+
global_id_seqlens, self.config
295+
)
296+
297+
batch = self.unpack_batch(batch)
298+
samples_this_rank_with_id = self.reroute_samples_to_hdp_ranks(
299+
batch, global_ids_this_rank, global_id_seqlens, sample_id_groups, offsets
300+
)
301+
return samples_this_rank_with_id, sample_id_groups

megatron/core/datasets/gpt_dataset.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import logging
44
import os
55
import time
6-
from dataclasses import dataclass
6+
from dataclasses import dataclass, field
77
from math import ceil
88
from typing import Dict, Optional, Tuple
99

@@ -50,11 +50,32 @@ class GPTDatasetConfig(BlendedMegatronDatasetConfig):
5050
object_storage_cache_path: Optional[str] = None
5151
"""Path for caching indices for s3 or msc dataloading."""
5252

53+
context_parallel_size: int = 1
54+
"""Option to enable context parallelism"""
55+
56+
data_parallel_size: int = 1
57+
"""Option to enable data parallelism"""
58+
59+
sequence_parallel_size: int = 0
60+
"""Option to indicate the sequence parallelism size when using TP
61+
Set to 0 if sequence parallel is not enabled regardless of TP size.
62+
"""
63+
64+
hybrid_context_parallel: bool = False
65+
"""Option to enable hybrid context parallelism. When setting this to True,
66+
each sample should be divisible by the data parallel size * context parallel size * 2.
67+
If sequence parallel is enabled, it should be divisible by the
68+
data parallel size * context parallel size * sequence parallel size * 2.
69+
"""
70+
5371
sequences_per_dataset: Optional[Dict[str, int]] = None
5472
"""If provided, the sequence and document counts for each dataset.
5573
Check --per-dataset-sequences-path
5674
"""
5775

76+
token_dtype_code: Optional[int] = field(init=False, default=None)
77+
"""The dtype code for the token ids. 4 for int32, 8 for uint16."""
78+
5879
context_parallel_size: Optional[int] = None
5980
"""The size of the context parallel group. Needed for padding in packed sequences."""
6081

0 commit comments

Comments
 (0)