Skip to content

Commit 29241e4

Browse files
authored
Add BSHD tests for context parallel (#1410)
Since padded-THD context parallelism is only supported on datacenter hardware, we should add BSHD CP examples to be able to test this code on workstations and in CI This branch includes #1400 Closes BIO-13 --------- Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent 07d6f33 commit 29241e4

File tree

17 files changed

+1179
-416
lines changed

17 files changed

+1179
-416
lines changed

bionemo-recipes/models/esm2/src/esm/collator.py

Lines changed: 123 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -569,117 +569,163 @@ def _pt_pad_to_multiple_of(batch: dict[str, Any], pad_to_multiple_of: int, token
569569
# TODO(@jomitchell): Once this gets merged: https://github.com/NVIDIA/TransformerEngine/pull/2387
570570
# we can replace this with the one in TransformerEngine.
571571
def _split_batch_by_cp_rank(
572-
cu_seqlens_padded: torch.Tensor,
572+
cu_seqlens_padded: torch.Tensor | None,
573573
input_ids_padded: torch.Tensor,
574574
labels_padded: torch.Tensor,
575575
cp_group: torch.distributed.ProcessGroup | None = None,
576576
qvk_format: str = "thd",
577577
cp_rank: int | None = None,
578578
cp_world_size: int | None = None,
579579
):
580-
"""Slice batch input along sequence dimension into multiple chunks for THD format.
580+
"""Slice batch input along sequence dimension into multiple chunks for THD or BSHD format.
581581
582-
This function is inteded for use in self attention. It will not work for cross attention because
582+
This function is intended for use in self attention. It will not work for cross attention because
583583
it does not handle the case where the sequence length of the query and key are different.
584584
Which are parallelized across GPUs in a context parallel group.
585-
This version works with variable-length sequences using cumulative sequence lengths.
585+
This version works with variable-length sequences using cumulative sequence lengths for THD format,
586+
and with padded sequences for BSHD format.
586587
587588
Args:
588-
cu_seqlens_padded: Cumulative sequence length.
589+
cu_seqlens_padded: Cumulative sequence length. Required for THD format, optional for BSHD format.
589590
input_ids_padded: Input IDs.
590591
labels_padded: Labels.
591592
cp_group: Context parallel group.
592-
qvk_format: Format of the input data.
593+
qvk_format: Format of the input data ("thd" or "bshd").
593594
cp_world_size: The size of the context parallelism group. If provided, the function will use this value to determine the rank.
594595
cp_rank: Optional manual CP rank index. When provided, the function shards tensors as if it
595596
were executing on that rank without querying `torch.distributed.get_rank`.
596597
"""
597598
if qvk_format not in ["thd", "bshd", "sbhd"]:
598599
raise ValueError(f"Unsupported qvk_format: {qvk_format}!")
600+
601+
if cp_world_size is None or cp_world_size <= 1:
602+
# No splitting needed
603+
return input_ids_padded, labels_padded
604+
605+
if cp_rank is None:
606+
cp_rank = torch.distributed.get_rank(group=cp_group)
607+
elif not (0 <= cp_rank < cp_world_size):
608+
raise ValueError(f"cp_rank must be in [0, {cp_world_size}), but received {cp_rank}.")
609+
599610
if qvk_format == "thd":
600-
# Get context parallel size and rank
601-
if cp_world_size > 1:
602-
if cp_rank is None:
603-
cp_rank = torch.distributed.get_rank(group=cp_group)
604-
elif not (0 <= cp_rank < cp_world_size):
605-
raise ValueError(f"cp_rank must be in [0, {cp_world_size}), but received {cp_rank}.")
606-
607-
# Calculate the chunk sizes for each sequence
608-
total_slices_of_any_sequence = 2 * cp_world_size
609-
slice_sizes = (cu_seqlens_padded[1:] - cu_seqlens_padded[:-1]) // total_slices_of_any_sequence
610-
611-
# Process each tensor directly instead of using keys_to_change loop
612-
def process_tensor(val):
613-
if val is None:
614-
return val
615-
# Determine which dimension is the sequence dimension
616-
# Ensure cu_seqlens_padded[-1] is a Python int, not a 0-dim tensor
617-
if isinstance(cu_seqlens_padded[-1], torch.Tensor):
618-
seq_len_val = cu_seqlens_padded[-1].item()
611+
if cu_seqlens_padded is None:
612+
raise ValueError("cu_seqlens_padded is required for THD format")
613+
614+
# Calculate the chunk sizes for each sequence
615+
total_slices_of_any_sequence = 2 * cp_world_size
616+
slice_sizes = (cu_seqlens_padded[1:] - cu_seqlens_padded[:-1]) // total_slices_of_any_sequence
617+
618+
# Process each tensor directly instead of using keys_to_change loop
619+
def process_tensor(val):
620+
if val is None:
621+
return val
622+
# Determine which dimension is the sequence dimension
623+
# Ensure cu_seqlens_padded[-1] is a Python int, not a 0-dim tensor
624+
if isinstance(cu_seqlens_padded[-1], torch.Tensor):
625+
seq_len_val = cu_seqlens_padded[-1].item()
626+
else:
627+
seq_len_val = cu_seqlens_padded[-1]
628+
629+
# Handle 1D tensors (like position_ids that don't have batch dimension)
630+
if val.ndim == 1:
631+
if val.shape[0] == seq_len_val:
632+
current_seq_dim = 0
619633
else:
620-
seq_len_val = cu_seqlens_padded[-1]
621-
622-
# Handle 1D tensors (like position_ids that don't have batch dimension)
623-
if val.ndim == 1:
624-
if val.shape[0] == seq_len_val:
625-
current_seq_dim = 0
626-
else:
627-
raise ValueError(
628-
"1D tensor shape doesn't match expected sequence length. Make sure the"
629-
" inputs are in THD format and padded correctly."
630-
)
631-
elif val.ndim >= 2:
632-
if val.shape[1] == seq_len_val:
633-
current_seq_dim = 1
634-
elif val.shape[0] == seq_len_val:
635-
current_seq_dim = 0
636-
else:
637-
raise ValueError("Make sure the inputs are in THD format and padded correctly.")
634+
raise ValueError(
635+
"1D tensor shape doesn't match expected sequence length. Make sure the"
636+
" inputs are in THD format and padded correctly."
637+
)
638+
elif val.ndim >= 2:
639+
if val.shape[1] == seq_len_val:
640+
current_seq_dim = 1
641+
elif val.shape[0] == seq_len_val:
642+
current_seq_dim = 0
638643
else:
639-
raise ValueError("Tensor must be at least 1D")
640-
641-
# On this particular rank, for each sequence, get two slices, one from the beginning
642-
# and one from the end.
643-
cp_rank_slices = []
644-
for slice_size, seq_start in zip(slice_sizes, cu_seqlens_padded[:-1]):
645-
# 1st segment
646-
cp_rank_slices.append(
647-
torch.arange(
648-
seq_start + (cp_rank * slice_size),
649-
seq_start + ((cp_rank + 1) * slice_size),
650-
device=val.device,
651-
)
644+
raise ValueError("Make sure the inputs are in THD format and padded correctly.")
645+
else:
646+
raise ValueError("Tensor must be at least 1D")
647+
648+
# On this particular rank, for each sequence, get two slices, one from the beginning
649+
# and one from the end.
650+
cp_rank_slices = []
651+
for slice_size, seq_start in zip(slice_sizes, cu_seqlens_padded[:-1]):
652+
# 1st segment
653+
cp_rank_slices.append(
654+
torch.arange(
655+
seq_start + (cp_rank * slice_size),
656+
seq_start + ((cp_rank + 1) * slice_size),
657+
device=val.device,
652658
)
653-
654-
# 2nd segment
655-
cp_rank_slices.append(
656-
torch.arange(
657-
seq_start + ((total_slices_of_any_sequence - cp_rank - 1) * slice_size),
658-
seq_start + ((total_slices_of_any_sequence - cp_rank) * slice_size),
659-
device=val.device,
660-
)
659+
)
660+
661+
# 2nd segment
662+
cp_rank_slices.append(
663+
torch.arange(
664+
seq_start + ((total_slices_of_any_sequence - cp_rank - 1) * slice_size),
665+
seq_start + ((total_slices_of_any_sequence - cp_rank) * slice_size),
666+
device=val.device,
661667
)
668+
)
669+
670+
return val.index_select(current_seq_dim, torch.cat(cp_rank_slices))
671+
672+
# Process each tensor directly
673+
input_ids_padded = process_tensor(input_ids_padded)
674+
labels_padded = process_tensor(labels_padded)
675+
676+
elif qvk_format == "bshd":
677+
# BSHD format: [batch, seq_len, ...]
678+
# Split along sequence dimension (dim=1)
679+
# Each sequence is split into 2*cp_world_size chunks
680+
# Each rank gets chunks at positions: [cp_rank, 2*cp_world_size - cp_rank - 1]
662681

663-
return val.index_select(current_seq_dim, torch.cat(cp_rank_slices))
682+
def process_tensor_bshd(val):
683+
if val is None:
684+
return val
685+
686+
if val.ndim < 2:
687+
raise ValueError(f"BSHD format requires at least 2D tensors, got {val.ndim}D")
688+
689+
seq_len = val.shape[1]
690+
691+
# Calculate chunk size
692+
total_chunks = 2 * cp_world_size
693+
chunk_size = seq_len // total_chunks
694+
695+
if chunk_size == 0:
696+
raise ValueError(
697+
f"Sequence length {seq_len} must be divisible by {total_chunks} "
698+
f"(2 * cp_world_size) for BSHD context parallelism"
699+
)
700+
701+
# Determine which chunks this rank should get
702+
# Rank 0 gets chunks [0, total_chunks-1]
703+
# Rank 1 gets chunks [1, total_chunks-2]
704+
# Rank k gets chunks [k, total_chunks-k-1]
705+
chunk_indices = [cp_rank, total_chunks - cp_rank - 1]
706+
707+
# Collect slices for this rank
708+
rank_slices = []
709+
for chunk_idx in chunk_indices:
710+
start_idx = chunk_idx * chunk_size
711+
end_idx = start_idx + chunk_size
712+
rank_slices.append(torch.arange(start_idx, end_idx, device=val.device))
713+
714+
# Concatenate indices for all chunks this rank should get
715+
indices = torch.cat(rank_slices)
716+
717+
# Select along sequence dimension (dim=1)
718+
return val.index_select(1, indices)
719+
720+
input_ids_padded = process_tensor_bshd(input_ids_padded)
721+
labels_padded = process_tensor_bshd(labels_padded)
664722

665-
# Process each tensor directly
666-
input_ids_padded = process_tensor(input_ids_padded)
667-
labels_padded = process_tensor(labels_padded)
668723
else:
669724
raise ValueError(f"Support not implemented yet for qvk_format: {qvk_format}!")
670725

671726
return input_ids_padded, labels_padded
672727

673728

674-
def _get_group_local_rank(group: torch.distributed.ProcessGroup | None = None) -> int:
675-
"""Rank of the current process within `group`."""
676-
if group is None:
677-
# default group; this is just the global rank
678-
return torch.distributed.get_rank()
679-
global_rank = torch.distributed.get_rank()
680-
return torch.distributed.get_group_rank(group, global_rank)
681-
682-
683729
class BatchType(TypedDict):
684730
"""The fields in the batch dictionary for context parallel."""
685731

bionemo-recipes/models/esm2/src/esm/modeling_esm_te.py

Lines changed: 4 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -188,43 +188,15 @@ def forward(
188188
**kwargs: Additional arguments, see TransformersKwargs for more details.
189189
"""
190190
all_hidden_states: tuple[torch.Tensor, ...] = ()
191-
has_thd_input = [
192-
x is not None
193-
for x in [
194-
kwargs.get("cu_seq_lens_q", None),
195-
kwargs.get("cu_seq_lens_k", None),
196-
kwargs.get("max_length_q", None),
197-
kwargs.get("max_length_k", None),
198-
]
199-
]
200191

201-
if self.config.attn_input_format == "thd":
202-
if not all(has_thd_input):
203-
raise ValueError(
204-
"cu_seq_lens_q, cu_seq_lens_k, max_length_q, and max_length_k must be provided when using THD inputs."
205-
)
206-
assert hidden_states.dim() == 3 and hidden_states.size(0) == 1, (
207-
"THD expects embeddings shaped [1, total_tokens, hidden_size]."
208-
)
192+
if self.config.attn_input_format == "thd" and hidden_states.dim() == 3 and hidden_states.size(0) == 1:
193+
# For THD, the embedding output is a 3-dimensional tensor with shape [1, total_tokens, hidden_size], but TE
194+
# expects a 2-dimensional tensor with shape [total_tokens, hidden_size].
209195
hidden_states = hidden_states.squeeze(0)
210-
attention_mask = None
211-
212-
elif self.config.attn_input_format == "bshd" and any(has_thd_input):
213-
raise ValueError(
214-
"cu_seq_lens_q, cu_seq_lens_k, max_length_q, and max_length_k are not allowed when using BSHD inputs."
215-
)
216196

217197
# Ensure that rotary embeddings are computed with at a higher precision outside the torch autocast context.
218198
with torch.autocast(device_type="cuda", enabled=False):
219-
if self.config.position_embedding_type == "rotary":
220-
if self.config.attn_input_format == "bshd":
221-
te_rope_emb = self.rotary_embeddings(max_seq_len=hidden_states.shape[1])
222-
elif self.config.attn_input_format == "thd":
223-
te_rope_emb = self.rotary_embeddings(
224-
max_seq_len=kwargs["cu_seq_lens_q_padded"][-1]
225-
if "cu_seq_lens_q_padded" in kwargs
226-
else kwargs["cu_seq_lens_q"][-1]
227-
)
199+
te_rope_emb = self.rotary_embeddings(max_seq_len=self.config.max_position_embeddings)
228200
te_rope_emb = te_rope_emb.to(hidden_states.device, non_blocking=True)
229201

230202
for layer_module in self.layers:

0 commit comments

Comments
 (0)