diff --git a/bionemo-recipes/models/esm2/src/esm/collator.py b/bionemo-recipes/models/esm2/src/esm/collator.py index ea614988e..48f0e6712 100644 --- a/bionemo-recipes/models/esm2/src/esm/collator.py +++ b/bionemo-recipes/models/esm2/src/esm/collator.py @@ -276,6 +276,73 @@ def set_epoch(self, epoch: int): @dataclass +class SequencePackingIterableDataset(torch.utils.data.IterableDataset): + """Dataset that packs sequences by concatenating across boundaries for BSHD format. + + This dataset creates fixed-length samples by arbitrarily splitting across sequence boundaries. + + Example: + Input sequences: AAA, BBBBB, CCC, DDDDDDD + Output samples: AAAB, BBBB, CCCD, DDDD, DD... + + Unlike TokenPackingDataset (which yields variable-count lists for THD format with cu_seqlens), + this yields fixed-length samples for BSHD format with pure causal masking. + + Args: + dataset: The input IterableDataset (can be windowed or not). + max_seq_length: Fixed length for each output sample. + pad_token_id: Token ID for padding (only used if drop_last=False). + drop_last: Whether to drop the last incomplete sample. Default: True. + """ + + dataset: datasets.IterableDataset + max_seq_length: int + pad_token_id: int + drop_last: bool = True + + def __iter__(self): + """Yield fixed-length samples by streaming tokens across sequence boundaries.""" + buffer_input_ids = [] + buffer_labels = [] + has_labels = None + + for sample in iter(self.dataset): + # Determine if dataset has labels on first sample + if has_labels is None: + has_labels = "labels" in sample + + # Add tokens to buffer + buffer_input_ids.extend(sample["input_ids"]) + if has_labels: + buffer_labels.extend(sample.get("labels", sample["input_ids"])) + + # Yield full chunks of max_seq_length + while len(buffer_input_ids) >= self.max_seq_length: + output = {"input_ids": buffer_input_ids[: self.max_seq_length]} + if has_labels: + output["labels"] = buffer_labels[: self.max_seq_length] + + yield output + + # Keep remainder in buffer + buffer_input_ids = buffer_input_ids[self.max_seq_length :] + if has_labels: + buffer_labels = buffer_labels[self.max_seq_length :] + + # Handle remaining tokens (only if not dropping last) + if buffer_input_ids and not self.drop_last: + padding_length = self.max_seq_length - len(buffer_input_ids) + output = {"input_ids": buffer_input_ids + [self.pad_token_id] * padding_length} + if has_labels: + output["labels"] = buffer_labels + [-100] * padding_length + yield output + + def set_epoch(self, epoch: int): + """Set the epoch for the underlying dataset.""" + if hasattr(self.dataset, "set_epoch"): + self.dataset.set_epoch(epoch) + + class DataCollatorForContextParallel: """A collator that is aware of context parallelism. diff --git a/bionemo-recipes/models/llama3/collator.py b/bionemo-recipes/models/llama3/collator.py index ea614988e..48f0e6712 100644 --- a/bionemo-recipes/models/llama3/collator.py +++ b/bionemo-recipes/models/llama3/collator.py @@ -276,6 +276,73 @@ def set_epoch(self, epoch: int): @dataclass +class SequencePackingIterableDataset(torch.utils.data.IterableDataset): + """Dataset that packs sequences by concatenating across boundaries for BSHD format. + + This dataset creates fixed-length samples by arbitrarily splitting across sequence boundaries. + + Example: + Input sequences: AAA, BBBBB, CCC, DDDDDDD + Output samples: AAAB, BBBB, CCCD, DDDD, DD... + + Unlike TokenPackingDataset (which yields variable-count lists for THD format with cu_seqlens), + this yields fixed-length samples for BSHD format with pure causal masking. + + Args: + dataset: The input IterableDataset (can be windowed or not). + max_seq_length: Fixed length for each output sample. + pad_token_id: Token ID for padding (only used if drop_last=False). + drop_last: Whether to drop the last incomplete sample. Default: True. + """ + + dataset: datasets.IterableDataset + max_seq_length: int + pad_token_id: int + drop_last: bool = True + + def __iter__(self): + """Yield fixed-length samples by streaming tokens across sequence boundaries.""" + buffer_input_ids = [] + buffer_labels = [] + has_labels = None + + for sample in iter(self.dataset): + # Determine if dataset has labels on first sample + if has_labels is None: + has_labels = "labels" in sample + + # Add tokens to buffer + buffer_input_ids.extend(sample["input_ids"]) + if has_labels: + buffer_labels.extend(sample.get("labels", sample["input_ids"])) + + # Yield full chunks of max_seq_length + while len(buffer_input_ids) >= self.max_seq_length: + output = {"input_ids": buffer_input_ids[: self.max_seq_length]} + if has_labels: + output["labels"] = buffer_labels[: self.max_seq_length] + + yield output + + # Keep remainder in buffer + buffer_input_ids = buffer_input_ids[self.max_seq_length :] + if has_labels: + buffer_labels = buffer_labels[self.max_seq_length :] + + # Handle remaining tokens (only if not dropping last) + if buffer_input_ids and not self.drop_last: + padding_length = self.max_seq_length - len(buffer_input_ids) + output = {"input_ids": buffer_input_ids + [self.pad_token_id] * padding_length} + if has_labels: + output["labels"] = buffer_labels + [-100] * padding_length + yield output + + def set_epoch(self, epoch: int): + """Set the epoch for the underlying dataset.""" + if hasattr(self.dataset, "set_epoch"): + self.dataset.set_epoch(epoch) + + class DataCollatorForContextParallel: """A collator that is aware of context parallelism. diff --git a/bionemo-recipes/models/llama3/modeling_llama_te.py b/bionemo-recipes/models/llama3/modeling_llama_te.py index 358f023e9..73b712e06 100644 --- a/bionemo-recipes/models/llama3/modeling_llama_te.py +++ b/bionemo-recipes/models/llama3/modeling_llama_te.py @@ -14,6 +14,7 @@ # limitations under the License. from collections import OrderedDict +from contextlib import nullcontext from typing import Unpack import torch @@ -44,6 +45,8 @@ class NVLlamaConfig(LlamaConfig): attn_input_format: str = "thd" self_attn_mask_type: str = "padding_causal" + fp8_first_last_bf16: bool = False + """When True, keeps first and last transformer layers in bf16 for FP8 numerical stability.""" class NVLlamaPreTrainedModel(PreTrainedModel): @@ -221,24 +224,33 @@ def forward( with torch.autocast(device_type="cuda", enabled=False): te_rope_emb = self.rotary_emb(max_seq_len=self.config.max_position_embeddings) - for decoder_layer in self.layers[: self.config.num_hidden_layers]: + num_layers = self.config.num_hidden_layers + for layer_idx, decoder_layer in enumerate(self.layers[:num_layers]): if output_hidden_states: all_hidden_states = (*all_hidden_states, hidden_states) - hidden_states = decoder_layer( - hidden_states, - attention_mask=None if self.config.attn_input_format == "thd" else attention_mask, - rotary_pos_emb=te_rope_emb, - inference_params=past_key_values, - cu_seqlens_q=kwargs.get("cu_seq_lens_q", None), - cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None), - cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None), - cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None), - max_seqlen_q=kwargs.get("max_length_q", None), - max_seqlen_kv=kwargs.get("max_length_k", None), - pad_between_seqs=kwargs.get("pad_between_seqs", None), + # Optionally keep first and last layers in bf16 for FP8 numerical stability + use_bf16_for_layer = getattr(self.config, "fp8_first_last_bf16", False) and ( + layer_idx == 0 or layer_idx == num_layers - 1 ) + # If fp8_first_last_bf16 is enabled, disable FP8 for first/last layers + # This nested fp8_autocast will override the outer one from training script + with transformer_engine.pytorch.fp8_autocast(enabled=False) if use_bf16_for_layer else nullcontext(): + hidden_states = decoder_layer( + hidden_states, + attention_mask=None if self.config.attn_input_format == "thd" else attention_mask, + rotary_pos_emb=te_rope_emb, + inference_params=past_key_values, + cu_seqlens_q=kwargs.get("cu_seq_lens_q", None), + cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None), + cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None), + cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None), + max_seqlen_q=kwargs.get("max_length_q", None), + max_seqlen_kv=kwargs.get("max_length_k", None), + pad_between_seqs=kwargs.get("pad_between_seqs", None), + ) + hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer. Note that these will be in THD format; we could possibly pad diff --git a/bionemo-recipes/recipes/esm2_native_te/collator.py b/bionemo-recipes/recipes/esm2_native_te/collator.py index ea614988e..48f0e6712 100644 --- a/bionemo-recipes/recipes/esm2_native_te/collator.py +++ b/bionemo-recipes/recipes/esm2_native_te/collator.py @@ -276,6 +276,73 @@ def set_epoch(self, epoch: int): @dataclass +class SequencePackingIterableDataset(torch.utils.data.IterableDataset): + """Dataset that packs sequences by concatenating across boundaries for BSHD format. + + This dataset creates fixed-length samples by arbitrarily splitting across sequence boundaries. + + Example: + Input sequences: AAA, BBBBB, CCC, DDDDDDD + Output samples: AAAB, BBBB, CCCD, DDDD, DD... + + Unlike TokenPackingDataset (which yields variable-count lists for THD format with cu_seqlens), + this yields fixed-length samples for BSHD format with pure causal masking. + + Args: + dataset: The input IterableDataset (can be windowed or not). + max_seq_length: Fixed length for each output sample. + pad_token_id: Token ID for padding (only used if drop_last=False). + drop_last: Whether to drop the last incomplete sample. Default: True. + """ + + dataset: datasets.IterableDataset + max_seq_length: int + pad_token_id: int + drop_last: bool = True + + def __iter__(self): + """Yield fixed-length samples by streaming tokens across sequence boundaries.""" + buffer_input_ids = [] + buffer_labels = [] + has_labels = None + + for sample in iter(self.dataset): + # Determine if dataset has labels on first sample + if has_labels is None: + has_labels = "labels" in sample + + # Add tokens to buffer + buffer_input_ids.extend(sample["input_ids"]) + if has_labels: + buffer_labels.extend(sample.get("labels", sample["input_ids"])) + + # Yield full chunks of max_seq_length + while len(buffer_input_ids) >= self.max_seq_length: + output = {"input_ids": buffer_input_ids[: self.max_seq_length]} + if has_labels: + output["labels"] = buffer_labels[: self.max_seq_length] + + yield output + + # Keep remainder in buffer + buffer_input_ids = buffer_input_ids[self.max_seq_length :] + if has_labels: + buffer_labels = buffer_labels[self.max_seq_length :] + + # Handle remaining tokens (only if not dropping last) + if buffer_input_ids and not self.drop_last: + padding_length = self.max_seq_length - len(buffer_input_ids) + output = {"input_ids": buffer_input_ids + [self.pad_token_id] * padding_length} + if has_labels: + output["labels"] = buffer_labels + [-100] * padding_length + yield output + + def set_epoch(self, epoch: int): + """Set the epoch for the underlying dataset.""" + if hasattr(self.dataset, "set_epoch"): + self.dataset.set_epoch(epoch) + + class DataCollatorForContextParallel: """A collator that is aware of context parallelism. diff --git a/bionemo-recipes/recipes/llama3_native_te/collator.py b/bionemo-recipes/recipes/llama3_native_te/collator.py index ea614988e..48f0e6712 100644 --- a/bionemo-recipes/recipes/llama3_native_te/collator.py +++ b/bionemo-recipes/recipes/llama3_native_te/collator.py @@ -276,6 +276,73 @@ def set_epoch(self, epoch: int): @dataclass +class SequencePackingIterableDataset(torch.utils.data.IterableDataset): + """Dataset that packs sequences by concatenating across boundaries for BSHD format. + + This dataset creates fixed-length samples by arbitrarily splitting across sequence boundaries. + + Example: + Input sequences: AAA, BBBBB, CCC, DDDDDDD + Output samples: AAAB, BBBB, CCCD, DDDD, DD... + + Unlike TokenPackingDataset (which yields variable-count lists for THD format with cu_seqlens), + this yields fixed-length samples for BSHD format with pure causal masking. + + Args: + dataset: The input IterableDataset (can be windowed or not). + max_seq_length: Fixed length for each output sample. + pad_token_id: Token ID for padding (only used if drop_last=False). + drop_last: Whether to drop the last incomplete sample. Default: True. + """ + + dataset: datasets.IterableDataset + max_seq_length: int + pad_token_id: int + drop_last: bool = True + + def __iter__(self): + """Yield fixed-length samples by streaming tokens across sequence boundaries.""" + buffer_input_ids = [] + buffer_labels = [] + has_labels = None + + for sample in iter(self.dataset): + # Determine if dataset has labels on first sample + if has_labels is None: + has_labels = "labels" in sample + + # Add tokens to buffer + buffer_input_ids.extend(sample["input_ids"]) + if has_labels: + buffer_labels.extend(sample.get("labels", sample["input_ids"])) + + # Yield full chunks of max_seq_length + while len(buffer_input_ids) >= self.max_seq_length: + output = {"input_ids": buffer_input_ids[: self.max_seq_length]} + if has_labels: + output["labels"] = buffer_labels[: self.max_seq_length] + + yield output + + # Keep remainder in buffer + buffer_input_ids = buffer_input_ids[self.max_seq_length :] + if has_labels: + buffer_labels = buffer_labels[self.max_seq_length :] + + # Handle remaining tokens (only if not dropping last) + if buffer_input_ids and not self.drop_last: + padding_length = self.max_seq_length - len(buffer_input_ids) + output = {"input_ids": buffer_input_ids + [self.pad_token_id] * padding_length} + if has_labels: + output["labels"] = buffer_labels + [-100] * padding_length + yield output + + def set_epoch(self, epoch: int): + """Set the epoch for the underlying dataset.""" + if hasattr(self.dataset, "set_epoch"): + self.dataset.set_epoch(epoch) + + class DataCollatorForContextParallel: """A collator that is aware of context parallelism. diff --git a/bionemo-recipes/recipes/llama3_native_te/dataset.py b/bionemo-recipes/recipes/llama3_native_te/dataset.py index 94e3c7063..cde8547b9 100644 --- a/bionemo-recipes/recipes/llama3_native_te/dataset.py +++ b/bionemo-recipes/recipes/llama3_native_te/dataset.py @@ -27,6 +27,7 @@ ContextParallelDataLoaderWrapper, DataCollatorForContextParallel, DataCollatorWithFlattening, + SequencePackingIterableDataset, TokenPackingDataset, ) from distributed_config import DistributedConfig @@ -139,7 +140,7 @@ def create_bshd_dataloader( text_column: Name of the column containing text sequences (default: "text"). uppercase_labels: Whether to uppercase labels (genomic masking). Default: False. mask_degenerate_bases: Whether to mask non-ACGT bases (genomic masking). Default: False. - pad_to_multiple_of: The number to pad sequences to be divisible by, required for FP8 training. Default: 16. + pad_to_multiple_of: The number to pad sequences to be divisible by, required for FP8 training. Default: None. Returns: A tuple of (dataloader, dataset_or_sampler). @@ -337,3 +338,123 @@ def create_cp_dataloader( tokenized_dataset = None return ContextParallelDataLoaderWrapper(train_dataloader, cp_mesh), tokenized_dataset + + +def create_bshd_packed_dataloader( + distributed_config: DistributedConfig, + tokenizer_name_or_path: str, + load_dataset_kwargs: dict, + micro_batch_size: int, + max_seq_length: int = 8192, + stride: int = 200, + num_workers: int = 1, + prefetch_factor: int = 4, + seed: int = 42, + buffer_size: int = 500_000, + text_column: str = "text", + use_stateful_dataloader: bool = False, + uppercase_labels: bool = False, + mask_degenerate_bases: bool = False, + pad_to_multiple_of: int | None = None, +): + """Create a BSHD dataloader with full sequence packing. + + This creates fixed-length samples by concatenating sequences and arbitrarily splitting + across sequence boundaries. Unlike THD packing, this does not track sequence boundaries + with cu_seqlens, allowing attention to flow across packed sequences. + + Key features: + - Uses windowing (via create_tokenized_dataset) to handle long sequences + - Packs windows into fixed-length chunks, crossing boundaries + - No cu_seqlens (no boundary tracking) + - No attention masks (pure causal masking) + - drop_last=True (no padding) + + Args: + distributed_config: The distributed configuration. + tokenizer_name_or_path: Name or path to the tokenizer. + load_dataset_kwargs: Keyword arguments to pass to `load_dataset`. + micro_batch_size: The batch size per device. + max_seq_length: The fixed length for each sample. + stride: The stride for windowing (used by create_tokenized_dataset). + num_workers: The number of workers to use for the dataloader. + prefetch_factor: The prefetch factor to use for the dataloader. + seed: The seed for shuffling. + buffer_size: The buffer size for shuffle. + text_column: Name of the column containing text sequences. + use_stateful_dataloader: Whether to use StatefulDataLoader for checkpointing. + uppercase_labels: Whether to uppercase labels (genomic masking). Default: False. + mask_degenerate_bases: Whether to mask non-ACGT bases (genomic masking). Default: False. + pad_to_multiple_of: If set, pads sequences to ensure total tokens is divisible by this number. + Required for FP8 (should be 16). Default: None. + + Returns: + A tuple of (dataloader, dataset). + """ + # Use existing tokenization with windowing + tokenized_dataset, tokenizer = create_tokenized_dataset( + distributed_config=distributed_config, + tokenizer_name_or_path=tokenizer_name_or_path, + load_dataset_kwargs=load_dataset_kwargs, + max_seq_length=max_seq_length, + stride=stride, + buffer_size=buffer_size, + text_column=text_column, + ) + + # Set pad_token if not present (required for BSHD format with padding) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + logger.info(f"Set tokenizer.pad_token to eos_token: {tokenizer.eos_token}") + + # Wrap with packing dataset - drop_last=True to avoid padding + packed_dataset = SequencePackingIterableDataset( + dataset=tokenized_dataset, + max_seq_length=max_seq_length, + pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id, + drop_last=True, # Drop last incomplete sample (no padding) + ) + + # Create base collator + base_collator = DataCollatorForLanguageModeling( + tokenizer=tokenizer, + mlm=False, # Causal language modeling + pad_to_multiple_of=pad_to_multiple_of, # For FP8 compatibility (must be divisible by 16) + ) + + # Wrap with genomic collator if masking options are enabled + if uppercase_labels or mask_degenerate_bases: + data_collator = GenomicDataCollator( + base_collator=base_collator, + uppercase_labels=uppercase_labels, + mask_degenerate_bases=mask_degenerate_bases, + ) + logger.info( + f"Using GenomicDataCollator (uppercase={uppercase_labels}, mask_degenerate={mask_degenerate_bases})" + ) + else: + # Use base collator directly for backward compatibility + data_collator = base_collator + logger.info("Using standard DataCollatorForLanguageModeling") + + if pad_to_multiple_of is not None: + logger.info(f"Padding to multiple of {pad_to_multiple_of} for FP8 compatibility") + + # Create dataloader + dataloader_class = StatefulDataLoader if use_stateful_dataloader else DataLoader + train_dataloader = dataloader_class( + packed_dataset, + batch_size=micro_batch_size, + collate_fn=data_collator, + num_workers=num_workers, + pin_memory=True if not use_stateful_dataloader else False, + persistent_workers=num_workers > 0, + prefetch_factor=prefetch_factor if num_workers > 0 else None, + ) + + logger.info( + f"Created BSHD packed dataloader: max_seq_length={max_seq_length}, " + f"stride={stride}, micro_batch_size={micro_batch_size}, drop_last=True (no padding)" + ) + + return train_dataloader, packed_dataset diff --git a/bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py b/bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py index 358f023e9..73b712e06 100644 --- a/bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py +++ b/bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py @@ -14,6 +14,7 @@ # limitations under the License. from collections import OrderedDict +from contextlib import nullcontext from typing import Unpack import torch @@ -44,6 +45,8 @@ class NVLlamaConfig(LlamaConfig): attn_input_format: str = "thd" self_attn_mask_type: str = "padding_causal" + fp8_first_last_bf16: bool = False + """When True, keeps first and last transformer layers in bf16 for FP8 numerical stability.""" class NVLlamaPreTrainedModel(PreTrainedModel): @@ -221,24 +224,33 @@ def forward( with torch.autocast(device_type="cuda", enabled=False): te_rope_emb = self.rotary_emb(max_seq_len=self.config.max_position_embeddings) - for decoder_layer in self.layers[: self.config.num_hidden_layers]: + num_layers = self.config.num_hidden_layers + for layer_idx, decoder_layer in enumerate(self.layers[:num_layers]): if output_hidden_states: all_hidden_states = (*all_hidden_states, hidden_states) - hidden_states = decoder_layer( - hidden_states, - attention_mask=None if self.config.attn_input_format == "thd" else attention_mask, - rotary_pos_emb=te_rope_emb, - inference_params=past_key_values, - cu_seqlens_q=kwargs.get("cu_seq_lens_q", None), - cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None), - cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None), - cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None), - max_seqlen_q=kwargs.get("max_length_q", None), - max_seqlen_kv=kwargs.get("max_length_k", None), - pad_between_seqs=kwargs.get("pad_between_seqs", None), + # Optionally keep first and last layers in bf16 for FP8 numerical stability + use_bf16_for_layer = getattr(self.config, "fp8_first_last_bf16", False) and ( + layer_idx == 0 or layer_idx == num_layers - 1 ) + # If fp8_first_last_bf16 is enabled, disable FP8 for first/last layers + # This nested fp8_autocast will override the outer one from training script + with transformer_engine.pytorch.fp8_autocast(enabled=False) if use_bf16_for_layer else nullcontext(): + hidden_states = decoder_layer( + hidden_states, + attention_mask=None if self.config.attn_input_format == "thd" else attention_mask, + rotary_pos_emb=te_rope_emb, + inference_params=past_key_values, + cu_seqlens_q=kwargs.get("cu_seq_lens_q", None), + cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None), + cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None), + cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None), + max_seqlen_q=kwargs.get("max_length_q", None), + max_seqlen_kv=kwargs.get("max_length_k", None), + pad_between_seqs=kwargs.get("pad_between_seqs", None), + ) + hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer. Note that these will be in THD format; we could possibly pad diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py b/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py index bc16d5f5a..81c32eaa6 100644 --- a/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py +++ b/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py @@ -410,6 +410,49 @@ def test_train_fsdp2_fp8_thd(tmp_path, recipe_path): assert final_loss < 8.0, f"Final loss {final_loss} is too high, expected < 8.0" +def test_train_fsdp2_fp8_first_last_bf16(tmp_path, recipe_path): + """Test that FSDP2 training works with FP8 and first/last layers in bf16.""" + with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): + sanity_config = compose( + config_name="L0_sanity", + overrides=[ + f"+wandb.dir={tmp_path}", + f"checkpoint.ckpt_dir={tmp_path}", + "fp8_config.enabled=true", + "+dataset.pad_to_multiple_of=16", + "+config_kwargs.fp8_first_last_bf16=true", + ], + ) + + final_loss = main_fsdp2(sanity_config) + gc.collect() + torch.cuda.empty_cache() + + assert final_loss < 8.0, f"Final loss {final_loss} is too high, expected < 8.0" + + +def test_train_fsdp2_fp8_bshd_packed(tmp_path, recipe_path): + """Test that FSDP2 training works with FP8 enabled and BSHD packed dataloader.""" + with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): + sanity_config = compose( + config_name="L0_sanity", + overrides=[ + f"+wandb.dir={tmp_path}", + f"checkpoint.ckpt_dir={tmp_path}", + "fp8_config.enabled=true", + "use_sequence_packing=true", + "config_kwargs.attn_input_format=bshd", + "+dataset.pad_to_multiple_of=16", + ], + ) + + final_loss = main_fsdp2(sanity_config) + gc.collect() + torch.cuda.empty_cache() + + assert final_loss < 8.0, f"Final loss {final_loss} is too high, expected < 8.0" + + @requires_datacenter_hardware def test_sanity_fsdp2_cp(tmp_path, recipe_path): # Run the training script with Hydra configuration overrides diff --git a/bionemo-recipes/recipes/llama3_native_te/train_ddp.py b/bionemo-recipes/recipes/llama3_native_te/train_ddp.py index 4ff237881..f7951afd5 100644 --- a/bionemo-recipes/recipes/llama3_native_te/train_ddp.py +++ b/bionemo-recipes/recipes/llama3_native_te/train_ddp.py @@ -28,7 +28,7 @@ from transformers.models.llama.modeling_llama import LlamaForCausalLM from checkpoint import load_checkpoint_ddp, save_checkpoint_ddp, save_final_model_ddp, should_save_checkpoint -from dataset import create_bshd_dataloader, create_thd_dataloader +from dataset import create_bshd_dataloader, create_bshd_packed_dataloader, create_thd_dataloader from distributed_config import DistributedConfig from modeling_llama_te import NVLlamaConfig, NVLlamaForCausalLM from perf_logger import PerfLogger @@ -93,8 +93,14 @@ def main(args: DictConfig) -> float | None: ) if args.use_sequence_packing: - train_dataloader, dataset_or_sampler = create_thd_dataloader(dist_config, **args.dataset) + if args.config_kwargs.attn_input_format == "bshd": + # BSHD with full packing (cross-boundary attention, no cu_seqlens) + train_dataloader, dataset_or_sampler = create_bshd_packed_dataloader(dist_config, **args.dataset) + else: + # THD with packing (respects boundaries via cu_seqlens) + train_dataloader, dataset_or_sampler = create_thd_dataloader(dist_config, **args.dataset) else: + # Standard BSHD with windowing (no packing) train_dataloader, dataset_or_sampler = create_bshd_dataloader(dist_config, **args.dataset) if args.use_torch_compile: diff --git a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py index 592c11d8c..f4fda8846 100644 --- a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py @@ -36,7 +36,7 @@ save_final_model_fsdp2, should_save_checkpoint, ) -from dataset import create_bshd_dataloader, create_thd_dataloader +from dataset import create_bshd_dataloader, create_bshd_packed_dataloader, create_thd_dataloader from distributed_config import DistributedConfig from modeling_llama_te import NVLlamaConfig, NVLlamaForCausalLM from perf_logger import PerfLogger @@ -110,8 +110,14 @@ def main(args: DictConfig) -> float | None: scheduler = get_cosine_annealing_schedule_with_warmup(optimizer, **args.lr_scheduler_kwargs) if args.use_sequence_packing: - train_dataloader, dataset_or_sampler = create_thd_dataloader(dist_config, **args.dataset) + if args.config_kwargs.attn_input_format == "bshd": + # BSHD with full packing (cross-boundary attention, no cu_seqlens) + train_dataloader, dataset_or_sampler = create_bshd_packed_dataloader(dist_config, **args.dataset) + else: + # THD with packing (respects boundaries via cu_seqlens) + train_dataloader, dataset_or_sampler = create_thd_dataloader(dist_config, **args.dataset) else: + # Standard BSHD with windowing (no packing) train_dataloader, dataset_or_sampler = create_bshd_dataloader(dist_config, **args.dataset) if args.use_torch_compile: