Skip to content

Commit 48e91d2

Browse files
committed
fix nan bugs, need to fix flops calculation.
Signed-off-by: tailaim <tailaim@nvidia.com>
1 parent 31fe2ea commit 48e91d2

File tree

13 files changed

+333
-1143
lines changed

13 files changed

+333
-1143
lines changed

examples/run_hybrid_cp.sh

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ USER=$SLURM_JOB_USER
2424

2525
# Auto-detect batch or interactive mode.
2626
which srun
27-
BATCH=$((1-$?))
27+
# BATCH=$((1-$?))
28+
BATCH=0
2829

2930
DEBUG=0
3031
USE_TILING=1
@@ -39,7 +40,7 @@ USE_MOCK_DATA=1
3940
if [[ $BATCH -eq 0 ]]; then
4041
DATETIME=`date +'%y-%m-%d-%H-%M-%S'`
4142
MODEL_NAME="interactive_hybrid_cp"
42-
DEBUG=1
43+
# DEBUG=1
4344
else
4445
MODEL_NAME="interactive_hybrid_cp"
4546
fi
@@ -58,7 +59,7 @@ export HF_DATASETS_CACHE="${OUTPUT}/hf_datasets_cache"
5859

5960
DATA_TRAIN="/lustre/fs1/portfolios/llmservice/users/adithyare/sft/nano_v2_fake_packed_131072_10000_rndm//stage1_stage2_multiling_128k_seq_packed.empty_assist_filtered.shuf.jsonl"
6061

61-
SEQ_LEN=131072 #131072 #81920 #65536
62+
SEQ_LEN=1024 #131072 #81920 #65536
6263

6364
if [[ $DEBUG -eq 1 ]]; then
6465
MBZ=1

megatron/core/datasets/blended_megatron_dataset_builder.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,6 @@ def _build_blended_dataset_splits(self) -> List[Optional[TopLevelDataset]]:
146146
##
147147
if self.config.mock:
148148
split = self.config.split_matrix
149-
# debugmtl
150-
# print("for debug, building mock datasets, size is {self.sizes},split is {split}")
151149
try:
152150
return self._build_megatron_dataset_splits(None, split, self.sizes)
153151
except Exception as error:

