Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions bionemo-recipes/models/esm2/src/esm/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,40 @@ def _send_data_to_cp_ranks(self):

return batch_on_this_rank

def state_dict(self):
"""Get the state dict by delegating to the dataloader."""
if self.cp_rank != 0:
return {}
elif hasattr(self.dataloader, "state_dict"):
return {"dataloader": self.dataloader.state_dict()}
else:
logger.warning(
"Attempting to get the state dict of the dataloader, but the dataloader does not support state_dict, "
"returning empty dict"
)
return {"dataloader": {}}

def load_state_dict(self, state_dict):
"""Load the state dict by delegating to the dataloader."""
if self.cp_rank != 0:
return
elif hasattr(self.dataloader, "load_state_dict"):
self.dataloader.load_state_dict(state_dict["dataloader"])
else:
logger.warning(
"Attempting to load the state dict of the dataloader, but the dataloader does not support "
"load_state_dict, returning without loading the state dict."
)
return

@property
def num_workers(self):
"""Get the number of workers of the dataloader."""
if self.cp_rank != 0:
return 0
else:
return self.dataloader.num_workers


def _split_sample_by_num_tokens(sample: dict[str, Any], num_tokens: int) -> tuple[dict[str, Any], dict[str, Any]]:
"""Split a sample dictionary at a specified number of tokens.
Expand Down
34 changes: 34 additions & 0 deletions bionemo-recipes/models/llama3/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,40 @@ def _send_data_to_cp_ranks(self):

return batch_on_this_rank

def state_dict(self):
"""Get the state dict by delegating to the dataloader."""
if self.cp_rank != 0:
return {}
elif hasattr(self.dataloader, "state_dict"):
return {"dataloader": self.dataloader.state_dict()}
else:
logger.warning(
"Attempting to get the state dict of the dataloader, but the dataloader does not support state_dict, "
"returning empty dict"
)
return {"dataloader": {}}

def load_state_dict(self, state_dict):
"""Load the state dict by delegating to the dataloader."""
if self.cp_rank != 0:
return
elif hasattr(self.dataloader, "load_state_dict"):
self.dataloader.load_state_dict(state_dict["dataloader"])
else:
logger.warning(
"Attempting to load the state dict of the dataloader, but the dataloader does not support "
"load_state_dict, returning without loading the state dict."
)
return

@property
def num_workers(self):
"""Get the number of workers of the dataloader."""
if self.cp_rank != 0:
return 0
else:
return self.dataloader.num_workers


def _split_sample_by_num_tokens(sample: dict[str, Any], num_tokens: int) -> tuple[dict[str, Any], dict[str, Any]]:
"""Split a sample dictionary at a specified number of tokens.
Expand Down
34 changes: 34 additions & 0 deletions bionemo-recipes/recipes/esm2_native_te/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,40 @@ def _send_data_to_cp_ranks(self):

return batch_on_this_rank

def state_dict(self):
"""Get the state dict by delegating to the dataloader."""
if self.cp_rank != 0:
return {}
elif hasattr(self.dataloader, "state_dict"):
return {"dataloader": self.dataloader.state_dict()}
else:
logger.warning(
"Attempting to get the state dict of the dataloader, but the dataloader does not support state_dict, "
"returning empty dict"
)
return {"dataloader": {}}

def load_state_dict(self, state_dict):
"""Load the state dict by delegating to the dataloader."""
if self.cp_rank != 0:
return
elif hasattr(self.dataloader, "load_state_dict"):
self.dataloader.load_state_dict(state_dict["dataloader"])
else:
logger.warning(
"Attempting to load the state dict of the dataloader, but the dataloader does not support "
"load_state_dict, returning without loading the state dict."
)
return

@property
def num_workers(self):
"""Get the number of workers of the dataloader."""
if self.cp_rank != 0:
return 0
else:
return self.dataloader.num_workers


def _split_sample_by_num_tokens(sample: dict[str, Any], num_tokens: int) -> tuple[dict[str, Any], dict[str, Any]]:
"""Split a sample dictionary at a specified number of tokens.
Expand Down
34 changes: 34 additions & 0 deletions bionemo-recipes/recipes/llama3_native_te/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,40 @@ def _send_data_to_cp_ranks(self):

return batch_on_this_rank

def state_dict(self):
"""Get the state dict by delegating to the dataloader."""
if self.cp_rank != 0:
return {}
elif hasattr(self.dataloader, "state_dict"):
return {"dataloader": self.dataloader.state_dict()}
else:
logger.warning(
"Attempting to get the state dict of the dataloader, but the dataloader does not support state_dict, "
"returning empty dict"
)
return {"dataloader": {}}

def load_state_dict(self, state_dict):
"""Load the state dict by delegating to the dataloader."""
if self.cp_rank != 0:
return
elif hasattr(self.dataloader, "load_state_dict"):
self.dataloader.load_state_dict(state_dict["dataloader"])
else:
logger.warning(
"Attempting to load the state dict of the dataloader, but the dataloader does not support "
"load_state_dict, returning without loading the state dict."
)
return

@property
def num_workers(self):
"""Get the number of workers of the dataloader."""
if self.cp_rank != 0:
return 0
else:
return self.dataloader.num_workers


def _split_sample_by_num_tokens(sample: dict[str, Any], num_tokens: int) -> tuple[dict[str, Any], dict[str, Any]]:
"""Split a sample dictionary at a specified number of tokens.
Expand Down
52 changes: 9 additions & 43 deletions bionemo-recipes/recipes/llama3_native_te/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,12 @@

