diff --git a/bionemo-recipes/models/esm2/src/esm/collator.py b/bionemo-recipes/models/esm2/src/esm/collator.py index ea614988e..a32e1e3fe 100644 --- a/bionemo-recipes/models/esm2/src/esm/collator.py +++ b/bionemo-recipes/models/esm2/src/esm/collator.py @@ -415,6 +415,40 @@ def _send_data_to_cp_ranks(self): return batch_on_this_rank + def state_dict(self): + """Get the state dict by delegating to the dataloader.""" + if self.cp_rank != 0: + return {} + elif hasattr(self.dataloader, "state_dict"): + return {"dataloader": self.dataloader.state_dict()} + else: + logger.warning( + "Attempting to get the state dict of the dataloader, but the dataloader does not support state_dict, " + "returning empty dict" + ) + return {"dataloader": {}} + + def load_state_dict(self, state_dict): + """Load the state dict by delegating to the dataloader.""" + if self.cp_rank != 0: + return + elif hasattr(self.dataloader, "load_state_dict"): + self.dataloader.load_state_dict(state_dict["dataloader"]) + else: + logger.warning( + "Attempting to load the state dict of the dataloader, but the dataloader does not support " + "load_state_dict, returning without loading the state dict." + ) + return + + @property + def num_workers(self): + """Get the number of workers of the dataloader.""" + if self.cp_rank != 0: + return 0 + else: + return self.dataloader.num_workers + def _split_sample_by_num_tokens(sample: dict[str, Any], num_tokens: int) -> tuple[dict[str, Any], dict[str, Any]]: """Split a sample dictionary at a specified number of tokens. diff --git a/bionemo-recipes/models/llama3/collator.py b/bionemo-recipes/models/llama3/collator.py index ea614988e..a32e1e3fe 100644 --- a/bionemo-recipes/models/llama3/collator.py +++ b/bionemo-recipes/models/llama3/collator.py @@ -415,6 +415,40 @@ def _send_data_to_cp_ranks(self): return batch_on_this_rank + def state_dict(self): + """Get the state dict by delegating to the dataloader.""" + if self.cp_rank != 0: + return {} + elif hasattr(self.dataloader, "state_dict"): + return {"dataloader": self.dataloader.state_dict()} + else: + logger.warning( + "Attempting to get the state dict of the dataloader, but the dataloader does not support state_dict, " + "returning empty dict" + ) + return {"dataloader": {}} + + def load_state_dict(self, state_dict): + """Load the state dict by delegating to the dataloader.""" + if self.cp_rank != 0: + return + elif hasattr(self.dataloader, "load_state_dict"): + self.dataloader.load_state_dict(state_dict["dataloader"]) + else: + logger.warning( + "Attempting to load the state dict of the dataloader, but the dataloader does not support " + "load_state_dict, returning without loading the state dict." + ) + return + + @property + def num_workers(self): + """Get the number of workers of the dataloader.""" + if self.cp_rank != 0: + return 0 + else: + return self.dataloader.num_workers + def _split_sample_by_num_tokens(sample: dict[str, Any], num_tokens: int) -> tuple[dict[str, Any], dict[str, Any]]: """Split a sample dictionary at a specified number of tokens. diff --git a/bionemo-recipes/recipes/esm2_native_te/collator.py b/bionemo-recipes/recipes/esm2_native_te/collator.py index ea614988e..a32e1e3fe 100644 --- a/bionemo-recipes/recipes/esm2_native_te/collator.py +++ b/bionemo-recipes/recipes/esm2_native_te/collator.py @@ -415,6 +415,40 @@ def _send_data_to_cp_ranks(self): return batch_on_this_rank + def state_dict(self): + """Get the state dict by delegating to the dataloader.""" + if self.cp_rank != 0: + return {} + elif hasattr(self.dataloader, "state_dict"): + return {"dataloader": self.dataloader.state_dict()} + else: + logger.warning( + "Attempting to get the state dict of the dataloader, but the dataloader does not support state_dict, " + "returning empty dict" + ) + return {"dataloader": {}} + + def load_state_dict(self, state_dict): + """Load the state dict by delegating to the dataloader.""" + if self.cp_rank != 0: + return + elif hasattr(self.dataloader, "load_state_dict"): + self.dataloader.load_state_dict(state_dict["dataloader"]) + else: + logger.warning( + "Attempting to load the state dict of the dataloader, but the dataloader does not support " + "load_state_dict, returning without loading the state dict." + ) + return + + @property + def num_workers(self): + """Get the number of workers of the dataloader.""" + if self.cp_rank != 0: + return 0 + else: + return self.dataloader.num_workers + def _split_sample_by_num_tokens(sample: dict[str, Any], num_tokens: int) -> tuple[dict[str, Any], dict[str, Any]]: """Split a sample dictionary at a specified number of tokens. diff --git a/bionemo-recipes/recipes/llama3_native_te/collator.py b/bionemo-recipes/recipes/llama3_native_te/collator.py index ea614988e..a32e1e3fe 100644 --- a/bionemo-recipes/recipes/llama3_native_te/collator.py +++ b/bionemo-recipes/recipes/llama3_native_te/collator.py @@ -415,6 +415,40 @@ def _send_data_to_cp_ranks(self): return batch_on_this_rank + def state_dict(self): + """Get the state dict by delegating to the dataloader.""" + if self.cp_rank != 0: + return {} + elif hasattr(self.dataloader, "state_dict"): + return {"dataloader": self.dataloader.state_dict()} + else: + logger.warning( + "Attempting to get the state dict of the dataloader, but the dataloader does not support state_dict, " + "returning empty dict" + ) + return {"dataloader": {}} + + def load_state_dict(self, state_dict): + """Load the state dict by delegating to the dataloader.""" + if self.cp_rank != 0: + return + elif hasattr(self.dataloader, "load_state_dict"): + self.dataloader.load_state_dict(state_dict["dataloader"]) + else: + logger.warning( + "Attempting to load the state dict of the dataloader, but the dataloader does not support " + "load_state_dict, returning without loading the state dict." + ) + return + + @property + def num_workers(self): + """Get the number of workers of the dataloader.""" + if self.cp_rank != 0: + return 0 + else: + return self.dataloader.num_workers + def _split_sample_by_num_tokens(sample: dict[str, Any], num_tokens: int) -> tuple[dict[str, Any], dict[str, Any]]: """Split a sample dictionary at a specified number of tokens. diff --git a/bionemo-recipes/recipes/llama3_native_te/dataset.py b/bionemo-recipes/recipes/llama3_native_te/dataset.py index 94e3c7063..d233a73f2 100644 --- a/bionemo-recipes/recipes/llama3_native_te/dataset.py +++ b/bionemo-recipes/recipes/llama3_native_te/dataset.py @@ -17,15 +17,12 @@ import datasets import datasets.distributed -import torch from torch.utils.data import DataLoader, DistributedSampler from torchdata.stateful_dataloader import StatefulDataLoader from transformers import AutoTokenizer from transformers.data.data_collator import DataCollatorForLanguageModeling from collator import ( - ContextParallelDataLoaderWrapper, - DataCollatorForContextParallel, DataCollatorWithFlattening, TokenPackingDataset, ) @@ -102,6 +99,11 @@ def tokenize_with_windowing(examples): remove_columns=[text_column], ) + # Even in THD mode, we use a base MLM collator that requires a padding token to be set. + if tokenizer.pad_token is None: + logger.warning(f"Tokenizer does not have a padding token. Setting it to the EOS token: {tokenizer.eos_token}") + tokenizer.pad_token = tokenizer.eos_token + return tokenized_dataset, tokenizer @@ -120,7 +122,7 @@ def create_bshd_dataloader( text_column: str = "text", uppercase_labels: bool = False, mask_degenerate_bases: bool = False, - pad_to_multiple_of: int | None = None, + pad_sequences_to_be_divisible_by: int | None = None, ): """Create a BSHD dataloader for llama3 pre-training. @@ -139,7 +141,8 @@ def create_bshd_dataloader( text_column: Name of the column containing text sequences (default: "text"). uppercase_labels: Whether to uppercase labels (genomic masking). Default: False. mask_degenerate_bases: Whether to mask non-ACGT bases (genomic masking). Default: False. - pad_to_multiple_of: The number to pad sequences to be divisible by, required for FP8 training. Default: 16. + pad_sequences_to_be_divisible_by: The number to pad sequences to be divisible by, required for FP8 training. + Default: None. Returns: A tuple of (dataloader, dataset_or_sampler). @@ -169,7 +172,7 @@ def create_bshd_dataloader( base_collator = DataCollatorForLanguageModeling( tokenizer=tokenizer, mlm=False, # Causal language modeling - pad_to_multiple_of=pad_to_multiple_of, + pad_to_multiple_of=pad_sequences_to_be_divisible_by, ) # Wrap with genomic collator if masking options are enabled @@ -300,40 +303,3 @@ def create_thd_dataloader( ) return train_dataloader, tokenized_dataset - - -def create_cp_dataloader( - *args, - cp_mesh: torch.distributed.device_mesh.DeviceMesh, - **kwargs, -): - """Create a Context-parallel aware dataloader that automatically handles sharding between ranks. - - Wraps the output of `create_thd_dataloader` to make it context parallel aware. - - Args: - *args: Arguments to pass to `create_thd_dataloader`. - cp_mesh: The context parallel mesh. - **kwargs: Keyword arguments to pass to `create_thd_dataloader`. - - Returns: - A tuple of (dataloader, dataset_or_sampler). - """ - # Ensure pad_sequences_to_be_divisible_by is passed to create_thd_dataloader - if kwargs.get("pad_sequences_to_be_divisible_by", None) is None: - logger.info("pad_sequences_to_be_divisible_by is not provided, using cp_mesh.size() * 2") - kwargs["pad_sequences_to_be_divisible_by"] = cp_mesh.size() * 2 - - if cp_mesh.get_local_rank() == 0: - train_dataloader, tokenized_dataset = create_thd_dataloader(*args, **kwargs) - - train_dataloader.collate_fn = DataCollatorForContextParallel( - collator=train_dataloader.collate_fn, - cp_world_size=cp_mesh.size(), - ) - - else: - train_dataloader = None - tokenized_dataset = None - - return ContextParallelDataLoaderWrapper(train_dataloader, cp_mesh), tokenized_dataset diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L0_sanity_cp.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L0_sanity_cp.yaml index bfe2bb2ea..669d86364 100644 --- a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L0_sanity_cp.yaml +++ b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L0_sanity_cp.yaml @@ -4,6 +4,8 @@ defaults: cp_size: 1 +use_sequence_packing: false + config_kwargs: - attn_input_format: "thd" - self_attn_mask_type: "padding_causal" + attn_input_format: "bshd" # Alternatively "thd" on datacenter hardware. + self_attn_mask_type: "causal" # Alternatively "padding_causal" for THD inputs. diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml index 070474324..c36a10e82 100644 --- a/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml +++ b/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml @@ -22,6 +22,7 @@ dataset: stride: 200 # Overlap for windowing buffer_size: 500_000 # Shuffle buffer size use_stateful_dataloader: false # Until https://github.com/pytorch/pytorch/pull/163102 is resolved with torchdata. + pad_sequences_to_be_divisible_by: null load_dataset_kwargs: path: ??? split: "train" diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/test_dataset.py b/bionemo-recipes/recipes/llama3_native_te/tests/test_dataset.py index e3d480b5d..84388ce56 100644 --- a/bionemo-recipes/recipes/llama3_native_te/tests/test_dataset.py +++ b/bionemo-recipes/recipes/llama3_native_te/tests/test_dataset.py @@ -24,7 +24,8 @@ from hydra import compose, initialize_config_dir from torch.distributed.device_mesh import init_device_mesh -from dataset import create_bshd_dataloader, create_cp_dataloader, create_thd_dataloader, create_tokenized_dataset +from collator import ContextParallelDataLoaderWrapper, DataCollatorForContextParallel +from dataset import create_bshd_dataloader, create_thd_dataloader, create_tokenized_dataset from distributed_config import DistributedConfig @@ -703,15 +704,28 @@ def test_cp_dataloader(tokenizer_path): torch.cuda.set_device(dist_config.local_rank) device_mesh = init_device_mesh("cuda", mesh_shape=(1, 1), mesh_dim_names=("dp", "cp")) - dataloader, _ = create_cp_dataloader( - distributed_config=dist_config, - cp_mesh=device_mesh["cp"], - tokenizer_name_or_path=tokenizer_path, - load_dataset_kwargs=load_dataset_kwargs, - text_column="text", - micro_batch_size=1, - max_seq_length=1024, - ) + cp_mesh = device_mesh["cp"] + + # Create the context-parallel dataloader directly following the pattern in train_fsdp2_cp.py + if cp_mesh.get_local_rank() == 0: + train_dataloader, _ = create_thd_dataloader( + distributed_config=dist_config, + tokenizer_name_or_path=tokenizer_path, + load_dataset_kwargs=load_dataset_kwargs, + text_column="text", + micro_batch_size=1, + max_seq_length=1024, + pad_sequences_to_be_divisible_by=cp_mesh.size() * 2, + ) + + train_dataloader.collate_fn = DataCollatorForContextParallel( + collator=train_dataloader.collate_fn, + cp_world_size=cp_mesh.size(), + ) + else: + train_dataloader = None + + dataloader = ContextParallelDataLoaderWrapper(train_dataloader, cp_mesh) batches = list(dataloader) assert len(batches) > 1 @@ -775,30 +789,39 @@ def test_cp_dataloader_multi_gpu(recipe_path, dataset_path): parser.add_argument("--dataset_path", type=str, default="dlcm_sanity_dataset.parquet") args = parser.parse_args() - from torch.distributed.device_mesh import init_device_mesh - - from dataset import create_cp_dataloader - dist_config = DistributedConfig() device = torch.device(f"cuda:{dist_config.local_rank}") torch.distributed.init_process_group(backend="nccl", device_id=device) torch.cuda.set_device(dist_config.local_rank) device_mesh = init_device_mesh("cuda", mesh_shape=(1, 2), mesh_dim_names=("dp", "cp")) - dataloader, _ = create_cp_dataloader( - distributed_config=dist_config, - cp_mesh=device_mesh["cp"], - tokenizer_name_or_path="nvidia/Llama-3.1-8B-Instruct-FP8", - micro_batch_size=1, - text_column="text" if args.dataset_path == "dlcm_sanity_dataset.parquet" else "sequence", - load_dataset_kwargs={ - "path": "parquet", - "split": "train", - "data_files": args.dataset_path, - "streaming": True, - }, - num_workers=1, - ) + cp_mesh = device_mesh["cp"] + + # Create the context-parallel dataloader directly following the pattern in train_fsdp2_cp.py + if cp_mesh.get_local_rank() == 0: + train_dataloader, _ = create_thd_dataloader( + distributed_config=dist_config, + tokenizer_name_or_path="nvidia/Llama-3.1-8B-Instruct-FP8", + micro_batch_size=1, + text_column="text" if args.dataset_path == "dlcm_sanity_dataset.parquet" else "sequence", + load_dataset_kwargs={ + "path": "parquet", + "split": "train", + "data_files": args.dataset_path, + "streaming": True, + }, + num_workers=1, + pad_sequences_to_be_divisible_by=cp_mesh.size() * 2, + ) + + train_dataloader.collate_fn = DataCollatorForContextParallel( + collator=train_dataloader.collate_fn, + cp_world_size=cp_mesh.size(), + ) + else: + train_dataloader = None + + dataloader = ContextParallelDataLoaderWrapper(train_dataloader, cp_mesh) batches = list(itertools.islice(dataloader, 10)) diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/test_distributed_checkpointing.py b/bionemo-recipes/recipes/llama3_native_te/tests/test_distributed_checkpointing.py index 03eb633f4..1a48c4134 100644 --- a/bionemo-recipes/recipes/llama3_native_te/tests/test_distributed_checkpointing.py +++ b/bionemo-recipes/recipes/llama3_native_te/tests/test_distributed_checkpointing.py @@ -38,6 +38,7 @@ from train_ddp import main as main_ddp from train_fsdp2 import main as main_fsdp2 +from train_fsdp2_cp import main as main_fsdp2_cp os.environ["WANDB_DISABLED"] = "true" @@ -543,6 +544,228 @@ def test_checkpoint_save_and_load_two_processes_fsdp2(recipe_path, tmp_path): ) +def test_checkpoint_save_and_load_single_process_fsdp2_with_context_parallelism(recipe_path, tmp_path): + """Test checkpoint save/resume functionality for FSDP2 with single process and context parallelism. + + This test validates: + - FSDP2 creates distributed checkpoints (step_X directories by default) + - Each rank saves its shard (even with single process) + - Dataloader state is saved alongside model checkpoint + - Training can resume from latest checkpoint and continue + - Resume starts from correct step count + + Process: + 1. Train 10 steps (0-9), save checkpoint at step 5 + 2. Resume training from step 5, continue to step 15 + 3. Verify checkpoints exist at steps 5 and 10 + """ + temp_dir = str(tmp_path / "test_ckpt_fsdp2_cp") + + # Phase 1: Train for 10 steps (using distributed checkpoint by default) + with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): + phase1_config = compose( + config_name="L0_sanity_cp", + overrides=[ + f"checkpoint.ckpt_dir={temp_dir}", + f"+wandb.dir={tmp_path}", + "num_train_steps=10", + "checkpoint.save_every_n_steps=5", + "checkpoint.resume_from_checkpoint=false", # Start fresh + "dataset.use_stateful_dataloader=true", # Enable for checkpoint testing + "checkpoint.async_save=false", + ], + ) + + main_fsdp2_cp(phase1_config) + gc.collect() + torch.cuda.empty_cache() + + # Checkpoints are saved in a subdirectory named after the script + ckpt_subdir = os.path.join(temp_dir, "train_fsdp2") + assert os.path.exists(ckpt_subdir), f"Checkpoint subdirectory {ckpt_subdir} not created" + + # Verify checkpoint was created (FSDP2 creates directories by default) + checkpoint_dirs = [ + d for d in os.listdir(ckpt_subdir) if d.startswith("step_") and os.path.isdir(os.path.join(ckpt_subdir, d)) + ] + assert len(checkpoint_dirs) > 0, "No checkpoint directories created in phase 1" + + # Check that checkpoint at step 5 exists + expected_checkpoint = "step_5" + assert expected_checkpoint in checkpoint_dirs, f"Expected {expected_checkpoint} not found" + + # Check dataloader file exists in step_5 directory + step_5_dir = os.path.join(ckpt_subdir, "step_5") + assert os.path.isdir(step_5_dir), f"Step 5 directory not found: {step_5_dir}" + step_5_files = os.listdir(step_5_dir) + + # With single process, we expect dataloader file for rank 0 + dataloader_files_5 = [f for f in step_5_files if "dataloader" in f] + assert len(dataloader_files_5) >= 1, ( + f"Expected at least 1 dataloader file, found {len(dataloader_files_5)}: {dataloader_files_5}" + ) + assert any("rank_0" in f for f in dataloader_files_5), ( + f"No dataloader file for rank 0 found in step_5. Files: {dataloader_files_5}" + ) + + # Phase 2: Resume training + with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): + phase2_config = compose( + config_name="L0_sanity_cp", + overrides=[ + f"checkpoint.ckpt_dir={temp_dir}", + f"+wandb.dir={tmp_path}", + "num_train_steps=15", + "checkpoint.save_every_n_steps=5", + "checkpoint.resume_from_checkpoint=true", # Resume from checkpoint + "dataset.use_stateful_dataloader=true", # Enable for checkpoint testing + # Sometimes the checkpoint hasn't finished saving by the time we resume training, so we disable async + # save for this test. + "checkpoint.async_save=false", + ], + ) + + main_fsdp2_cp(phase2_config) + gc.collect() + torch.cuda.empty_cache() + + # Verify phase 2 completed and created additional checkpoints + final_checkpoint_dirs = [ + d for d in os.listdir(ckpt_subdir) if d.startswith("step_") and os.path.isdir(os.path.join(ckpt_subdir, d)) + ] + expected_checkpoints = ["step_5", "step_10"] + for expected in expected_checkpoints: + assert expected in final_checkpoint_dirs, f"Missing checkpoint: {expected}" + + # Check dataloader file exists in step_10 directory + step_10_dir = os.path.join(ckpt_subdir, "step_10") + assert os.path.isdir(step_10_dir), f"Step 10 directory not found: {step_10_dir}" + step_10_files = os.listdir(step_10_dir) + + # With single process, we expect dataloader file for rank 0 + dataloader_files_10 = [f for f in step_10_files if "dataloader" in f] + assert len(dataloader_files_10) >= 1, ( + f"Expected at least 1 dataloader file in step_10, found {len(dataloader_files_10)}: {dataloader_files_10}" + ) + assert any("rank_0" in f for f in dataloader_files_10), ( + f"No dataloader file for rank 0 found in step_10. Files: {dataloader_files_10}" + ) + + +@requires_multi_gpu +def test_checkpoint_save_and_load_two_processes_fsdp2_with_context_parallelism(recipe_path, tmp_path): + """Test checkpoint save/resume functionality for FSDP2 with two processes. + + This test validates: + - Multi-process FSDP2 distributed checkpointing (each rank saves its shard) + - Dataloader state is saved for each rank alongside model checkpoint + - All ranks participate in saving and loading + - Training resumes correctly with proper process synchronization + + Process: + 1. Train 10 steps (0-9) across 2 processes with context parallelism, save checkpoint at step 5 + 2. Resume training with 2 processes from step 5, continue to step 15 + 3. Verify checkpoints exist at steps 5 and 10 with dataloader files for both ranks + """ + temp_dir = str(tmp_path / "test_ckpt_fsdp2_cp_2p") + + # Set environment for subprocess + env = os.environ.copy() + env["WANDB_MODE"] = "disabled" + + # Get the full path to train_fsdp2.py + train_script = recipe_path / "train_fsdp2_cp.py" + + # Phase 1: Train for 10 steps with 2 processes + cmd_phase1 = [ + "torchrun", + "--nproc_per_node=2", + str(train_script), + f"checkpoint.ckpt_dir={temp_dir}", + "num_train_steps=10", + "checkpoint.save_every_n_steps=5", + "checkpoint.async_save=false", + "dataset.use_stateful_dataloader=true", # Enable for checkpoint testing + "cp_size=2", + ] + + result1 = subprocess.run(cmd_phase1, check=False, capture_output=True, text=True, env=env) + assert result1.returncode == 0, f"Phase 1 failed: {result1.stderr}" + + # Checkpoints are saved in a subdirectory named after the script + ckpt_subdir = os.path.join(temp_dir, "train_fsdp2") + assert os.path.exists(ckpt_subdir), f"Checkpoint subdirectory {ckpt_subdir} not created" + + # Verify checkpoint was created (FSDP2 creates directories by default) + checkpoint_dirs = [ + d for d in os.listdir(ckpt_subdir) if d.startswith("step_") and os.path.isdir(os.path.join(ckpt_subdir, d)) + ] + assert len(checkpoint_dirs) > 0, "No checkpoint directories created in phase 1" + + # Check that checkpoint at step 5 exists + expected_checkpoint = "step_5" + assert expected_checkpoint in checkpoint_dirs, f"Expected {expected_checkpoint} not found" + + # Check dataloader files exist in step_5 directory for both ranks + step_5_dir = os.path.join(ckpt_subdir, "step_5") + assert os.path.isdir(step_5_dir), f"Step 5 directory not found: {step_5_dir}" + step_5_files = os.listdir(step_5_dir) + + # With 2 processes, we expect dataloader files for rank 0 and rank 1 + dataloader_files_5 = [f for f in step_5_files if "dataloader" in f] + assert len(dataloader_files_5) == 2, ( + f"Expected 2 dataloader files (rank 0 and 1), found {len(dataloader_files_5)}: {dataloader_files_5}" + ) + assert any("rank_0" in f for f in dataloader_files_5), ( + f"No dataloader file for rank 0 found in step_5. Files: {dataloader_files_5}" + ) + assert any("rank_1" in f for f in dataloader_files_5), ( + f"No dataloader file for rank 1 found in step_5. Files: {dataloader_files_5}" + ) + + # Phase 2: Resume training with 2 processes + cmd_phase2 = [ + "torchrun", + "--nproc_per_node=2", + str(train_script), + f"checkpoint.ckpt_dir={temp_dir}", + "num_train_steps=15", + "checkpoint.save_every_n_steps=5", + "checkpoint.resume_from_checkpoint=true", # Resume from checkpoint + "checkpoint.async_save=false", + "dataset.use_stateful_dataloader=true", # Enable for checkpoint testing + "cp_size=2", + ] + + result2 = subprocess.run(cmd_phase2, check=False, capture_output=True, text=True, env=env) + assert result2.returncode == 0, f"Phase 2 failed: {result2.stderr}" + + # Verify phase 2 completed and created additional checkpoints + final_checkpoint_dirs = [ + d for d in os.listdir(ckpt_subdir) if d.startswith("step_") and os.path.isdir(os.path.join(ckpt_subdir, d)) + ] + expected_checkpoints = ["step_5", "step_10"] + for expected in expected_checkpoints: + assert expected in final_checkpoint_dirs, f"Missing checkpoint: {expected}" + + # Check dataloader files exist in step_10 directory for both ranks + step_10_dir = os.path.join(ckpt_subdir, "step_10") + assert os.path.isdir(step_10_dir), f"Step 10 directory not found: {step_10_dir}" + step_10_files = os.listdir(step_10_dir) + + # With 2 processes, we expect dataloader files for rank 0 and rank 1 + dataloader_files_10 = [f for f in step_10_files if "dataloader" in f] + assert len(dataloader_files_10) == 2, ( + f"Expected 2 dataloader files (rank 0 and 1) in step_10, found {len(dataloader_files_10)}: {dataloader_files_10}" + ) + assert any("rank_0" in f for f in dataloader_files_10), ( + f"No dataloader file for rank 0 found in step_10. Files: {dataloader_files_10}" + ) + assert any("rank_1" in f for f in dataloader_files_10), ( + f"No dataloader file for rank 1 found in step_10. Files: {dataloader_files_10}" + ) + + def test_scheduler_resume_single_gpu(recipe_path, tmp_path): """Test that learning rate scheduler resumes from correct state after checkpoint load. diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py b/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py index bc16d5f5a..a1573d8a1 100644 --- a/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py +++ b/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py @@ -378,7 +378,7 @@ def test_train_fsdp2_fp8_bshd(tmp_path, recipe_path): f"+wandb.dir={tmp_path}", f"checkpoint.ckpt_dir={tmp_path}", "fp8_config.enabled=true", - "+dataset.pad_to_multiple_of=16", + "+dataset.pad_sequences_to_be_divisible_by=16", ], ) diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/test_train_two_gpu.py b/bionemo-recipes/recipes/llama3_native_te/tests/test_train_two_gpu.py index d4d792567..f87915af0 100644 --- a/bionemo-recipes/recipes/llama3_native_te/tests/test_train_two_gpu.py +++ b/bionemo-recipes/recipes/llama3_native_te/tests/test_train_two_gpu.py @@ -191,9 +191,31 @@ def test_multi_gpu_train_fsdp2_with_checkpointing(tmp_path, recipe_path): assert (ckpt_dir / "step_5").exists(), "Checkpoint at step 5 not found" +@requires_multi_gpu +def test_multi_gpu_train_te_fsdp2_cp_bshd(tmp_path, recipe_path): + run_train_cmd( + [ + "torchrun", + "--nproc_per_node=2", + "--standalone", + "train_fsdp2_cp.py", + "--config-name", + "L0_sanity_cp", + "num_train_steps=10", + f"checkpoint.ckpt_dir={tmp_path}", + "checkpoint.save_every_n_steps=5", + "cp_size=2", + "use_sequence_packing=false", + "config_kwargs.attn_input_format=bshd", + "config_kwargs.self_attn_mask_type=causal", + ], + recipe_path, + ) + + @requires_multi_gpu @requires_datacenter_hardware -def test_multi_gpu_train_te_fsdp2_cp(tmp_path, recipe_path): +def test_multi_gpu_train_te_fsdp2_cp_thd(tmp_path, recipe_path): run_train_cmd( [ "torchrun", @@ -206,6 +228,9 @@ def test_multi_gpu_train_te_fsdp2_cp(tmp_path, recipe_path): f"checkpoint.ckpt_dir={tmp_path}", "checkpoint.save_every_n_steps=5", "cp_size=2", + "use_sequence_packing=true", + "config_kwargs.attn_input_format=thd", + "config_kwargs.self_attn_mask_type=padding_causal", ], recipe_path, ) diff --git a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py index 8e4254161..0870b47f7 100644 --- a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py +++ b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py @@ -28,7 +28,8 @@ from transformer_engine.common.recipe import Format from checkpoint import load_checkpoint_fsdp2, save_checkpoint_fsdp2, save_final_model_fsdp2, should_save_checkpoint -from dataset import create_cp_dataloader +from collator import ContextParallelDataLoaderWrapper, DataCollatorForContextParallel +from dataset import create_bshd_dataloader, create_thd_dataloader from distributed_config import DistributedConfig from modeling_llama_te import NVLlamaConfig, NVLlamaForCausalLM from perf_logger import PerfLogger @@ -81,7 +82,7 @@ def main(args: DictConfig) -> float | None: logger.info("Initialized Model:\n%s", model) # Create a flattened mesh for FSDP2 sharding. This will shard the model across both the DP and CP ranks. - cp_dp_mesh = device_mesh["dp", "cp"]._flatten(mesh_dim_name="dp_shard_cp") if args.cp_size > 1 else device_mesh + cp_dp_mesh = device_mesh["dp", "cp"]._flatten(mesh_dim_name="dp_shard_cp") # Shard the transformer layers with FSDP. For Llama3, the transformer stack is in model.model.layers. # Each decoder layer should be individually sharded before sharding the full model. @@ -109,11 +110,26 @@ def main(args: DictConfig) -> float | None: # If we're using torch.compile, we need to do this before loading the checkpoint to ensure key consistency. model = torch.compile(model) - train_dataloader, dataset_or_sampler = create_cp_dataloader( - dist_config, - cp_mesh=device_mesh["cp"], - **args.dataset, - ) + # Create the context-aware dataloader. We only create the dataloader on rank 0 and wrap it in a + # ContextParallelDataLoaderWrapper that will shard and distribute the data across the context parallelism group. + args.dataset.setdefault("pad_sequences_to_be_divisible_by", device_mesh["cp"].size() * 2) + if device_mesh["cp"].get_local_rank() == 0: + if args.use_sequence_packing: + train_dataloader, dataset_or_sampler = create_thd_dataloader(dist_config, **args.dataset) + else: + train_dataloader, dataset_or_sampler = create_bshd_dataloader(dist_config, **args.dataset) + + train_dataloader.collate_fn = DataCollatorForContextParallel( + collator=train_dataloader.collate_fn, + cp_world_size=device_mesh["cp"].size(), + qkv_format=args.config_kwargs.attn_input_format, + ) + + else: + train_dataloader = None + dataset_or_sampler = None + + train_dataloader = ContextParallelDataLoaderWrapper(train_dataloader, device_mesh["cp"]) # If we're resuming from a checkpoint, load it and set the start step. Otherwise, start from step 0. ckpt_path = Path(args.checkpoint.ckpt_dir) / "train_fsdp2" if args.checkpoint.ckpt_dir else None @@ -126,7 +142,7 @@ def main(args: DictConfig) -> float | None: ckpt_path=ckpt_path, dist_config=dist_config, dataloader=train_dataloader, - process_group=device_mesh.get_group("dp"), + process_group=cp_dp_mesh.get_group(), ) logger.info(f"Checkpoint loaded, resuming from step {start_step}, epoch {epoch}") else: @@ -188,7 +204,7 @@ def main(args: DictConfig) -> float | None: epoch=epoch, dist_config=dist_config, dataloader=train_dataloader if args.dataset.use_stateful_dataloader else None, - process_group=device_mesh.get_group("dp"), + process_group=cp_dp_mesh.get_group(), max_checkpoints=args.checkpoint.max_checkpoints, async_save=args.checkpoint.async_save, )