megatron/core/extensions/transformer_engine.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -935,6 +935,13 @@ def __init__(
935935
else:
936936
extra_kwargs["cp_comm_type"] = cp_comm_type
937937

938+
# we need to create a single stream for cp=1 and enable hybrid cp case
939+
if (
940+
self.config.hybrid_context_parallel
941+
and getattr(TEDotProductAttention, "cp_stream") is None
942+
):
943+
TEDotProductAttention.cp_stream = torch.cuda.Stream()
944+
938945
if self.config.deterministic_mode:
939946
if int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")) != 0:
940947
raise RuntimeError(

megatron/core/model_parallel_config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,14 @@ class ModelParallelConfig:
7878
Else, it would be controlled by the maximum sequence length / context parallel size.
7979
"""
8080

81+
balanced_sequence_packing: bool = False
82+
"""
83+
If true, enables balanced sequence packing.
84+
This is used to pack samples with variable sequence lengths into a single sample
85+
such that each packed sample has similar total sequence lengths.
86+
This is useful to improve the efficiency of sequence packing.
87+
"""
88+
8189
expert_model_parallel_size: int = 1
8290
"""Distributes Moe Experts across sub data parallel dimension."""
8391

megatron/core/pipeline_parallel/data_schedule.py

Lines changed: 144 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
22

3+
import enum
34
from collections import deque
45
from functools import lru_cache
56
from math import ceil, log2
6-
from typing import Callable, List, Optional, Tuple
7+
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
78

89
import numpy as np
910
import torch
@@ -13,8 +14,18 @@
1314
from megatron.core.rerun_state_machine import RerunDataIterator
1415

1516

16-
def wrap_hybrid_cp_dataloader(
17-
data_iterator, config, pg_collection: Optional[ProcessGroupCollection] = None
17+
class PackingScheduler(enum.Enum):
18+
"""Enum for supported sequence packing algorithms."""
19+
20+
HYBRID_CP = "hybrid_cp"
21+
NAIVE_SEQUENCE_PACKING = "naive_sequence_packing"
22+
23+
24+
def wrap_dataloader(
25+
data_iterator,
26+
config,
27+
scheduler_type: Union[PackingScheduler, str],
28+
pg_collection: Optional[ProcessGroupCollection] = None,
1829
):
1930
"""
2031
A wrapper function that wraps around an existing data_iterator
@@ -26,6 +37,13 @@ def wrap_hybrid_cp_dataloader(
2637
dp_cp_group: Data parallel context parallel group.
2738
"""
2839

40+
scheduler_map = {"hybrid_cp": BalancedHybridCPscheduler, "naive": NaiveSequencePackingScheduler}
41+
42+
scheduler_map: Dict[PackingScheduler, Type[BaseScheduler]] = {
43+
PackingScheduler.HYBRID_CP: BalancedHybridCPscheduler,
44+
PackingScheduler.NAIVE_SEQUENCE_PACKING: NaiveSequencePackingScheduler,
45+
}
46+
2947
def _get_global_seqlens(subsample_seqlens: torch.Tensor, dp_group) -> List[int]:
3048
"""
3149
Gathers the sequence lengths of all subsamples from all DP ranks.
@@ -151,16 +169,17 @@ def _reroute_samples_to_hdp_ranks(
151169
]
152170
# send_counts = [len(combined_sample_id_groups[d]) for d in range(total_hdp_gpus)]
153171

172+
send_num_split = [0] * total_hdp_gpus
154173
send_lens_split = [0] * total_hdp_gpus
155174
for dest_rank in range(total_hdp_gpus):
156175
if dest_rank in dp_ranks:
157-
send_lens_split[dest_rank] = sum(
158-
[
159-
global_id_seqlens[gid][1]
160-
for gid in combined_sample_id_groups[dest_rank]
161-
if gid in global_ids_this_rank
162-
]
163-
)
176+
send_seq_lens = [
177+
global_id_seqlens[gid][1]
178+
for gid in combined_sample_id_groups[dest_rank]
179+
if gid in global_ids_this_rank
180+
]
181+
send_num_split[dest_rank] = len(send_seq_lens)
182+
send_lens_split[dest_rank] = sum(send_seq_lens)
164183
else:
165184
# We only need to share local data with DP ranks that have different data.
166185
send_lens_split[dest_rank] = 0
@@ -197,20 +216,30 @@ def _pack_sample_by_key(key: str) -> torch.Tensor:
197216
def _unpack_sample_by_key(key: str, recv_tensor: torch.Tensor):
198217
cursor = 0
199218
for i, gid in enumerate(recv_ids_sorted):
200-
sample_len = global_id_seqlens[gid][1]
219+
sample_len = 1 if key in ["original_seq_len"] else global_id_seqlens[gid][1]
201220
recv_samples[i][key] = recv_tensor[cursor : cursor + sample_len]
202221
cursor += sample_len
203222

204223
for key in data_keys:
224+
output_split_sizes, input_split_sizes = (
225+
(recv_counts, send_num_split)
226+
if key in ["original_seq_len"]
227+
else (recv_lens_split, send_lens_split)
228+
)
205229
send_tensor = _pack_sample_by_key(key)
230+
recv_tensor_size = sum(output_split_sizes)
206231
recv_tensor = torch.empty(
207-
sum(recv_lens_split), device=torch.cuda.current_device(), dtype=send_tensor.dtype
232+
recv_tensor_size, device=torch.cuda.current_device(), dtype=send_tensor.dtype
208233
)
234+
# debugmtl
235+
# print(f"ready to all to all for key:{key}, output_split_sizes:{output_split_sizes},
236+
# input_split_sizes:{input_split_sizes}, recv_tensor_size:
237+
# {tensor_size},send_tensor_size:{send_tensor.size(0)}")
209238
torch.distributed.all_to_all_single(
210239
output=recv_tensor,
211240
input=send_tensor,
212-
output_split_sizes=recv_lens_split,
213-
input_split_sizes=send_lens_split,
241+
output_split_sizes=output_split_sizes,
242+
input_split_sizes=input_split_sizes,
214243
group=dp_cp_group,
215244
)
216245
_unpack_sample_by_key(key, recv_tensor)
@@ -245,29 +274,24 @@ def _broadcast(item):
245274
group=parallel_state.get_tensor_model_parallel_group(),
246275
)
247276

248-
def _broadcast_num_samples_this_group(num_samples_this_group):
249-
dev = torch.cuda.current_device()
250-
# TODO(tailaim) do we need this barrier?
251-
torch.distributed.barrier()
252-
253-
n = 0 if num_samples_this_group is None else int(num_samples_this_group.numel())
254-
n = torch.tensor([n], dtype=torch.int64, device=dev)
255-
256-
_broadcast(n)
257-
n = int(n.item())
277+
# Convert string to enum if needed
278+
if isinstance(scheduler_type, str):
279+
try:
280+
scheduler_type = PackingScheduler[scheduler_type.upper()]
281+
except KeyError:
282+
available_scheduler = ", ".join([scheduler.name for scheduler in PackingScheduler])
283+
raise ValueError(
284+
f"Unknown packing scheduler: {scheduler_type}. "
285+
f"Available schedulers: {available_scheduler}"
286+
)
258287

259-
assert n > 0, "there should be at least 1 sub samples in the group"
260-
num_samples_this_group_broadcast = (
261-
torch.empty(n, dtype=torch.int32, device=dev)
262-
if num_samples_this_group is None
263-
else num_samples_this_group
288+
if scheduler_type not in scheduler_map:
289+
available_scheduler = ", ".join([scheduler.name for scheduler in PackingScheduler])
290+
raise ValueError(
291+
f"Unknown scheduler: {scheduler}. " f"Available schedulers: {available_scheduler}"
264292
)
265-
_broadcast(num_samples_this_group_broadcast)
266-
return num_samples_this_group_broadcast
267293

268-
cp_balancing_scheduler = BalancedHybridCPScheduler(
269-
max_seq_len_per_rank=config.max_seqlen_per_dp_cp_rank
270-
)
294+
scheduler = scheduler_map[scheduler_type](config)
271295
if pg_collection is None:
272296
dp_cp_group = parallel_state.get_data_parallel_group(with_context_parallel=True)
273297
dp_group = parallel_state.get_data_parallel_group()
@@ -305,9 +329,7 @@ def _broadcast_num_samples_this_group(num_samples_this_group):
305329
subsample_seqlens.shape[0], offsets, seqlens_gathered, dp_group
306330
)
307331

308-
groups, sample_id_groups = cp_balancing_scheduler.get_groups_and_subsamples(
309-
global_id_seqlens, config
310-
)
332+
groups, sample_id_groups = scheduler.get_groups_and_subsamples(global_id_seqlens, config)
311333

312334
batch = _unpack_batch(batch)
313335
samples_this_rank_with_id = _reroute_samples_to_hdp_ranks(
@@ -353,10 +375,10 @@ def _pack_tensors(tensors):
353375

354376
# TODO(tailaim): do we need attention_mask for sequence packing?
355377
new_sample = {}
356-
new_sample["tokens"] = tokens
357-
new_sample["labels"] = labels
358-
new_sample["loss_mask"] = loss_mask
359-
new_sample["position_ids"] = position_ids
378+
new_sample["tokens"] = tokens.unsqueeze(0)
379+
new_sample["labels"] = labels.unsqueeze(0)
380+
new_sample["loss_mask"] = loss_mask.unsqueeze(0)
381+
new_sample["position_ids"] = position_ids.unsqueeze(0)
360382
new_sample["local_cp_size"] = torch.tensor(
361383
partner_cp_size, dtype=torch.int32, device=dev
362384
)
@@ -367,9 +389,11 @@ def _pack_tensors(tensors):
367389
)
368390
cu_seqlens_padded = np.empty(len(samples) + 1, dtype=np.int32)
369391
cu_seqlens_padded[0] = 0
370-
np.cumsum(lengths_padding, out=cu_seqlens_padded[1:])
371-
cu_seqlens_padded = torch.from_numpy(cu_seqlens_padded).to(
372-
device=dev, non_blocking=True
392+
cu_seqlens_padded[1:] = np.cumsum(lengths_padding, out=cu_seqlens_padded[1:])
393+
cu_seqlens_padded = (
394+
torch.from_numpy(cu_seqlens_padded)
395+
.to(device=dev, non_blocking=True, dtype=torch.int32)
396+
.reshape(-1)
373397
)
374398
new_sample["cu_seqlens_padded"] = cu_seqlens_padded
375399

@@ -379,27 +403,95 @@ def _pack_tensors(tensors):
379403
new_sample["max_seqlen"] = max_seqlen
380404

381405
# create cu_seqlens without padding
382-
lengths = torch.stack([s["original_seq_len"] for s in samples], dim=0)
406+
lengths = torch.stack([s["original_seq_len"] for s in samples], dim=0).reshape(-1)
383407
cu_seqlens = torch.empty(lengths.numel() + 1, device=dev, dtype=torch.int32)
384408
cu_seqlens[0] = 0
385-
cu_seqlens[1:] = torch.cumsum(lengths, dim=0)
409+
cu_seqlens[1:] = torch.cumsum(lengths, dim=0).reshape(-1)
386410
new_sample["cu_seqlens"] = cu_seqlens
387411

388412
new_samples.append(new_sample)
389-
413+
# debugmtl
414+
# print(f"new_samples type: {type(new_samples)}, new_sample type: {type(new_samples[0])}")
390415
new_data_iterator = RerunDataIterator(iter(new_samples))
391416

417+
# debugmtl
418+
# data = next(new_data_iterator)
419+
# print(f"data type: {type(data)}")
420+
# print(data)
421+
392422
return new_data_iterator, num_micro_batches
393423

394424

395-
class BalancedHybridCPScheduler:
425+
class BaseScheduler:
426+
"""
427+
Base class for sequence packing schedulers.
428+
"""
429+
430+
def __init__(self, config):
431+
pass
432+
433+
434+
class NaiveSequencePackingScheduler(BaseScheduler):
435+
"""
436+
This scheduler simply packs sequences in their original order
437+
until reaching the max sequence length.
438+
It does not reorder sequences nor perform any load balancing.
439+
"""
440+
441+
def __init__(self, config):
442+
super().__init__(config)
443+
self.max_seq_len_all_ranks = config.max_seqlen_per_dp_cp_rank * config.context_parallel_size
444+
self.dp_size = parallel_state.get_data_parallel_world_size()
445+
446+
def get_groups_and_subsamples(self, sample_id_seqlens, config):
447+
"""
448+
This scheduler simply packs sequences in their original order
449+
until reaching the max sequence length.
450+
It does not reorder sequences nor perform any load balancing.
451+
"""
452+
groups = []
453+
sample_id_groups = []
454+
sum_seqlen = 0
455+
single_microbatch = []
456+
457+
for i in range(len(sample_id_seqlens)):
458+
if sum_seqlen + sample_id_seqlens[i] <= self.max_seq_len_all_ranks:
459+
single_microbatch.append(i)
460+
sum_seqlen += sample_id_seqlens[i][1]
461+
else:
462+
groups.append(single_microbatch)
463+
sample_id_groups.append(single_microbatch)
464+
single_microbatch = [i]
465+
sum_seqlen = sample_id_seqlens[i][1]
466+
467+
# we want the number of microbatches to be multiple of dp_size
468+
# so we move few samples from previous microbatch
469+
# to the end of the microbatches if needed
470+
num_microbatches_before = len(sample_id_groups)
471+
if num_microbatches_before % self.dp_size != 0:
472+
remainder = num_microbatches_before % self.dp_size
473+
num_to_move = self.dp_size - remainder
474+
i = num_microbatches_before - 1
475+
while num_to_move > 0:
476+
assert i > 0, "Not enough samples to move"
477+
if len(sample_id_groups[i]) > 1:
478+
seq_id = sample_id_groups[i].pop()
479+
sample_id_groups[i].append(seq_id)
480+
num_to_move -= 1
481+
else:
482+
i -= 1
483+
return groups, sample_id_groups
484+
485+
486+
class BalancedHybridCPscheduler(BaseScheduler):
396487
"""
397488
This class provides the functionality to form groups of sub-samples
398489
such that all DPxCP ranks have a roughly balanced workload in the group.
399490
"""
400491

401-
def __init__(self, max_seq_len_per_rank: int):
402-
self.max_seq_len_per_rank = max_seq_len_per_rank
492+
def __init__(self, config):
493+
super().__init__(config)
494+
self.max_seq_len_per_rank = config.max_seqlen_per_dp_cp_rank
403495
self.num_subsamples = 0
404496
self.num_subsamples_processed = 0
405497
self.free_resources = []
@@ -614,6 +706,8 @@ def next_hdp_group(
614706
else:
615707
chosen_members = group_members[best_gid]
616708
else:
709+
if best_gid is None:
710+
break
617711
chosen_members = group_members[best_gid]
618712

619713
# ---- Step 2 – if we decided to create a fresh group ----------------
@@ -731,7 +825,6 @@ def trim_overload():
731825
else:
732826
break
733827

734-
# debugmtl make sure total_seq_len after packing smaller than max_seq_len
735828
# trim_overload()
736829

737830
# Track samples in this group before redistribution to empty GPUs

0 commit comments

Comments
 (0)