import datasets
import datasets.distributed
import torch
from torch.utils.data import DataLoader, DistributedSampler
from torchdata.stateful_dataloader import StatefulDataLoader
from transformers import AutoTokenizer
from transformers.data.data_collator import DataCollatorForLanguageModeling

from collator import (
ContextParallelDataLoaderWrapper,
DataCollatorForContextParallel,
DataCollatorWithFlattening,
TokenPackingDataset,
)
Expand Down Expand Up @@ -102,6 +99,11 @@ def tokenize_with_windowing(examples):
remove_columns=[text_column],
)

# Even in THD mode, we use a base MLM collator that requires a padding token to be set.
if tokenizer.pad_token is None:
logger.warning(f"Tokenizer does not have a padding token. Setting it to the EOS token: {tokenizer.eos_token}")
tokenizer.pad_token = tokenizer.eos_token

return tokenized_dataset, tokenizer


Expand All @@ -120,7 +122,7 @@ def create_bshd_dataloader(
text_column: str = "text",
uppercase_labels: bool = False,
mask_degenerate_bases: bool = False,
pad_to_multiple_of: int | None = None,
pad_sequences_to_be_divisible_by: int | None = None,
):
"""Create a BSHD dataloader for llama3 pre-training.

Expand All @@ -139,7 +141,8 @@ def create_bshd_dataloader(
text_column: Name of the column containing text sequences (default: "text").
uppercase_labels: Whether to uppercase labels (genomic masking). Default: False.
mask_degenerate_bases: Whether to mask non-ACGT bases (genomic masking). Default: False.
pad_to_multiple_of: The number to pad sequences to be divisible by, required for FP8 training. Default: 16.
pad_sequences_to_be_divisible_by: The number to pad sequences to be divisible by, required for FP8 training.
Default: None.

Returns:
A tuple of (dataloader, dataset_or_sampler).
Expand Down Expand Up @@ -169,7 +172,7 @@ def create_bshd_dataloader(
base_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False, # Causal language modeling
pad_to_multiple_of=pad_to_multiple_of,
pad_to_multiple_of=pad_sequences_to_be_divisible_by,
)

# Wrap with genomic collator if masking options are enabled
Expand Down Expand Up @@ -300,40 +303,3 @@ def create_thd_dataloader(
)

return train_dataloader, tokenized_dataset


def create_cp_dataloader(
*args,
cp_mesh: torch.distributed.device_mesh.DeviceMesh,
**kwargs,
):
"""Create a Context-parallel aware dataloader that automatically handles sharding between ranks.

Wraps the output of `create_thd_dataloader` to make it context parallel aware.

Args:
*args: Arguments to pass to `create_thd_dataloader`.
cp_mesh: The context parallel mesh.
**kwargs: Keyword arguments to pass to `create_thd_dataloader`.

