Skip to content

Commit 5d47936

Browse files
asolergi-nvBoxiangW
authored andcommitted
Fix SFT Pipeline when TP>1 (NVIDIA#3268)
1 parent 9afa139 commit 5d47936

File tree

1 file changed

+2
-45
lines changed

1 file changed

+2
-45
lines changed

megatron/training/utils.py

Lines changed: 2 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -600,25 +600,6 @@ def _broadcast_cu_seqlens(cu_seqlens):
600600
_broadcast(batch['loss_mask'])
601601
_broadcast(batch['attention_mask'])
602602

603-
def _broadcast_cu_seqlens(cu_seqlens):
604-
dev = torch.cuda.current_device()
605-
606-
n = 0 if cu_seqlens is None else int(cu_seqlens.numel())
607-
n_tensor = torch.tensor(n, dtype=torch.int64, device=dev)
608-
_broadcast(n_tensor)
609-
610-
if n == 0:
611-
buf = torch.empty(0, dtype=torch.int32, device=dev)
612-
else:
613-
assert isinstance(cu_seqlens, torch.Tensor)
614-
assert cu_seqlens.dtype == torch.int32
615-
assert cu_seqlens.shape[0] == 1, "micro-batch-size must be 1 for packing"
616-
buf = cu_seqlens.to(device=dev, non_blocking=True).contiguous()
617-
_broadcast(buf)
618-
619-
_broadcast_cu_seqlens(batch['cu_seqlens'])
620-
_broadcast(batch['max_seqlen'])
621-
622603
else:
623604
if args.hybrid_context_parallel:
624605
seq_len = torch.tensor(0, dtype=torch.int32, device=torch.cuda.current_device())
@@ -657,21 +638,15 @@ def _broadcast_cu_seqlens(cu_seqlens):
657638
device=torch.cuda.current_device(),
658639
)
659640
cu_seqlens = None
660-
if args.sft:
641+
if args.hybrid_context_parallel or args.sft:
661642
max_seqlen = torch.empty(
662643
1,
663644
dtype=torch.int32,
664645
device=torch.cuda.current_device(),
665646
)
666647
else:
667648
max_seqlen = None
668-
669-
cu_seqlens = None
670-
max_seqlen = torch.empty(
671-
1,
672-
dtype=torch.int32,
673-
device=torch.cuda.current_device(),
674-
) if args.hybrid_context_parallel else None
649+
675650
local_cp_size = torch.empty(
676651
1,
677652
dtype=torch.int32,
@@ -726,24 +701,6 @@ def _broadcast_cu_seqlens():
726701
_broadcast(loss_mask)
727702
_broadcast(attention_mask)
728703

729-
def _broadcast_cu_seqlens():
730-
dev = torch.cuda.current_device()
731-
732-
n = torch.empty((), dtype=torch.int64, device=dev)
733-
_broadcast(n)
734-
n = int(n.item())
735-
736-
if n == 0:
737-
cu_seqlens = torch.empty(0, dtype=torch.int32, device=dev)
738-
else:
739-
cu_seqlens = torch.empty((args.micro_batch_size, n), dtype=torch.int32, device=dev)
740-
_broadcast(cu_seqlens)
741-
742-
return cu_seqlens if n > 0 else None
743-
744-
cu_seqlens = _broadcast_cu_seqlens()
745-
_broadcast(max_seqlen)
746-
747704
batch = {
748705
'tokens': tokens,
749706
'labels': labels,

0 commit comments

Comments
 (0)