@@ -569,117 +569,163 @@ def _pt_pad_to_multiple_of(batch: dict[str, Any], pad_to_multiple_of: int, token
569569# TODO(@jomitchell): Once this gets merged: https://github.com/NVIDIA/TransformerEngine/pull/2387
570570# we can replace this with the one in TransformerEngine.
571571def _split_batch_by_cp_rank (
572- cu_seqlens_padded : torch .Tensor ,
572+ cu_seqlens_padded : torch .Tensor | None ,
573573 input_ids_padded : torch .Tensor ,
574574 labels_padded : torch .Tensor ,
575575 cp_group : torch .distributed .ProcessGroup | None = None ,
576576 qvk_format : str = "thd" ,
577577 cp_rank : int | None = None ,
578578 cp_world_size : int | None = None ,
579579):
580- """Slice batch input along sequence dimension into multiple chunks for THD format.
580+ """Slice batch input along sequence dimension into multiple chunks for THD or BSHD format.
581581
582- This function is inteded for use in self attention. It will not work for cross attention because
582+ This function is intended for use in self attention. It will not work for cross attention because
583583 it does not handle the case where the sequence length of the query and key are different.
584584 Which are parallelized across GPUs in a context parallel group.
585- This version works with variable-length sequences using cumulative sequence lengths.
585+ This version works with variable-length sequences using cumulative sequence lengths for THD format,
586+ and with padded sequences for BSHD format.
586587
587588 Args:
588- cu_seqlens_padded: Cumulative sequence length.
589+ cu_seqlens_padded: Cumulative sequence length. Required for THD format, optional for BSHD format.
589590 input_ids_padded: Input IDs.
590591 labels_padded: Labels.
591592 cp_group: Context parallel group.
592- qvk_format: Format of the input data.
593+ qvk_format: Format of the input data ("thd" or "bshd") .
593594 cp_world_size: The size of the context parallelism group. If provided, the function will use this value to determine the rank.
594595 cp_rank: Optional manual CP rank index. When provided, the function shards tensors as if it
595596 were executing on that rank without querying `torch.distributed.get_rank`.
596597 """
597598 if qvk_format not in ["thd" , "bshd" , "sbhd" ]:
598599 raise ValueError (f"Unsupported qvk_format: { qvk_format } !" )
600+
601+ if cp_world_size is None or cp_world_size <= 1 :
602+ # No splitting needed
603+ return input_ids_padded , labels_padded
604+
605+ if cp_rank is None :
606+ cp_rank = torch .distributed .get_rank (group = cp_group )
607+ elif not (0 <= cp_rank < cp_world_size ):
608+ raise ValueError (f"cp_rank must be in [0, { cp_world_size } ), but received { cp_rank } ." )
609+
599610 if qvk_format == "thd" :
600- # Get context parallel size and rank
601- if cp_world_size > 1 :
602- if cp_rank is None :
603- cp_rank = torch .distributed .get_rank (group = cp_group )
604- elif not (0 <= cp_rank < cp_world_size ):
605- raise ValueError (f"cp_rank must be in [0, { cp_world_size } ), but received { cp_rank } ." )
606-
607- # Calculate the chunk sizes for each sequence
608- total_slices_of_any_sequence = 2 * cp_world_size
609- slice_sizes = (cu_seqlens_padded [1 :] - cu_seqlens_padded [:- 1 ]) // total_slices_of_any_sequence
610-
611- # Process each tensor directly instead of using keys_to_change loop
612- def process_tensor (val ):
613- if val is None :
614- return val
615- # Determine which dimension is the sequence dimension
616- # Ensure cu_seqlens_padded[-1] is a Python int, not a 0-dim tensor
617- if isinstance (cu_seqlens_padded [- 1 ], torch .Tensor ):
618- seq_len_val = cu_seqlens_padded [- 1 ].item ()
611+ if cu_seqlens_padded is None :
612+ raise ValueError ("cu_seqlens_padded is required for THD format" )
613+
614+ # Calculate the chunk sizes for each sequence
615+ total_slices_of_any_sequence = 2 * cp_world_size
616+ slice_sizes = (cu_seqlens_padded [1 :] - cu_seqlens_padded [:- 1 ]) // total_slices_of_any_sequence
617+
618+ # Process each tensor directly instead of using keys_to_change loop
619+ def process_tensor (val ):
620+ if val is None :
621+ return val
622+ # Determine which dimension is the sequence dimension
623+ # Ensure cu_seqlens_padded[-1] is a Python int, not a 0-dim tensor
624+ if isinstance (cu_seqlens_padded [- 1 ], torch .Tensor ):
625+ seq_len_val = cu_seqlens_padded [- 1 ].item ()
626+ else :
627+ seq_len_val = cu_seqlens_padded [- 1 ]
628+
629+ # Handle 1D tensors (like position_ids that don't have batch dimension)
630+ if val .ndim == 1 :
631+ if val .shape [0 ] == seq_len_val :
632+ current_seq_dim = 0
619633 else :
620- seq_len_val = cu_seqlens_padded [- 1 ]
621-
622- # Handle 1D tensors (like position_ids that don't have batch dimension)
623- if val .ndim == 1 :
624- if val .shape [0 ] == seq_len_val :
625- current_seq_dim = 0
626- else :
627- raise ValueError (
628- "1D tensor shape doesn't match expected sequence length. Make sure the"
629- " inputs are in THD format and padded correctly."
630- )
631- elif val .ndim >= 2 :
632- if val .shape [1 ] == seq_len_val :
633- current_seq_dim = 1
634- elif val .shape [0 ] == seq_len_val :
635- current_seq_dim = 0
636- else :
637- raise ValueError ("Make sure the inputs are in THD format and padded correctly." )
634+ raise ValueError (
635+ "1D tensor shape doesn't match expected sequence length. Make sure the"
636+ " inputs are in THD format and padded correctly."
637+ )
638+ elif val .ndim >= 2 :
639+ if val .shape [1 ] == seq_len_val :
640+ current_seq_dim = 1
641+ elif val .shape [0 ] == seq_len_val :
642+ current_seq_dim = 0
638643 else :
639- raise ValueError ("Tensor must be at least 1D" )
640-
641- # On this particular rank, for each sequence, get two slices, one from the beginning
642- # and one from the end.
643- cp_rank_slices = []
644- for slice_size , seq_start in zip (slice_sizes , cu_seqlens_padded [:- 1 ]):
645- # 1st segment
646- cp_rank_slices .append (
647- torch .arange (
648- seq_start + (cp_rank * slice_size ),
649- seq_start + ((cp_rank + 1 ) * slice_size ),
650- device = val .device ,
651- )
644+ raise ValueError ("Make sure the inputs are in THD format and padded correctly." )
645+ else :
646+ raise ValueError ("Tensor must be at least 1D" )
647+
648+ # On this particular rank, for each sequence, get two slices, one from the beginning
649+ # and one from the end.
650+ cp_rank_slices = []
651+ for slice_size , seq_start in zip (slice_sizes , cu_seqlens_padded [:- 1 ]):
652+ # 1st segment
653+ cp_rank_slices .append (
654+ torch .arange (
655+ seq_start + (cp_rank * slice_size ),
656+ seq_start + ((cp_rank + 1 ) * slice_size ),
657+ device = val .device ,
652658 )
653-
654- # 2nd segment
655- cp_rank_slices . append (
656- torch . arange (
657- seq_start + (( total_slices_of_any_sequence - cp_rank - 1 ) * slice_size ),
658- seq_start + ((total_slices_of_any_sequence - cp_rank ) * slice_size ),
659- device = val . device ,
660- )
659+ )
660+
661+ # 2nd segment
662+ cp_rank_slices . append (
663+ torch . arange (
664+ seq_start + ((total_slices_of_any_sequence - cp_rank - 1 ) * slice_size ),
665+ seq_start + (( total_slices_of_any_sequence - cp_rank ) * slice_size ) ,
666+ device = val . device ,
661667 )
668+ )
669+
670+ return val .index_select (current_seq_dim , torch .cat (cp_rank_slices ))
671+
672+ # Process each tensor directly
673+ input_ids_padded = process_tensor (input_ids_padded )
674+ labels_padded = process_tensor (labels_padded )
675+
676+ elif qvk_format == "bshd" :
677+ # BSHD format: [batch, seq_len, ...]
678+ # Split along sequence dimension (dim=1)
679+ # Each sequence is split into 2*cp_world_size chunks
680+ # Each rank gets chunks at positions: [cp_rank, 2*cp_world_size - cp_rank - 1]
662681
663- return val .index_select (current_seq_dim , torch .cat (cp_rank_slices ))
682+ def process_tensor_bshd (val ):
683+ if val is None :
684+ return val
685+
686+ if val .ndim < 2 :
687+ raise ValueError (f"BSHD format requires at least 2D tensors, got { val .ndim } D" )
688+
689+ seq_len = val .shape [1 ]
690+
691+ # Calculate chunk size
692+ total_chunks = 2 * cp_world_size
693+ chunk_size = seq_len // total_chunks
694+
695+ if chunk_size == 0 :
696+ raise ValueError (
697+ f"Sequence length { seq_len } must be divisible by { total_chunks } "
698+ f"(2 * cp_world_size) for BSHD context parallelism"
699+ )
700+
701+ # Determine which chunks this rank should get
702+ # Rank 0 gets chunks [0, total_chunks-1]
703+ # Rank 1 gets chunks [1, total_chunks-2]
704+ # Rank k gets chunks [k, total_chunks-k-1]
705+ chunk_indices = [cp_rank , total_chunks - cp_rank - 1 ]
706+
707+ # Collect slices for this rank
708+ rank_slices = []
709+ for chunk_idx in chunk_indices :
710+ start_idx = chunk_idx * chunk_size
711+ end_idx = start_idx + chunk_size
712+ rank_slices .append (torch .arange (start_idx , end_idx , device = val .device ))
713+
714+ # Concatenate indices for all chunks this rank should get
715+ indices = torch .cat (rank_slices )
716+
717+ # Select along sequence dimension (dim=1)
718+ return val .index_select (1 , indices )
719+
720+ input_ids_padded = process_tensor_bshd (input_ids_padded )
721+ labels_padded = process_tensor_bshd (labels_padded )
664722
665- # Process each tensor directly
666- input_ids_padded = process_tensor (input_ids_padded )
667- labels_padded = process_tensor (labels_padded )
668723 else :
669724 raise ValueError (f"Support not implemented yet for qvk_format: { qvk_format } !" )
670725
671726 return input_ids_padded , labels_padded
672727
673728
674- def _get_group_local_rank (group : torch .distributed .ProcessGroup | None = None ) -> int :
675- """Rank of the current process within `group`."""
676- if group is None :
677- # default group; this is just the global rank
678- return torch .distributed .get_rank ()
679- global_rank = torch .distributed .get_rank ()
680- return torch .distributed .get_group_rank (group , global_rank )
681-
682-
683729class BatchType (TypedDict ):
684730 """The fields in the batch dictionary for context parallel."""
685731
0 commit comments