Returns:
A tuple of (dataloader, dataset_or_sampler).
"""
# Ensure pad_sequences_to_be_divisible_by is passed to create_thd_dataloader
if kwargs.get("pad_sequences_to_be_divisible_by", None) is None:
logger.info("pad_sequences_to_be_divisible_by is not provided, using cp_mesh.size() * 2")
kwargs["pad_sequences_to_be_divisible_by"] = cp_mesh.size() * 2

if cp_mesh.get_local_rank() == 0:
train_dataloader, tokenized_dataset = create_thd_dataloader(*args, **kwargs)

train_dataloader.collate_fn = DataCollatorForContextParallel(
collator=train_dataloader.collate_fn,
cp_world_size=cp_mesh.size(),
)

else:
train_dataloader = None
tokenized_dataset = None

return ContextParallelDataLoaderWrapper(train_dataloader, cp_mesh), tokenized_dataset
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ defaults:

cp_size: 1

use_sequence_packing: false

config_kwargs:
attn_input_format: "thd"
self_attn_mask_type: "padding_causal"
attn_input_format: "bshd" # Alternatively "thd" on datacenter hardware.
self_attn_mask_type: "causal" # Alternatively "padding_causal" for THD inputs.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ dataset:
stride: 200 # Overlap for windowing
buffer_size: 500_000 # Shuffle buffer size
use_stateful_dataloader: false # Until https://github.com/pytorch/pytorch/pull/163102 is resolved with torchdata.
pad_sequences_to_be_divisible_by: null
load_dataset_kwargs:
path: ???
split: "train"
Expand Down
79 changes: 51 additions & 28 deletions bionemo-recipes/recipes/llama3_native_te/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
from hydra import compose, initialize_config_dir
from torch.distributed.device_mesh import init_device_mesh

from dataset import create_bshd_dataloader, create_cp_dataloader, create_thd_dataloader, create_tokenized_dataset
from collator import ContextParallelDataLoaderWrapper, DataCollatorForContextParallel
from dataset import create_bshd_dataloader, create_thd_dataloader, create_tokenized_dataset
from distributed_config import DistributedConfig


Expand Down Expand Up @@ -703,15 +704,28 @@ def test_cp_dataloader(tokenizer_path):
torch.cuda.set_device(dist_config.local_rank)
device_mesh = init_device_mesh("cuda", mesh_shape=(1, 1), mesh_dim_names=("dp", "cp"))

dataloader, _ = create_cp_dataloader(
distributed_config=dist_config,
cp_mesh=device_mesh["cp"],
tokenizer_name_or_path=tokenizer_path,
load_dataset_kwargs=load_dataset_kwargs,
text_column="text",
micro_batch_size=1,
max_seq_length=1024,
)
cp_mesh = device_mesh["cp"]

# Create the context-parallel dataloader directly following the pattern in train_fsdp2_cp.py
if cp_mesh.get_local_rank() == 0:
train_dataloader, _ = create_thd_dataloader(
distributed_config=dist_config,
tokenizer_name_or_path=tokenizer_path,
load_dataset_kwargs=load_dataset_kwargs,
text_column="text",
micro_batch_size=1,
max_seq_length=1024,
pad_sequences_to_be_divisible_by=cp_mesh.size() * 2,
)

train_dataloader.collate_fn = DataCollatorForContextParallel(
collator=train_dataloader.collate_fn,
cp_world_size=cp_mesh.size(),
)
else:
train_dataloader = None

dataloader = ContextParallelDataLoaderWrapper(train_dataloader, cp_mesh)

batches = list(dataloader)
assert len(batches) > 1
Expand Down Expand Up @@ -775,30 +789,39 @@ def test_cp_dataloader_multi_gpu(recipe_path, dataset_path):
parser.add_argument("--dataset_path", type=str, default="dlcm_sanity_dataset.parquet")
args = parser.parse_args()

from torch.distributed.device_mesh import init_device_mesh

from dataset import create_cp_dataloader

dist_config = DistributedConfig()
device = torch.device(f"cuda:{dist_config.local_rank}")
torch.distributed.init_process_group(backend="nccl", device_id=device)
torch.cuda.set_device(dist_config.local_rank)
device_mesh = init_device_mesh("cuda", mesh_shape=(1, 2), mesh_dim_names=("dp", "cp"))

dataloader, _ = create_cp_dataloader(
distributed_config=dist_config,
cp_mesh=device_mesh["cp"],
tokenizer_name_or_path="nvidia/Llama-3.1-8B-Instruct-FP8",
micro_batch_size=1,
text_column="text" if args.dataset_path == "dlcm_sanity_dataset.parquet" else "sequence",
load_dataset_kwargs={
"path": "parquet",
"split": "train",
"data_files": args.dataset_path,
"streaming": True,
},
num_workers=1,
)
cp_mesh = device_mesh["cp"]

# Create the context-parallel dataloader directly following the pattern in train_fsdp2_cp.py
if cp_mesh.get_local_rank() == 0:
train_dataloader, _ = create_thd_dataloader(
distributed_config=dist_config,
tokenizer_name_or_path="nvidia/Llama-3.1-8B-Instruct-FP8",
micro_batch_size=1,
text_column="text" if args.dataset_path == "dlcm_sanity_dataset.parquet" else "sequence",
load_dataset_kwargs={
"path": "parquet",
"split": "train",
"data_files": args.dataset_path,
"streaming": True,
},
num_workers=1,
pad_sequences_to_be_divisible_by=cp_mesh.size() * 2,
)

train_dataloader.collate_fn = DataCollatorForContextParallel(
collator=train_dataloader.collate_fn,
cp_world_size=cp_mesh.size(),
)
else:
train_dataloader = None

dataloader = ContextParallelDataLoaderWrapper(train_dataloader, cp_mesh)

batches = list(itertools.islice(dataloader, 10))

Expand Down
Loading
Loading