Skip to content

[checkpoint] Expose thread count for checkpoint saving to control parallel writes per worker#2572

Open
ananthsub wants to merge 4 commits intoNVIDIA-NeMo:mainfrom
ananthsub:generalize-ckpt-assertion
Open

[checkpoint] Expose thread count for checkpoint saving to control parallel writes per worker#2572
ananthsub wants to merge 4 commits intoNVIDIA-NeMo:mainfrom
ananthsub:generalize-ckpt-assertion

Conversation

@ananthsub
Copy link
Contributor

@ananthsub ananthsub commented Feb 26, 2026

What does this PR do ?

Currently megatron bridge relies on the default value for the thread count from mcore's dist_checkpointing for checkpoint saving: https://github.com/NVIDIA/Megatron-LM/blob/a7c207f84f0a5237d9a94f8b602a5f7d2c6e3389/megatron/core/dist_checkpointing/strategies/torch.py#L602

This PR makes the setting explicit, so it can be controlled separately from the mcore default.
This PR also updates the default value to 1 to prepare for future daemon changes, related NVIDIA/Megatron-LM#3424 / NVIDIA/Megatron-LM#3630

Changelog

  • Expose thread count for checkpoint saving to control parallel writes per rank and update default to 1

GitHub Actions CI

See the CI sectionin the Contributing doc for how to trigger the CI. A Nvidia developer will need to approve and trigger the CI for external contributors.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

If you haven't finished some of the above items you can still open "Draft" PR.

Additional Information

  • Related to # (issue)

Summary by CodeRabbit

  • New Features

    • Added configurable thread_count parameter for checkpoint operations in torch_dist format.
  • Updates

    • Updated checkpoint verification logic to calculate expected file counts using world_size × thread_count configuration.
    • Updated Megatron-LM dependency to latest version.
  • Tests

    • Updated functional tests to validate new thread_count checkpoint parameter.

@copy-pr-bot
Copy link

copy-pr-bot bot commented Feb 26, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@ananthsub
Copy link
Contributor Author

/ok to test dbfbe69

@ananthsub
Copy link
Contributor Author

/ok to test ab3c72b

@ananthsub
Copy link
Contributor Author

/ok to test 21d90cd

@ananthsub ananthsub force-pushed the generalize-ckpt-assertion branch 2 times, most recently from b64b09c to 675f522 Compare February 26, 2026 19:38
@ananthsub
Copy link
Contributor Author

/ok to test 675f522

@yaoyu-33
Copy link
Contributor

I think mlm needs another upgrade? it includes another issue that Oliver has reverted for encoder_and_decoder...

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 26, 2026

📝 Walkthrough

Walkthrough

This PR introduces thread-based distributed checkpoint saving support. It adds a thread_count configuration parameter to CheckpointConfig, updates the checkpoint save strategy selection in checkpointing.py to use TorchDistSaveShardedStrategy for the torch_dist format, and adjusts checkpoint file verification logic across tests to account for thread count when calculating expected file counts.

Changes

Cohort / File(s) Summary
Submodule Update
3rdparty/Megatron-LM
Updated Megatron-LM submodule pointer to new commit; no local code changes.
Checkpoint Configuration & Strategy
src/megatron/bridge/training/checkpointing.py, src/megatron/bridge/training/config.py
Added thread_count field to CheckpointConfig (default: 1). Modified save_checkpoint to instantiate TorchDistSaveShardedStrategy for torch_dist format, passing thread_count as a constructor parameter for sharded checkpoint control.
Checkpoint Verification Infrastructure
tests/functional_tests/utils.py
Updated verify_checkpoint_files signature to accept ckpt_format and thread_count parameters. Changed expected file count calculation for torch_dist format from 2 * world_size to thread_count * world_size. Updated error messages to reference the actual checkpoint format.
Test Updates (Checkpoint Verification)
tests/functional_tests/recipes/test_gpt_oss_recipes_finetune.py, tests/functional_tests/recipes/test_llama_recipes_distill_3b-1b.py, tests/functional_tests/recipes/test_nemotronh_recipes_finetune.py, tests/functional_tests/recipes/utils.py, tests/functional_tests/training/test_finetune_lora.py, tests/functional_tests/training/test_megatron_fsdp.py, tests/functional_tests/training/test_pretrain_resume.py, tests/functional_tests/training/test_seqpacking_cp_example.py, tests/functional_tests/training/test_sft.py
Updated all calls to verify_checkpoint_files to pass ckpt_format and thread_count extracted from respective configuration objects, enabling format-aware and thread-aware verification.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related PRs

Suggested labels

Run CICD

Suggested reviewers

  • cuichenx
  • maanug-nv
🚥 Pre-merge checks | ✅ 3 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Test Results For Major Changes ⚠️ Warning PR contains major changes to checkpoint configuration and saving logic affecting 9 test files, but lacks documentation of test results, performance validation, or regression testing confirmation. Include test results confirming checkpoint functionality with new thread_count parameter, no convergence regression, and resolution of the thread_count default mismatch (1 vs 2) between components.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title directly summarizes the main change: exposing thread count configuration for checkpoint saving to control parallel writes per worker, which aligns with the PR's primary objective and the changes across multiple files.
Docstring Coverage ✅ Passed Docstring coverage is 94.12% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs).
Share your feedback on Discord.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
src/megatron/bridge/training/checkpointing.py (1)

651-658: Add a fail-fast guard for invalid thread_count.

At Line 655, ckpt_cfg.thread_count is passed through directly. Adding a local check (>= 1) will produce a clearer error path for misconfiguration.

Suggested patch
                 validate_sharding_integrity = True
                 if ckpt_cfg.ckpt_format == "torch_dist":
+                    if ckpt_cfg.thread_count < 1:
+                        raise ValueError(
+                            f"checkpoint.thread_count must be >= 1 for torch_dist, got {ckpt_cfg.thread_count}"
+                        )
                     save_strategy = TorchDistSaveShardedStrategy(
                         "torch_dist",
                         1,
                         thread_count=ckpt_cfg.thread_count,
                     )
                 else:
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/megatron/bridge/training/checkpointing.py` around lines 651 - 658, Add a
fail-fast check for ckpt_cfg.thread_count before constructing the
TorchDistSaveShardedStrategy: validate that ckpt_cfg.thread_count is an int >= 1
and raise a clear ValueError (or similar) if not, so misconfiguration surfaces
immediately; keep this validation local to the branch that uses
TorchDistSaveShardedStrategy (the code paths that assign save_strategy) and do
not change behavior for get_default_save_sharded_strategy.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/functional_tests/utils.py`:
- Around line 105-107: The default for thread_count in verify_checkpoint_files
is outdated (2) and must match the new checkpoint saving default of 1; update
the function signature of verify_checkpoint_files to set thread_count: int = 1
and also update the other identical occurrence of verify_checkpoint_files in
this file so any callers that omit thread_count will assert the correct .distcp
count.

---

Nitpick comments:
In `@src/megatron/bridge/training/checkpointing.py`:
- Around line 651-658: Add a fail-fast check for ckpt_cfg.thread_count before
constructing the TorchDistSaveShardedStrategy: validate that
ckpt_cfg.thread_count is an int >= 1 and raise a clear ValueError (or similar)
if not, so misconfiguration surfaces immediately; keep this validation local to
the branch that uses TorchDistSaveShardedStrategy (the code paths that assign
save_strategy) and do not change behavior for get_default_save_sharded_strategy.

ℹ️ Review info

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between a3e44b4 and 675f522.

📒 Files selected for processing (13)
  • 3rdparty/Megatron-LM
  • src/megatron/bridge/training/checkpointing.py
  • src/megatron/bridge/training/config.py
  • tests/functional_tests/recipes/test_gpt_oss_recipes_finetune.py
  • tests/functional_tests/recipes/test_llama_recipes_distill_3b-1b.py
  • tests/functional_tests/recipes/test_nemotronh_recipes_finetune.py
  • tests/functional_tests/recipes/utils.py
  • tests/functional_tests/training/test_finetune_lora.py
  • tests/functional_tests/training/test_megatron_fsdp.py
  • tests/functional_tests/training/test_pretrain_resume.py
  • tests/functional_tests/training/test_seqpacking_cp_example.py
  • tests/functional_tests/training/test_sft.py
  • tests/functional_tests/utils.py

@ananthsub
Copy link
Contributor Author

I think mlm needs another upgrade? it includes another issue that Oliver has reverted for encoder_and_decoder...

do you know the commit to test against that's recommended? this change should future-proof regardless from the worker changes on the mcore dist checkpoint side. as long as the bump PR either includes this commit or is merged after, the tests will pass

@ananthsub ananthsub force-pushed the generalize-ckpt-assertion branch from 0148f2a to 1c36b03 Compare February 27, 2026 00:10
@ananthsub
Copy link
Contributor Author

/ok to test 4501b2a

@ananthsub
Copy link
Contributor Author

/ok to test bee037c

Signed-off-by: Ananth Subramaniam <ansubramania@nvidia.com>
Signed-off-by: Ananth Subramaniam <ansubramania@nvidia.com>
Signed-off-by: Ananth Subramaniam <ansubramania@nvidia.com>
Signed-off-by: Ananth Subramaniam <ansubramania@nvidia.com>
@ananthsub ananthsub force-pushed the generalize-ckpt-assertion branch from bee037c to c912f27 Compare February 27, 2026 04:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants