Skip to content

Commit 983e5f3

Browse files
committed
support tp
Signed-off-by: tailaim <tailaim@nvidia.com>
1 parent 48e91d2 commit 983e5f3

File tree

5 files changed

+44
-22
lines changed

5 files changed

+44
-22
lines changed

megatron/core/pipeline_parallel/data_schedule.py

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -307,10 +307,9 @@ def _broadcast(item):
307307
total_hdp_gpus = dp_cp_group.size()
308308
dev = torch.cuda.current_device()
309309

310-
# TODO(tailaim): handle the case when data_iterator is None
311310
if data_iterator is None:
312311
# TP-0 reads from data_iterator, others receive via broadcast.
313-
sample_id_groups, batch = None, None, None
312+
sample_id_groups, batch = None, None
314313
num_total_groups_broadcast = torch.tensor([0], dtype=torch.int32, device=dev)
315314
_broadcast(num_total_groups_broadcast)
316315
num_micro_batches = int(num_total_groups_broadcast.item())
@@ -352,6 +351,10 @@ def _broadcast(item):
352351
)
353352
_broadcast(num_total_groups_broadcast)
354353

354+
# TODO(tailaim): calculate this two values properly
355+
# num_total_tokens_this_GA = losses_reduced.pop(0)
356+
# sequence_square_sum_this_GA = losses_reduced.pop(0)
357+
355358
# pack sequences in the same group and create a new data iterator
356359
new_samples = []
357360
for i in range(num_micro_batches):
@@ -375,10 +378,10 @@ def _pack_tensors(tensors):
375378

376379
# TODO(tailaim): do we need attention_mask for sequence packing?
377380
new_sample = {}
378-
new_sample["tokens"] = tokens.unsqueeze(0)
379-
new_sample["labels"] = labels.unsqueeze(0)
380-
new_sample["loss_mask"] = loss_mask.unsqueeze(0)
381-
new_sample["position_ids"] = position_ids.unsqueeze(0)
381+
new_sample["tokens"] = tokens
382+
new_sample["labels"] = labels
383+
new_sample["loss_mask"] = loss_mask
384+
new_sample["position_ids"] = position_ids
382385
new_sample["local_cp_size"] = torch.tensor(
383386
partner_cp_size, dtype=torch.int32, device=dev
384387
)
@@ -442,6 +445,7 @@ def __init__(self, config):
442445
super().__init__(config)
443446
self.max_seq_len_all_ranks = config.max_seqlen_per_dp_cp_rank * config.context_parallel_size
444447
self.dp_size = parallel_state.get_data_parallel_world_size()
448+
self.cp_size = parallel_state.get_context_parallel_world_size()
445449

446450
def get_groups_and_subsamples(self, sample_id_seqlens, config):
447451
"""
@@ -451,35 +455,44 @@ def get_groups_and_subsamples(self, sample_id_seqlens, config):
451455
"""
452456
groups = []
453457
sample_id_groups = []
458+
packed_id_groups = []
454459
sum_seqlen = 0
455460
single_microbatch = []
456461

457462
for i in range(len(sample_id_seqlens)):
458-
if sum_seqlen + sample_id_seqlens[i] <= self.max_seq_len_all_ranks:
463+
if sum_seqlen + sample_id_seqlens[i][1] <= self.max_seq_len_all_ranks:
459464
single_microbatch.append(i)
460465
sum_seqlen += sample_id_seqlens[i][1]
461466
else:
462467
groups.append(single_microbatch)
463-
sample_id_groups.append(single_microbatch)
468+
packed_id_groups.append(single_microbatch)
464469
single_microbatch = [i]
465470
sum_seqlen = sample_id_seqlens[i][1]
466471

467-
# we want the number of microbatches to be multiple of dp_size
472+
# we want the number of packed sequences to be multiple of dp_size
468473
# so we move few samples from previous microbatch
469474
# to the end of the microbatches if needed
470-
num_microbatches_before = len(sample_id_groups)
471-
if num_microbatches_before % self.dp_size != 0:
472-
remainder = num_microbatches_before % self.dp_size
475+
num_packed_sequence = len(packed_id_groups)
476+
if num_packed_sequence % self.dp_size != 0:
477+
remainder = num_packed_sequence % self.dp_size
473478
num_to_move = self.dp_size - remainder
474-
i = num_microbatches_before - 1
479+
i = num_packed_sequence - 1
475480
while num_to_move > 0:
476481
assert i > 0, "Not enough samples to move"
477-
if len(sample_id_groups[i]) > 1:
478-
seq_id = sample_id_groups[i].pop()
479-
sample_id_groups[i].append(seq_id)
482+
if len(packed_id_groups[i]) > 1:
483+
seq_id = packed_id_groups[i].pop()
484+
packed_id_groups[i].append(seq_id)
480485
num_to_move -= 1
481486
else:
482487
i -= 1
488+
489+
num_micro_batches = int(len(packed_id_groups) / self.dp_size)
490+
for i in range(num_micro_batches):
491+
sample_id_groups.append([])
492+
for j in range(self.cp_size * self.dp_size):
493+
seq_id = int(i * self.dp_size + j / self.cp_size)
494+
sample_id_groups[i].append(packed_id_groups[seq_id])
495+
483496
return groups, sample_id_groups
484497

485498

megatron/core/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1910,11 +1910,16 @@ def get_thd_batch_on_this_cp_rank(
19101910
cp_group=cp_group,
19111911
)
19121912

1913+
for key in ['tokens', 'position_ids', 'labels', 'loss_mask']:
1914+
if key in batch:
1915+
batch[key] = batch[key].unsqueeze(0)
1916+
19131917
if cp_size > 1: # slice batch along sequence dimension for context parallelism
19141918
assert tex is not None and is_te_min_version("1.10.0"), (
19151919
"Please update Transformer Engine to >= 1.10 to use "
19161920
"Context Parallel with THD format data"
19171921
)
1922+
# print(f"tokens shape before cp slice: {batch['tokens'].shape}")
19181923
index = tex.thd_get_partitioned_indices(
19191924
cu_seqlens_padded, batch['tokens'].size(1), cp_size, cp_rank
19201925
)

megatron/training/arguments.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -932,7 +932,6 @@ def validate_args(args, defaults={}):
932932
if args.hybrid_context_parallel:
933933
assert not args.pipeline_model_parallel_size > 1, 'Hybrid context parallelism not supported with pipeline parallelism'
934934
assert not args.enable_cuda_graph, 'Hybrid context parallelism not supported with CUDA Graph'
935-
assert not args.use_megatron_fsdp, 'Hybrid context parallelism not supported with Megatron FSDP'
936935
assert args.dataloader_type == 'single', 'Hybrid context parallelism only supported with single dataloader type'
937936
assert args.calculate_per_token_loss, 'Hybrid context parallelism must be used with --calculate-per-token-loss'
938937

megatron/training/datasets/sft_dataset.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,14 @@ def _calculate_padding_divisor(self) -> int:
8181
cp_pad = cp_pad * dp_size if hybrid_cp else cp_pad
8282
divisor = cp_pad * tp_pad
8383
"""
84-
cp_pad = self.config.context_parallel_size * 2 if self.config.context_parallel_size > 1 else 1
85-
cp_pad = cp_pad * self.config.data_parallel_size if self.config.hybrid_context_parallel else cp_pad
84+
if self.config.hybrid_context_parallel:
85+
# Hybrid CP: consider both CP and DP
86+
cp_pad = self.config.data_parallel_size * self.config.context_parallel_size * 2
87+
else:
88+
# Standard CP: only consider CP
89+
cp_pad = self.context_parallel_size * 2 if self.context_parallel_size > 1 else 1
8690
tp_pad = self.config.sequence_parallel_size if self.config.sequence_parallel_size > 0 else 1
8791
divisor = cp_pad * tp_pad
88-
8992
return divisor
9093

9194
def get_padding_size(

megatron/training/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -563,8 +563,6 @@ def _broadcast_cu_seqlens(cu_seqlens):
563563
else:
564564
assert isinstance(cu_seqlens, torch.Tensor)
565565
assert cu_seqlens.dtype == torch.int32
566-
#TODO(tailaim): verify the shape for this tensor
567-
# assert cu_seqlens.shape[0] == 1, "micro-batch-size must be 1 for packing"
568566
buf = cu_seqlens.to(device=dev, non_blocking=True).contiguous()
569567
_broadcast(buf)
570568

@@ -732,6 +730,10 @@ def _broadcast_cu_seqlens():
732730
'local_cp_size': local_cp_size,
733731
}
734732

733+
if not args.sft_sequence_packing:
734+
keys_to_keep = ['tokens', 'labels', 'loss_mask', 'attention_mask', 'position_ids']
735+
batch = {k: v for k, v in batch.items() if k in keys_to_keep}
736+
735737
return batch
736738

737739

0 commit comments

Comments
 (0)