Skip to content

Conversation

@pstjohn
Copy link
Collaborator

@pstjohn pstjohn commented Jan 15, 2026

Makes it easier to run BSHD context parallel runs in the llama3 recipe for local testing, and adds checkpoint save/resume checks to the llama3 recipe

BIO-8

Summary by CodeRabbit

Release Notes

  • New Features

    • Added checkpoint save/restore capabilities and worker configuration querying for context-parallel data loading
    • Introduced new configuration parameter for sequence padding control
  • Improvements

    • Enhanced distributed training support with improved checkpoint integration for context parallelism
  • Tests

    • Added comprehensive integration tests for distributed checkpointing with context parallelism
    • Added multi-GPU training tests with different attention format configurations

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 15, 2026

Important

Review skipped

Auto reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

📝 Walkthrough

Walkthrough

This pull request introduces state management and worker configuration access to context-parallel dataloader wrappers across four recipe/model implementations, updates dataset parameter naming for clarity, and modifies training scripts to support rank-aware dataloader initialization with revised checkpoint group references.

Changes

Cohort / File(s) Change Summary
State management in ContextParallelDataLoaderWrapper
bionemo-recipes/models/esm2/src/esm/collator.py, bionemo-recipes/models/llama3/collator.py, bionemo-recipes/recipes/esm2_native_te/collator.py, bionemo-recipes/recipes/llama3_native_te/collator.py
Added state_dict(), load_state_dict(), and num_workers property to enable dataloader state persistence and worker configuration queries, with delegation to underlying dataloader on cp\_rank 0 and graceful fallbacks on other ranks.
Dataset restructuring
bionemo-recipes/recipes/llama3_native_te/dataset.py
Renamed pad_to_multiple_of to pad_sequences_to_be_divisible_by across dataloader creation functions; removed create_cp_dataloader function and related imports; added parameter to create_thd_dataloader signature.
Hydra configuration updates
bionemo-recipes/recipes/llama3_native_te/hydra_config/L0_sanity_cp.yaml, bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml
Added use_sequence_packing: false; changed attention input format from "thd" to "bshd" and mask type from "padding\_causal" to "causal"; introduced pad_sequences_to_be_divisible_by configuration parameter.
Test updates and extensions
bionemo-recipes/recipes/llama3_native_te/tests/test_dataset.py, bionemo-recipes/recipes/llama3_native_te/tests/test_distributed_checkpointing.py, bionemo-recipes/recipes/llama3_native_te/tests/test_train_two_gpu.py
Updated dataloader initialization patterns to use THD variant with ContextParallelDataLoaderWrapper and DataCollatorForContextParallel; added distributed checkpointing tests for FSDP2 with context parallelism; introduced multi-GPU variants for BSHD and THD configurations.
Training script modernization
bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py
Modified dataloader creation to be rank-aware (rank 0 only); integrated new collator wrapper pattern; updated checkpoint save/load calls to use cp_dp_mesh group instead of dp-only group.

Sequence Diagram(s)

sequenceDiagram
    participant Rank0 as Rank 0 Process
    participant RankN as Non-Zero Rank<br/>Process
    participant Wrapper as ContextParallel<br/>DataLoaderWrapper
    participant THD as THD DataLoader
    participant Collator as DataCollatorFor<br/>ContextParallel
    participant Ckpt as Checkpoint<br/>Manager

    Note over Rank0,RankN: Dataloader Initialization
    Rank0->>Rank0: Create dataset<br/>with pad_sequences...
    Rank0->>THD: Initialize THD<br/>dataloader
    Rank0->>Collator: Attach collator
    RankN->>RankN: Set dataloader to None
    
    Note over Rank0,RankN: Wrapper Creation
    Rank0->>Wrapper: Wrap THD +<br/>collator
    RankN->>Wrapper: Wrap None with<br/>CP mesh
    
    Note over Rank0,RankN: State Management
    Rank0->>Wrapper: state_dict()
    Wrapper->>THD: Delegate to<br/>underlying loader
    RankN->>Wrapper: state_dict()
    Wrapper->>RankN: Return {}
    
    Note over Rank0,RankN: Checkpoint Operations
    Rank0->>Ckpt: Save with<br/>cp_dp_mesh group
    Rank0->>Ckpt: Load with<br/>cp_dp_mesh group
    RankN->>Ckpt: Participate in<br/>group operations
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Poem

