Skip to content

Conversation

@savitha-eng
Copy link
Collaborator

@savitha-eng savitha-eng commented Jan 14, 2026

Description

This PR adds FP8 support enhancements for llama3 training building off #1416, including:

  1. Packed BSHD dataloader - A new SequencePackingIterableDataset and create_bshd_packed_dataloader that packs sequences by concatenating across boundaries for BSHD format. Unlike THD packing (which tracks boundaries with cu_seqlens), this yields fixed-length samples with no padding, allowing attention to flow across packed sequences. This has been used as a baseline to compare THD to for FP8 experimentation.

  2. BSHD packing toggle in training scripts - Updated train_fsdp2.py and train_ddp.py to automatically select the dataloader based on config:

    • use_sequence_packing=true + attn_input_format=bshd → BSHD packed (cross-boundary attention, no cu_seqlens)
    • use_sequence_packing=true + attn_input_format=thd → THD packed (respects boundaries via cu_seqlens)
    • use_sequence_packing=false → BSHD unpacked (standard windowing)
  3. LM head bf16 for FP8 - Wraps the lm_head forward pass with fp8_autocast(enabled=False) to keep it in bf16 for numerical stability during FP8 training. (this is currently also present in Add fp8 tests for llama3 #1416)

  4. Configurable first/last layer bf16 - Adds fp8_first_last_bf16 config option that keeps the first and last transformer layers in bf16 while using FP8 for middle layers. This can improve numerical stability during FP8 training.

  5. FP8 tests - Adds training tests for FP8 in BSHD, THD, and BSHD packed modes, plus a test for the first/last bf16 feature.

Usage

Enable FP8 training with first/last layers in bf16

with initialize_config_dir(config_dir="hydra_config", version_base="1.2"):
config = compose(
config_name="L0_sanity",
overrides=[
"fp8_config.enabled=true",
"+dataset.pad_to_multiple_of=16", # Required for FP8
"+config_kwargs.fp8_first_last_bf16=true", # Keep first/last layers in bf16
],
)

Enable FP8 with BSHD packed dataloader via config

with initialize_config_dir(config_dir="hydra_config", version_base="1.2"):
config = compose(
config_name="L0_sanity",
overrides=[
"fp8_config.enabled=true",
"use_sequence_packing=true",
"config_kwargs.attn_input_format=bshd", # Triggers BSHD packed dataloader
"+dataset.pad_to_multiple_of=16",
],
)

Or use packed BSHD dataloader directly

from dataset import create_bshd_packed_dataloader

dataloader, dataset = create_bshd_packed_dataloader(
distributed_config=dist_config,
tokenizer_name_or_path="nvidia/Llama-3.1-8B-Instruct-FP8",
load_dataset_kwargs={"path": "parquet", "data_files": "data.parquet", "streaming": True},
micro_batch_size=4,
max_seq_length=8192,
pad_to_multiple_of=16, # For FP8 compatibility
)### Type of changes

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Refactor
  • Documentation update
  • Other (please describe):

CI Pipeline Configuration

Configure CI behavior by applying the relevant labels. By default, only basic unit tests are run.

  • ciflow:skip - Skip all CI tests for this PR
  • ciflow:notebooks - Run Jupyter notebooks execution tests for bionemo2
  • ciflow:slow - Run slow single GPU integration tests marked as @pytest.mark.slow for bionemo2
  • ciflow:all - Run all tests (unit tests, slow tests, and notebooks) for bionemo2. This label can be used to enforce running tests for all bionemo2.
  • ciflow:all-recipes - Run tests for all recipes (under bionemo-recipes). This label can be used to enforce running tests for all recipes.

Unit tests marked as @pytest.mark.multi_gpu or @pytest.mark.distributed are not run in the PR pipeline.

For more details, see CONTRIBUTING

Note

By default, only basic unit tests are run. Add appropriate labels to enable an additional test coverage.

Authorizing CI Runs

We use copy-pr-bot to manage authorization of CI
runs on NVIDIA's compute resources.

  • If a pull request is opened by a trusted user and contains only trusted changes, the pull request's code will
    automatically be copied to a pull-request/ prefixed branch in the source repository (e.g. pull-request/123)
  • If a pull request is opened by an untrusted user or contains untrusted changes, an NVIDIA org member must leave an
    /ok to test comment on the pull request to trigger CI. This will need to be done for each new commit.

Pre-submit Checklist

  • I have tested these changes locally
  • I have updated the documentation accordingly
  • I have added/updated tests as needed
  • All existing tests pass successfully

@copy-pr-bot
Copy link

copy-pr-bot bot commented Jan 14, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@savitha-eng savitha-eng changed the title Support for packed bshd seq dataloader and more advanced fp8 integration Support for packed bshd seq dataloader and more advanced fp8 integration for llama3 Jan 14, 2026
@savitha-eng savitha-eng marked this pull request as ready for review January 14, 2026 08:27
@savitha-eng savitha-eng changed the title Support for packed bshd seq dataloader and more advanced fp8 integration for llama3 BIO-48: Support for packed bshd seq dataloader and more advanced fp8 integration for llama3 Jan 14, 2026


@dataclass
class SequencePackingIterableDataset(torch.utils.data.IterableDataset):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we just wrap the existing TokenPackingDataset to do this?

set max_tokens_per_batch to be the desired sequence length, set split_samples=True; then before you return the sample just concatenate along the sequence dimension

return ContextParallelDataLoaderWrapper(train_dataloader, cp_mesh), tokenized_dataset


def create_bshd_packed_dataloader(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this be an option in the existing create_bshd_dataset function? might be simpler rather than repeating all this.

I'm wondering though if there's a better way to structure this dataset file where it can be more modular

f"+wandb.dir={tmp_path}",
f"checkpoint.ckpt_dir={tmp_path}",
"fp8_config.enabled=true",
"+dataset.pad_to_multiple_of=16",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, if you have the fully-packed BSHD dataset, could we just use that here?

@trvachov
Copy link
Collaborator

Quick question to check my understanding -- "Packed BSHD dataloader" is a feature that is only relevant for llama3/autoregressive models, is that right? I.e. such a dataloader for BERT would not be useful to train an embedding model?

Copy link
Collaborator

pstjohn commented Jan 14, 2026

It mirrors what Megatron-LM does when pre-training large models, which is just concatenating across sequence boundaries without really worrying about document start / ends. Not sure whether BERT would be fine if you trained it like that, but i think the auto-regressive loss makes it slightly less problematic

@savitha-eng
Copy link
Collaborator Author

/ok to test f5bc6c3

- Add SequencePackingIterableDataset to collator.py for fixed-length BSHD samples
- Add create_bshd_packed_dataloader function to dataset.py
- Add pad_to_multiple_of parameter to create_bshd_dataloader for FP8 compatibility
Adds a config parameter fp8_first_last_bf16 that when enabled keeps the
first and last transformer layers in bf16 while using FP8 for middle layers.
This can improve numerical stability during FP8 training.
- Add test_train_fsdp2_fp8_first_last_bf16 to test the fp8_first_last_bf16 config
- Add pad_sequences_to_be_divisible_by=16 to THD FP8 test for proper FP8 padding
- Update train_fsdp2.py and train_ddp.py to toggle between dataloaders:
  - use_sequence_packing=true + attn_input_format=bshd -> BSHD packed
  - use_sequence_packing=true + attn_input_format=thd -> THD packed
  - use_sequence_packing=false -> BSHD unpacked
- Add test_train_fsdp2_fp8_bshd_packed test for FP8 with BSHD packing

Signed-off-by: Savitha Srinivasan <[email protected]>
@savitha-eng savitha-eng force-pushed the savitha/fp8-bshd-clean-integration branch from f5bc6c3 to 7e75003 Compare January 15, 2026 08:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants