[checkpoint] Expose thread count for checkpoint saving to control parallel writes per worker#2572
[checkpoint] Expose thread count for checkpoint saving to control parallel writes per worker#2572ananthsub wants to merge 4 commits intoNVIDIA-NeMo:mainfrom
Conversation
|
/ok to test dbfbe69 |
|
/ok to test ab3c72b |
|
/ok to test 21d90cd |
b64b09c to
675f522
Compare
|
/ok to test 675f522 |
|
I think mlm needs another upgrade? it includes another issue that Oliver has reverted for encoder_and_decoder... |
📝 WalkthroughWalkthroughThis PR introduces thread-based distributed checkpoint saving support. It adds a Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs). Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
src/megatron/bridge/training/checkpointing.py (1)
651-658: Add a fail-fast guard for invalidthread_count.At Line 655,
ckpt_cfg.thread_countis passed through directly. Adding a local check (>= 1) will produce a clearer error path for misconfiguration.Suggested patch
validate_sharding_integrity = True if ckpt_cfg.ckpt_format == "torch_dist": + if ckpt_cfg.thread_count < 1: + raise ValueError( + f"checkpoint.thread_count must be >= 1 for torch_dist, got {ckpt_cfg.thread_count}" + ) save_strategy = TorchDistSaveShardedStrategy( "torch_dist", 1, thread_count=ckpt_cfg.thread_count, ) else:🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/megatron/bridge/training/checkpointing.py` around lines 651 - 658, Add a fail-fast check for ckpt_cfg.thread_count before constructing the TorchDistSaveShardedStrategy: validate that ckpt_cfg.thread_count is an int >= 1 and raise a clear ValueError (or similar) if not, so misconfiguration surfaces immediately; keep this validation local to the branch that uses TorchDistSaveShardedStrategy (the code paths that assign save_strategy) and do not change behavior for get_default_save_sharded_strategy.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/functional_tests/utils.py`:
- Around line 105-107: The default for thread_count in verify_checkpoint_files
is outdated (2) and must match the new checkpoint saving default of 1; update
the function signature of verify_checkpoint_files to set thread_count: int = 1
and also update the other identical occurrence of verify_checkpoint_files in
this file so any callers that omit thread_count will assert the correct .distcp
count.
---
Nitpick comments:
In `@src/megatron/bridge/training/checkpointing.py`:
- Around line 651-658: Add a fail-fast check for ckpt_cfg.thread_count before
constructing the TorchDistSaveShardedStrategy: validate that
ckpt_cfg.thread_count is an int >= 1 and raise a clear ValueError (or similar)
if not, so misconfiguration surfaces immediately; keep this validation local to
the branch that uses TorchDistSaveShardedStrategy (the code paths that assign
save_strategy) and do not change behavior for get_default_save_sharded_strategy.
ℹ️ Review info
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (13)
3rdparty/Megatron-LMsrc/megatron/bridge/training/checkpointing.pysrc/megatron/bridge/training/config.pytests/functional_tests/recipes/test_gpt_oss_recipes_finetune.pytests/functional_tests/recipes/test_llama_recipes_distill_3b-1b.pytests/functional_tests/recipes/test_nemotronh_recipes_finetune.pytests/functional_tests/recipes/utils.pytests/functional_tests/training/test_finetune_lora.pytests/functional_tests/training/test_megatron_fsdp.pytests/functional_tests/training/test_pretrain_resume.pytests/functional_tests/training/test_seqpacking_cp_example.pytests/functional_tests/training/test_sft.pytests/functional_tests/utils.py
do you know the commit to test against that's recommended? this change should future-proof regardless from the worker changes on the mcore dist checkpoint side. as long as the bump PR either includes this commit or is merged after, the tests will pass |
0148f2a to
1c36b03
Compare
|
/ok to test 4501b2a |
|
/ok to test bee037c |
Signed-off-by: Ananth Subramaniam <ansubramania@nvidia.com>
Signed-off-by: Ananth Subramaniam <ansubramania@nvidia.com>
Signed-off-by: Ananth Subramaniam <ansubramania@nvidia.com>
bee037c to
c912f27
Compare
What does this PR do ?
Currently megatron bridge relies on the default value for the thread count from mcore's dist_checkpointing for checkpoint saving: https://github.com/NVIDIA/Megatron-LM/blob/a7c207f84f0a5237d9a94f8b602a5f7d2c6e3389/megatron/core/dist_checkpointing/strategies/torch.py#L602
This PR makes the setting explicit, so it can be controlled separately from the mcore default.
This PR also updates the default value to 1 to prepare for future daemon changes, related NVIDIA/Megatron-LM#3424 / NVIDIA/Megatron-LM#3630
Changelog
GitHub Actions CI
See the CI sectionin the Contributing doc for how to trigger the CI. A Nvidia developer will need to approve and trigger the CI for external contributors.
Before your PR is "Ready for review"
Pre checks:
If you haven't finished some of the above items you can still open "Draft" PR.
Additional Information
Summary by CodeRabbit
New Features
Updates
Tests