🐰 Hops and whiskers twitch with glee,
State dicts saved for all to see,
Ranks dance round the wrapper true,
Context parallel, fresh and new!
Checkpoints leap, data flows just right,
This refactor's bunny-approved delight!

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Description check ⚠️ Warning The PR description is minimal and lacks the required template structure including description, usage, type of changes, and CI configuration sections. Fill out the full PR description template with detailed description of changes, usage examples, type of changes checkbox, and CI pipeline configuration labels.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately reflects the main changes: demonstrating DCP checkpoint save/resume with context parallel, which is the core objective of the PR.
Docstring Coverage ✅ Passed Docstring coverage is 90.32% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@pstjohn pstjohn changed the title Pstjohn/bio 8 demonstrate dcp checkpoint save resume with context parallel Demonstrate dcp checkpoint save resume with context parallel Jan 15, 2026
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 15, 2026

✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

Signed-off-by: Peter St. John <[email protected]>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Fix all issues with AI agents
In `@bionemo-recipes/models/llama3/collator.py`:
- Around line 431-437: The load_state_dict method on the collator currently
accesses state_dict["dataloader"] directly which raises KeyError if the
checkpoint omitted dataloader state; update load_state_dict to first check
cp_rank==0, then confirm hasattr(self.dataloader, "load_state_dict") and that
"dataloader" is present and not None in the passed state_dict (e.g., use
"dataloader" in state_dict and state_dict.get("dataloader") is not None) before
calling self.dataloader.load_state_dict(...), and otherwise log or warn that
dataloader state was absent and skip loading to keep resume robust.

In `@bionemo-recipes/recipes/llama3_native_te/dataset.py`:
- Around line 139-141: The docstring for the parameter
pad_sequences_to_be_divisible_by incorrectly states "Default: 16" while the
function signature (pad_sequences_to_be_divisible_by: int | None = None)
defaults to None; update the docstring in the dataset function to reflect the
actual default (e.g., "Default: None") or remove the default mention entirely so
the parameter description matches the signature.

In
`@bionemo-recipes/recipes/llama3_native_te/tests/test_distributed_checkpointing.py`:
- Around line 612-626: Phase 2 is composing the wrong config (uses
config_name="L0_sanity") so the checkpoint-specific settings are not applied;
update the compose call that creates phase2_config (inside the
initialize_config_dir block) to use config_name="L0_sanity_cp" instead, leaving
the same overrides (checkpoint.* and dataset.*) intact so the resume run picks
up the checkpoint-specific configuration.
🧹 Nitpick comments (2)
bionemo-recipes/recipes/llama3_native_te/collator.py (1)

418-451: LGTM! Implementation matches the established pattern across collator modules.

The state management methods are correctly implemented and consistent with the other collator files (esm2_native_te, models/esm2, models/llama3).

Consider consolidating the ContextParallelDataLoaderWrapper class into a shared module to reduce code duplication across the four collator files. The module docstring already notes this code "should eventually get moved to a separate package" - this PR would be a good opportunity to track that as a follow-up.

bionemo-recipes/recipes/llama3_native_te/tests/test_dataset.py (1)

798-825: Consider extracting a small helper for CP THD setup to avoid drift.
This block duplicates the test_cp_dataloader construction; a tiny helper would keep the CP padding + collator wiring consistent.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e53b726 and 90270ba.

📒 Files selected for processing (11)
  • bionemo-recipes/models/esm2/src/esm/collator.py
  • bionemo-recipes/models/llama3/collator.py
  • bionemo-recipes/recipes/esm2_native_te/collator.py
  • bionemo-recipes/recipes/llama3_native_te/collator.py
  • bionemo-recipes/recipes/llama3_native_te/dataset.py
  • bionemo-recipes/recipes/llama3_native_te/hydra_config/L0_sanity_cp.yaml
  • bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml
  • bionemo-recipes/recipes/llama3_native_te/tests/test_dataset.py
  • bionemo-recipes/recipes/llama3_native_te/tests/test_distributed_checkpointing.py
  • bionemo-recipes/recipes/llama3_native_te/tests/test_train_two_gpu.py
  • bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit inference engine (.cursorrules)

**/*.py: Fix Python linter errors immediately using Ruff for linting and formatting (configured with line-length: 119 in pyproject.toml), and verify all auto-fixes are appropriate
Ensure all Python files follow Google-style docstrings (pydocstyle convention)
Follow import sorting configuration as per isort with 2 lines after imports
Use Pyright for type checking as configured in pyproject.toml

Files:

  • bionemo-recipes/recipes/llama3_native_te/collator.py
  • bionemo-recipes/models/llama3/collator.py
  • bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py
  • bionemo-recipes/recipes/llama3_native_te/tests/test_train_two_gpu.py
  • bionemo-recipes/recipes/llama3_native_te/dataset.py
  • bionemo-recipes/recipes/esm2_native_te/collator.py
  • bionemo-recipes/models/esm2/src/esm/collator.py
  • bionemo-recipes/recipes/llama3_native_te/tests/test_dataset.py
  • bionemo-recipes/recipes/llama3_native_te/tests/test_distributed_checkpointing.py
{**/*test*.py,**/__init__.py}

📄 CodeRabbit inference engine (.cursorrules)

Ensure test files and __init__.py files respect relaxed linting rules as configured in pyproject.toml

Files:

  • bionemo-recipes/recipes/llama3_native_te/tests/test_train_two_gpu.py
  • bionemo-recipes/recipes/llama3_native_te/tests/test_dataset.py
  • bionemo-recipes/recipes/llama3_native_te/tests/test_distributed_checkpointing.py
🧠 Learnings (1)
📚 Learning: 2025-08-28T16:40:04.315Z
Learnt from: pstjohn
Repo: NVIDIA/bionemo-framework PR: 1078
File: recipes/esm2_native_te_mfsdp/train_ddp.py:103-108
Timestamp: 2025-08-28T16:40:04.315Z
Learning: PyTorch DistributedDataParallel constructor accepts a device_mesh parameter in recent versions, which supports advanced distributed training scenarios and nvFSDP configurations.

Applied to files:

  • bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py
  • bionemo-recipes/recipes/llama3_native_te/tests/test_dataset.py
🧬 Code graph analysis (4)
bionemo-recipes/recipes/llama3_native_te/collator.py (2)
bionemo-recipes/models/llama3/collator.py (3)
  • state_dict (418-429)
  • load_state_dict (431-442)
  • num_workers (445-450)
bionemo-recipes/recipes/esm2_native_te/collator.py (3)
  • state_dict (418-429)
  • load_state_dict (431-442)
  • num_workers (445-450)
bionemo-recipes/models/esm2/src/esm/collator.py (3)
bionemo-recipes/models/llama3/collator.py (3)
  • state_dict (418-429)
  • load_state_dict (431-442)
  • num_workers (445-450)
bionemo-recipes/recipes/esm2_native_te/collator.py (3)
  • state_dict (418-429)
  • load_state_dict (431-442)
  • num_workers (445-450)
bionemo-recipes/recipes/llama3_native_te/collator.py (3)
  • state_dict (418-429)
  • load_state_dict (431-442)
  • num_workers (445-450)
bionemo-recipes/recipes/llama3_native_te/tests/test_dataset.py (1)
bionemo-recipes/recipes/llama3_native_te/collator.py (3)
  • ContextParallelDataLoaderWrapper (335-450)
  • DataCollatorForContextParallel (279-332)
  • num_workers (445-450)
bionemo-recipes/recipes/llama3_native_te/tests/test_distributed_checkpointing.py (2)
bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py (1)
  • main (44-233)
bionemo-recipes/recipes/llama3_native_te/tests/conftest.py (1)
  • recipe_path (27-29)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: unit-tests (models/esm2)
  • GitHub Check: unit-tests (models/llama3)
  • GitHub Check: unit-tests (recipes/llama3_native_te)
  • GitHub Check: unit-tests (recipes/esm2_native_te)
  • GitHub Check: pre-commit
🔇 Additional comments (16)
bionemo-recipes/recipes/llama3_native_te/hydra_config/L0_sanity_cp.yaml (1)

7-11: LGTM! Configuration changes align well with BSHD context-parallel setup.

The changes correctly configure BSHD format with causal masking, and the inline comments helpfully document the THD alternatives. The use_sequence_packing: false setting is appropriate since BSHD format doesn't require sequence packing.

bionemo-recipes/recipes/esm2_native_te/collator.py (1)

418-451: LGTM! State management and worker introspection properly implemented for CP-wrapped dataloader.

The implementation correctly:

  • Restricts state operations to CP rank 0 (where the dataloader exists)
  • Gracefully handles dataloaders that don't support state_dict/load_state_dict
  • Returns sensible defaults (empty dict, 0 workers) for non-rank-0 processes

This enables checkpoint save/resume functionality for context-parallel training as intended by the PR.

bionemo-recipes/models/esm2/src/esm/collator.py (1)

418-451: LGTM! Consistent implementation with other collator modules.

State management and worker introspection methods are correctly implemented, maintaining consistency with the other ContextParallelDataLoaderWrapper implementations across the codebase.

bionemo-recipes/recipes/llama3_native_te/dataset.py (2)

120-121: LGTM! Parameter rename improves semantic clarity.

The renamed parameter pad_sequences_to_be_divisible_by better conveys its purpose for context parallelism alignment. The mapping to HuggingFace's pad_to_multiple_of parameter is correctly maintained internally.

Also applies to: 168-171


220-220: LGTM! THD dataloader correctly wires the new parameter.

The pad_sequences_to_be_divisible_by parameter is properly passed to DataCollatorWithFlattening, enabling context-parallel compatible padding in the THD flow.

Also applies to: 268-271

bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py (4)

31-32: Import updates look good.


137-146: Process group wiring for CP/DP mesh looks correct.


197-209: Checkpoint save now aligned to CP/DP mesh group.


113-133: No action required. Both create_bshd_dataloader and create_thd_dataloader explicitly accept the pad_sequences_to_be_divisible_by parameter (declared as int | None = None in both function signatures). The setdefault call is safe for both dataloader paths.

Likely an incorrect or invalid review comment.

bionemo-recipes/recipes/llama3_native_te/tests/test_train_two_gpu.py (2)

194-213: Nice coverage for the BSHD CP path.


218-234: THD CP test config updates look consistent.

bionemo-recipes/recipes/llama3_native_te/tests/test_distributed_checkpointing.py (2)

41-41: Import looks good.


714-725: The concern about rank 1 dataloader files may be unfounded.

save_dataloader() calls torch.save() unconditionally on all ranks (line 413), even when ContextParallelDataLoaderWrapper.state_dict() returns {} on non-zero CP ranks. The function always adds num_workers and num_ranks metadata before saving, ensuring dataloader files are created for all ranks including rank 1. The test assertions should succeed.

Likely an incorrect or invalid review comment.

bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml (1)

25-25: Config addition looks good.

bionemo-recipes/recipes/llama3_native_te/tests/test_dataset.py (2)

27-28: Imports align with the new CP wrapper flow.


707-728: Rank‑0‑only THD construction + CP wrapper looks correct.

✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.

@pstjohn pstjohn force-pushed the pstjohn/bio-8-demonstrate-dcp-checkpoint-save-resume-with-context-parallel branch from d214d6e to 80e29e5 Compare January 16, 2026 16:01
Copy link
Collaborator

@jomitchellnv jomitchellnv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@pstjohn pstjohn added this pull request to the merge queue Jan 16, 2026
github-merge-queue bot pushed a commit that referenced this pull request Jan 16, 2026
Makes it easier to run BSHD context parallel runs in the llama3 recipe
for local testing, and adds checkpoint save/resume checks to the llama3
recipe

BIO-8

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

## Release Notes

* **New Features**
* Added checkpoint save/restore capabilities and worker configuration
querying for context-parallel data loading
  * Introduced new configuration parameter for sequence padding control

* **Improvements**
* Enhanced distributed training support with improved checkpoint
integration for context parallelism

* **Tests**
* Added comprehensive integration tests for distributed checkpointing
with context parallelism
* Added multi-GPU training tests with different attention format
configurations

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Peter St. John <[email protected]>
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks Jan 16, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants