Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 44 additions & 27 deletions bionemo-recipes/models/esm2/src/esm/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ def set_epoch(self, epoch: int):
self.dataset.set_epoch(epoch)


@dataclass
class DataCollatorForContextParallel:
"""A collator that is aware of context parallelism.

Expand All @@ -285,15 +286,10 @@ class DataCollatorForContextParallel:
appropriate GPUs.
"""

def __init__(self, collator: DataCollator, cp_world_size: int):
"""Initialize the DataCollatorForContextParallel.

Args:
collator: The collator to use for masking tokens.
cp_world_size: The size of the context parallelism group.
"""
self.collator = collator
self.cp_world_size = cp_world_size
collator: DataCollator
cp_world_size: int
tp_world_size: int | None = None
qkv_format: str = "thd"

def __call__(self, features) -> list[dict[str, Any]]:
"""Process batches of data and create shards for each context parallelism rank.
Expand All @@ -309,23 +305,36 @@ def __call__(self, features) -> list[dict[str, Any]]:
combined_batch = []
for cp_rank in range(self.cp_world_size):
input_ids_sharded, labels_sharded = _split_batch_by_cp_rank(
cu_seqlens_padded=batch["cu_seq_lens_q_padded"],
cu_seqlens_padded=batch.get("cu_seq_lens_q_padded", None), # This will be None for BSHD format.
input_ids_padded=batch["input_ids"],
labels_padded=batch["labels"],
qvk_format="thd",
qvk_format=self.qkv_format,
cp_rank=cp_rank,
cp_world_size=self.cp_world_size,
)
batch_shard = dict(batch)
batch_shard["input_ids"] = input_ids_sharded
batch_shard["labels"] = labels_sharded
# Now determine the max length of the sequence.
seqlens_q = batch_shard["cu_seq_lens_q_padded"][1:] - batch_shard["cu_seq_lens_q_padded"][:-1]
batch_shard["max_length_q"] = int((seqlens_q.max().item() + 63) // 64 * 64)
batch_shard["max_length_k"] = batch_shard["max_length_q"]
batch_shard["pad_between_seqs"] = True
if self.qkv_format == "thd":
seqlens_q = batch_shard["cu_seq_lens_q_padded"][1:] - batch_shard["cu_seq_lens_q_padded"][:-1]
max_length = seqlens_q.max().item()
batch_shard["pad_between_seqs"] = True
elif self.qkv_format == "bshd":
max_length = batch["input_ids"].shape[1]
# For BSHD context parallelism, we can't handle padding, so we remove the attention mask.
del batch_shard["attention_mask"]
else:
raise ValueError(f"Unsupported qvk_format: {self.qkv_format}!")

batch_shard["max_length_k"] = batch_shard["max_length_q"] = max_length * round(max_length / 64)
combined_batch.append(batch_shard)

if self.tp_world_size is not None:
# If we're using tensor parallelism, we need to replicate the batch for each TP rank. We do this by just
# repeating the batch in a single flattened output list.
combined_batch = [batch for batch in combined_batch for _ in range(self.tp_world_size)]

return combined_batch


Expand All @@ -335,29 +344,32 @@ class ContextParallelDataLoaderWrapper:
def __init__(
self,
dataloader: torch.utils.data.DataLoader | None,
cp_mesh: torch.distributed.device_mesh.DeviceMesh,
cp_tp_mesh: torch.distributed.device_mesh.DeviceMesh,
):
"""A dataloader wrapper that distributes the data across the context parallelism group.

This class will get the batch from the dataloader on CP rank 0, and then determine the shards for all the
different CP group members. Then it will scatter the shards to the different CP group members. The shards are
then returned to the caller for the current CP rank.

If tensor parallelism is also being used, the data will be replicated across the TP dimension for each CP rank.
This should be provided using a flattened cp/tp mesh.

Args:
dataloader: The dataloader to use.
cp_mesh: The context parallel mesh.
cp_rank: The rank of the current context parallel process.
cp_tp_mesh: The context parallel mesh, or combined context parallel and tensor parallel mesh.

"""
if cp_mesh.get_local_rank() == 0:
if cp_tp_mesh.get_local_rank() == 0:
assert dataloader is not None, "dataloader must be provided on rank 0"
self.dataloader = dataloader

else:
assert dataloader is None, "Dataloader on non-rank 0 will not be used"

self.cp_rank = cp_mesh.get_local_rank()
self.cp_group = cp_mesh.get_group()
self.num_cp_ranks = cp_mesh.size()
self.cp_rank = cp_tp_mesh.get_local_rank()
self.cp_group = cp_tp_mesh.get_group()
self.num_cp_ranks = cp_tp_mesh.size()
self._iterator = None

logger.debug(
Expand Down Expand Up @@ -387,9 +399,13 @@ def _send_data_to_cp_ranks(self):
The shards are then combined into a single batch and returned to the caller
for the current CP rank.

If tensor parallelism is also being used, the combined batch will look like:
combined_batch = [<cp_rank_0_shard>, <cp_rank_0_shard>, ..., <cp_rank_1_shard>, ...]
where each shard is replicated self.tp_world_size times.

Scalability:
Rank 0's work grows linearly with CP size, but the other ranks do not need to store all the shards so they do not
grow linearly with CP size.
Rank 0's work grows linearly with CP size, but the other ranks do not need to store all the shards so they
do not grow linearly with CP size.

Args:
None
Expand Down Expand Up @@ -727,7 +743,7 @@ def process_tensor_bshd(val):


class BatchType(TypedDict):
"""The fields in the batch dictionary for context parallel."""
"""The fields in the batch dictionary fo THD context parallel."""

input_ids: torch.Tensor
labels: torch.Tensor
Expand All @@ -737,17 +753,18 @@ class BatchType(TypedDict):
cu_seq_lens_k_padded: torch.Tensor
max_length_q: int
max_length_k: int
pad_between_seqs: bool


def _scatter_batch_to_cp_ranks(
batch: list[BatchType] | list[StopIteration], cp_group: torch.distributed.ProcessGroup | None = None
all_batches: list[BatchType] | list[StopIteration], cp_group: torch.distributed.ProcessGroup | None = None
) -> BatchType | StopIteration:
"""Scatter a batch to all the CP ranks."""
scatter_object_output_list = [None]
# Note: This does not provide an async_op handle. Thus its blocking.
torch.distributed.scatter_object_list(
scatter_object_output_list=scatter_object_output_list,
scatter_object_input_list=batch,
scatter_object_input_list=all_batches,
group=cp_group,
group_src=0,
)
Expand Down
165 changes: 165 additions & 0 deletions bionemo-recipes/models/esm2/tests/test_collator_context_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@
from typing import Dict, Iterator, List
from unittest import mock

import pytest
import torch
from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import pad_thd_sequences_for_cp
from transformers import DataCollatorForLanguageModeling

from esm.collator import (
BatchType,
ContextParallelDataLoaderWrapper,
DataCollatorForContextParallel,
DataCollatorWithFlattening,
_split_batch_by_cp_rank,
)
Expand Down Expand Up @@ -887,3 +890,165 @@ def test_bshd_and_thd_equivalence(tokenizer):
torch.sort(batch_bshd["input_ids"][1])[0],
msg="Reconstructed sequence 2 doesn't match original",
)


@pytest.mark.parametrize("cp_world_size", [2, 4])
def test_data_collator_for_context_parallel_returns_correct_list_size(tokenizer, cp_world_size):
"""Test that DataCollatorForContextParallel returns a list of the correct size."""
divisibility_factor = 2 * cp_world_size

# Create the wrapped collator that produces padded THD batches
base_collator = DataCollatorWithFlattening(
collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15),
pad_sequences_to_be_divisible_by=divisibility_factor,
)

# Create the context parallel collator
cp_collator = DataCollatorForContextParallel(collator=base_collator, cp_world_size=cp_world_size)

# Create test sequences
features = [
{"input_ids": [0, 5, 6, 7, 8, 9, 10, 2]}, # 8 tokens
{"input_ids": [0, 11, 12, 13, 14, 15, 16, 17, 2]}, # 9 tokens
]

# Call the collator
result = cp_collator(features)

# Assert that the result is a list of the correct size
assert isinstance(result, list), f"Expected list, got {type(result)}"
assert len(result) == cp_world_size, f"Expected list of size {cp_world_size}, got {len(result)}"


def test_data_collator_for_context_parallel_thd(tokenizer):
"""Test that each shard from DataCollatorForContextParallel has all required keys from BatchType."""

cp_world_size = 2
divisibility_factor = 2 * cp_world_size

# Create the wrapped collator that produces padded THD batches
base_collator = DataCollatorWithFlattening(
collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15),
pad_sequences_to_be_divisible_by=divisibility_factor,
)

# Create the context parallel collator
cp_collator = DataCollatorForContextParallel(collator=base_collator, cp_world_size=cp_world_size)

# Create test sequences
features = [
{"input_ids": [0, 5, 6, 7, 8, 9, 10, 2]}, # 8 tokens
{"input_ids": [0, 11, 12, 13, 14, 15, 16, 17, 2]}, # 9 tokens
]

# Call the collator
result = cp_collator(features)

assert len(result) == cp_world_size, f"Expected list of size {cp_world_size}, got {len(result)}"

# Define the required keys from BatchType
required_keys = set(BatchType.__annotations__.keys())

# Assert each shard has all required keys
for cp_rank, shard in enumerate(result):
assert set(shard.keys()) == required_keys, (
f"CP rank {cp_rank}: difference: {set(shard.keys()) - required_keys}"
)


def test_data_collator_for_context_parallel_bshd(tokenizer):
"""Test that each shard from DataCollatorForContextParallel has all required keys from BatchType."""

cp_world_size = 2
divisibility_factor = 2 * cp_world_size

# Create the wrapped collator that produces padded THD batches
base_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm_probability=0.15,
pad_to_multiple_of=divisibility_factor,
)

# Create the context parallel collator
cp_collator = DataCollatorForContextParallel(
collator=base_collator, cp_world_size=cp_world_size, qkv_format="bshd"
)

# Create test sequences
features = [
{"input_ids": [0, 5, 6, 7, 8, 9, 10, 2]}, # 8 tokens
{"input_ids": [0, 11, 12, 13, 14, 15, 16, 17, 2]}, # 9 tokens
]

# Call the collator
result = cp_collator(features)

assert len(result) == cp_world_size, f"Expected list of size {cp_world_size}, got {len(result)}"

# Define the required keys from BatchType
required_keys = {"input_ids", "labels", "max_length_q", "max_length_k"}

# Assert each shard has all required keys
for cp_rank, shard in enumerate(result):
assert set(shard.keys()) == required_keys, (
f"CP rank {cp_rank}: expected keys {required_keys}, got {set(shard.keys())}"
)


def test_data_collator_for_context_parallel_with_tp(tokenizer):
"""Test that DataCollatorForContextParallel duplicates batches for TP ranks when tp_world_size is set."""
cp_world_size = 2
tp_world_size = 2
divisibility_factor = 2 * cp_world_size

# Create the wrapped collator that produces padded THD batches
base_collator = DataCollatorWithFlattening(
collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15),
pad_sequences_to_be_divisible_by=divisibility_factor,
)

# Create the context parallel collator with TP
cp_collator = DataCollatorForContextParallel(
collator=base_collator, cp_world_size=cp_world_size, tp_world_size=tp_world_size
)

# Create test sequences
features = [
{"input_ids": [0, 5, 6, 7, 8, 9, 10, 2]}, # 8 tokens
{"input_ids": [0, 11, 12, 13, 14, 15, 16, 17, 2]}, # 9 tokens
]

# Call the collator
result = cp_collator(features)

# Assert that the result list has length cp_world_size * tp_world_size
expected_length = cp_world_size * tp_world_size
assert len(result) == expected_length, f"Expected list of size {expected_length}, got {len(result)}"

# Assert that batches are duplicated for TP ranks
# The structure should be: [cp0_tp0, cp0_tp1, cp1_tp0, cp1_tp1]
# So consecutive pairs should be identical for the same CP rank
for cp_rank in range(cp_world_size):
base_idx = cp_rank * tp_world_size
for tp_rank in range(1, tp_world_size):
# Compare each TP rank's batch with the first TP rank's batch for this CP rank
first_batch = result[base_idx]
current_batch = result[base_idx + tp_rank]

# Check that all keys are the same
assert set(first_batch.keys()) == set(current_batch.keys()), (
f"CP rank {cp_rank}, TP rank {tp_rank}: keys don't match"
)

# Check that tensor values are identical
for key in first_batch.keys():
if isinstance(first_batch[key], torch.Tensor):
torch.testing.assert_close(
first_batch[key],
current_batch[key],
msg=f"CP rank {cp_rank}, TP rank {tp_rank}: '{key}' tensors don't match",
)
else:
assert first_batch[key] == current_batch[key], (
f"CP rank {cp_rank}, TP rank {tp_rank}: '{key}' values don't match"
)
Loading