Skip to content

[checkpoint] Refactor checkpoint utils and expose public API to load model weights#2239

Open
ananthsub wants to merge 2 commits intoNVIDIA-NeMo:mainfrom
ananthsub:ckpt-utils-consolidate
Open

[checkpoint] Refactor checkpoint utils and expose public API to load model weights#2239
ananthsub wants to merge 2 commits intoNVIDIA-NeMo:mainfrom
ananthsub:ckpt-utils-consolidate

Conversation

@ananthsub
Copy link
Contributor

@ananthsub ananthsub commented Feb 5, 2026

What does this PR do ?

  • Consolidate duplicate code for handling a checkpoint path for when either the top-level checkpoint directory is provided, or a checkpoint corresponding to a specific iteration is passed in into a utility function
  • Expose a public load_model_weights function which to load just the model states from the checkpoint. this covers both torch_dist and fsdp_dtensor checkpoint formats transparently
  • now models trained with megatron fsdp can be exported to hf through the auto bridge

Changelog

  • Add specific line by line info of high level changes in this PR.

GitHub Actions CI

See the CI sectionin the Contributing doc for how to trigger the CI. A Nvidia developer will need to approve and trigger the CI for external contributors.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

If you haven't finished some of the above items you can still open "Draft" PR.

Additional Information

  • Related to # (issue)

Summary by CodeRabbit

Release Notes

  • New Features

    • Added support for FSDP DTensor checkpoint loading format with enhanced state dict conversion capabilities.
    • Improved checkpoint path resolution to intelligently detect and select the correct iteration directory when multiple versions exist.
  • Bug Fixes

    • Enhanced checkpoint discovery logic to reliably locate configuration files across various checkpoint directory structures.
    • Improved error handling for missing checkpoint paths and malformed iteration directories.

Signed-off-by: Ananth Subramaniam <ansubramania@nvidia.com>
@copy-pr-bot
Copy link

copy-pr-bot bot commented Feb 5, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@ananthsub
Copy link
Contributor Author

/ok to test d9bab4f

Signed-off-by: Ananth Subramaniam <ansubramania@nvidia.com>
@ananthsub
Copy link
Contributor Author

/ok to test d96819a

@ananthsub ananthsub requested a review from yaoyu-33 February 5, 2026 17:38
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 5, 2026

📝 Walkthrough

Walkthrough

This PR refactors checkpoint path resolution by introducing centralized utilities to handle iteration-based directory discovery and adds support for FSDP DTensor checkpoint format. Manual iteration directory scanning is replaced with a single resolve_checkpoint_path() function throughout the codebase.

Changes

Cohort / File(s) Summary
Checkpoint Path Resolution Utilities
src/megatron/bridge/training/utils/checkpoint_utils.py
Introduces resolve_checkpoint_path(), is_iteration_directory(), and internal helpers (_list_iteration_directories, _get_latest_iteration_path) for standardized iteration directory discovery. Refactors get_hf_model_id_from_checkpoint() to use these new utilities. Adds _ITER_PREFIX constants for iteration naming convention.
FSDP DTensor Checkpoint Loading
src/megatron/bridge/training/checkpointing.py
Adds three new internal functions: _load_fsdp_dtensor_state_dict(), _convert_dtensor_state_dict_to_full(), and _load_model_weights_fsdp_dtensor() to support FSDP DTensor checkpoint format. Updates load_model_weights() to detect checkpoint format and dispatch to appropriate loader (torch_dist or fsdp_dtensor). Refactors existing FSDP loading code to use centralized utilities.
Path Resolution Integration
src/megatron/bridge/models/conversion/auto_bridge.py, examples/conversion/compare_text_generation.py, examples/conversion/convert_checkpoints.py
Replaces manual iter_* directory discovery with resolve_checkpoint_path() calls. Centralizes checkpoint path resolution logic instead of inline folder scanning.
Post-training Checkpoint Path Handling
src/megatron/bridge/training/post_training/checkpointing.py
Refactors _get_modelopt_checkpoint_path() to use new iteration directory helpers. Replaces manual iter_* detection and on-disk inspection with delegated helper-based approach.
Model Load/Save API Updates
src/megatron/bridge/training/model_load_save.py
Updates to use new load_model_weights() API. Simplifies control flow by collapsing branching logic for state dict handling into unified return pattern.
Core Test Updates
tests/unit_tests/training/test_checkpointing.py, tests/unit_tests/training/test_model_load_save.py, tests/unit_tests/models/test_auto_bridge.py
Adds comprehensive test coverage for new checkpoint format detection and FSDP DTensor loading paths. Updates mocks to use resolve_checkpoint_path() and _get_checkpoint_format() instead of path manipulation. Adds tests for return_state_dict behavior and error conditions.
Checkpoint Utilities Tests
tests/unit_tests/training/utils/test_checkpoint_utils.py, tests/unit_tests/training/post_training/test_checkpointing.py
Introduces extensive test coverage for iteration directory utilities (is_iteration_directory, _get_latest_iteration_path, resolve_checkpoint_path). Refactors post-training checkpoint tests to focus on filesystem-driven validation rather than internal mocking. Tests iteration selection, malformed directory handling, and edge cases.

Sequence Diagram(s)

sequenceDiagram
    participant Client as Caller
    participant LoadWeights as load_model_weights()
    participant Resolver as resolve_checkpoint_path()
    participant Format as _get_checkpoint_format()
    participant TorchDist as torch_dist Loader
    participant FSDPDTensor as fsdp_dtensor Loader
    participant Model as Model Instance

    Client->>LoadWeights: load_model_weights(model, checkpoint_path)
    LoadWeights->>Resolver: resolve_checkpoint_path(checkpoint_path)
    Resolver-->>LoadWeights: resolved_path
    LoadWeights->>Format: _get_checkpoint_format(resolved_path)
    Format-->>LoadWeights: format_type
    
    alt format == "torch_dist"
        LoadWeights->>TorchDist: load_model_weights_from_checkpoint(model, resolved_path, ...)
        TorchDist->>Model: load weights
        TorchDist-->>LoadWeights: state_dict (if return_state_dict)
    else format == "fsdp_dtensor"
        LoadWeights->>FSDPDTensor: _load_model_weights_fsdp_dtensor(model, resolved_path, ...)
        FSDPDTensor->>Model: load DTensor checkpoint via DCP
        FSDPDTensor-->>LoadWeights: state_dict (if return_state_dict)
    else unsupported format
        LoadWeights-->>Client: NotImplementedError
    end
    
    LoadWeights-->>Client: state_dict or model
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested labels

Run CICD

Suggested reviewers

  • yaoyu-33
  • maanug-nv
🚥 Pre-merge checks | ✅ 3 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Test Results For Major Changes ⚠️ Warning PR contains major checkpoint loading changes with 6,196 test lines but no documented CI execution results, test pass/fail counts, or evidence that three identified implementation bugs were fixed. Add confirmation from NVIDIA developer that CI passed, include test summary with pass/fail counts, provide CI workflow URL, verify three review bugs are fixed with test coverage, and confirm regression testing shows identical checkpoint loading results.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately summarizes the main changes: refactoring checkpoint utilities and exposing a public API for loading model weights. It directly corresponds to the core objectives of the PR.
Docstring Coverage ✅ Passed Docstring coverage is 92.77% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Fix all issues with AI agents
In `@src/megatron/bridge/training/checkpointing.py`:
- Around line 1394-1399: The call to _load_fsdp_dtensor_state_dict from
_load_model_weights_fsdp_dtensor ignores the strict parameter, so partial-load
behavior is always allowed; forward the strict argument when invoking
_load_fsdp_dtensor_state_dict (and any other internal DTensor-loading helpers
called in the same flow) so that the strict flag passed into
_load_model_weights_fsdp_dtensor is honored, and ensure the return_state_dict
branch that converts DTensors to full tensors still receives the state_dict
produced with the strict semantics.
- Around line 1386-1393: The code builds a model state dict using
model_instance.sharded_state_dict(), but DTensor checkpoints are saved with
state_dict_for_save_checkpoint() so loading will mismatch; replace the call to
model_instance.sharded_state_dict() with
model_instance.state_dict_for_save_checkpoint() (or use the same helper path
used by _generate_model_state_dict when ckpt_format == "fsdp_dtensor") and then
pass that result into preprocess_fsdp_dtensor_state_dict(...) so the produced
state_dict matches the saved DTensor checkpoint format; keep references to
unwrap_model and model_instance so the change is localized to where
model_state_dict is constructed.

In `@src/megatron/bridge/training/utils/checkpoint_utils.py`:
- Around line 166-184: The is_iteration_directory function currently only
accepts exactly 7 digits after the _ITER_PREFIX which breaks when
get_checkpoint_name emits longer iteration numbers; update the check in
is_iteration_directory to allow any suffix length >= 7 as long as all characters
are digits (i.e., keep checking basename.startswith(_ITER_PREFIX), then ensure
len(suffix) >= 7 and suffix.isdigit()). Reference the function name
is_iteration_directory and the constants _ITER_PREFIX/_ITER_PREFIX_LEN when
making the change and update the docstring to reflect "7 or more digits" if
present.
🧹 Nitpick comments (5)
src/megatron/bridge/training/utils/checkpoint_utils.py (1)

35-36: Align new global constants with the G_ prefix convention.

The new globals _ITER_PREFIX and _ITER_PREFIX_LEN don’t follow the required G_ prefix for module-level variables. Consider renaming and updating references to keep naming consistent.

♻️ Suggested refactor
-_ITER_PREFIX: str = "iter_"
-_ITER_PREFIX_LEN: int = len(_ITER_PREFIX)
+G_ITER_PREFIX: str = "iter_"
+G_ITER_PREFIX_LEN: int = len(G_ITER_PREFIX)

As per coding guidelines, Use upper snake_case and prefix 'G' for global variables (e.g., G_MY_GLOBAL).

tests/unit_tests/training/utils/test_checkpoint_utils.py (1)

788-959: Update iteration-name expectations and add unit-test markers.

If is_iteration_directory is relaxed to accept ≥7-digit iterations, flip the 8‑digit case to valid and consider adding an explicit >7‑digit valid case. Also, mark these new classes with @pytest.mark.unit (or a module‑level pytestmark) to satisfy test categorization guidance.

🧪 Suggested updates
+@pytest.mark.unit
 class TestIsIterationDirectory:
@@
-        assert is_iteration_directory("iter_12345678") is False  # 8 digits
+        assert is_iteration_directory("iter_12345678") is True  # 8+ digits valid

As per coding guidelines, Use 'pytest.mark' to categorize tests (unit, integration, system).

tests/unit_tests/training/post_training/test_checkpointing.py (1)

47-55: Add unit-test markers for new test classes.

Consider adding @pytest.mark.unit on the new test classes (or a module‑level pytestmark) to align with test categorization guidance.

🧪 Example
+@pytest.mark.unit
 class TestGetModeloptCheckpointPath:

As per coding guidelines, Use 'pytest.mark' to categorize tests (unit, integration, system).

src/megatron/bridge/models/conversion/auto_bridge.py (1)

655-659: Consider chaining the ImportError for better debugging.

The static analysis tool flagged that the re-raised exception should use exception chaining for clearer error context.

🔧 Proposed fix for exception chaining
         try:
             from megatron.bridge.training.model_load_save import load_megatron_model
             from megatron.bridge.training.utils.checkpoint_utils import resolve_checkpoint_path
         except ImportError:
-            raise ImportError("megatron.bridge.training is not available.")
+            raise ImportError("megatron.bridge.training is not available.") from None
tests/unit_tests/training/test_checkpointing.py (1)

1604-1627: Consider asserting mock_load_fsdp was called.

The test verifies that a warning is printed when fully_parallel_load=True is used with fsdp_dtensor format, but it doesn't verify that loading still proceeds. Adding an assertion would make the test more robust.

🔧 Suggested improvement
         # Should warn about unsupported fully_parallel_load
         mock_print.assert_called()
         warning_calls = [call for call in mock_print.call_args_list if "fully_parallel_load" in str(call)]
         assert len(warning_calls) > 0
+
+        # Verify loading still proceeds despite the warning
+        mock_load_fsdp.assert_called_once_with(
+            mock_model, "/resolved/checkpoint/iter_0000005", strict=True, return_state_dict=False
+        )

Comment on lines +1386 to +1393
model_list = unwrap_model(model)
model_instance = model_list[0]

# Build model sharded state dict and preprocess for FSDP DTensor format
model_state_dict = model_instance.sharded_state_dict()
state_dict = {"model": model_state_dict}
state_dict = preprocess_fsdp_dtensor_state_dict(None, state_dict, model_instance)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, check the file exists and get size
wc -l src/megatron/bridge/training/checkpointing.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 121


🏁 Script executed:

# Get the context around lines 1386-1393
sed -n '1370,1410p' src/megatron/bridge/training/checkpointing.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1592


🏁 Script executed:

# Find the function/method containing these lines
sed -n '1350,1395p' src/megatron/bridge/training/checkpointing.py | cat -n

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 2157


🏁 Script executed:

# Search for DTensor-related code in checkpointing.py
grep -n -i "dtensor" src/megatron/bridge/training/checkpointing.py | head -20

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1391


🏁 Script executed:

# Search for sharded_state_dict usage
rg "sharded_state_dict" -n src/megatron/bridge/training/checkpointing.py

# Search for state_dict_for_save_checkpoint usage
rg "state_dict_for_save_checkpoint" -n src/megatron/bridge/training/checkpointing.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 2868


🏁 Script executed:

# Find where these methods are defined
fd -t f "*.py" src/megatron/bridge | xargs grep -l "def sharded_state_dict\|def state_dict_for_save_checkpoint" | head -10

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 354


🏁 Script executed:

# Get broader context to see the save path logic
sed -n '1080,1120p' src/megatron/bridge/training/checkpointing.py | cat -n

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1923


🏁 Script executed:

# Find the function containing lines 1080-1120
sed -n '1050,1080p' src/megatron/bridge/training/checkpointing.py | tail -30

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1473


🏁 Script executed:

# Get more context showing the condition/format check before lines 1100-1106
sed -n '1090,1115p' src/megatron/bridge/training/checkpointing.py | cat -n

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1159


🏁 Script executed:

# Search for functions that save FSDP DTensor checkpoints
rg -n "def.*save.*fsdp.*dtensor|def.*fsdp.*dtensor.*save" -i src/megatron/bridge/training/checkpointing.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 53


🏁 Script executed:

# Look at the preprocess_fsdp_dtensor_state_dict function definition
sed -n '1178,1230p' src/megatron/bridge/training/checkpointing.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1995


🏁 Script executed:

# Get the function name and full context around lines 1098-1106
sed -n '1050,1065p' src/megatron/bridge/training/checkpointing.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1072


🏁 Script executed:

# Check what function signature and docstring contains this save logic
sed -n '1030,1060p' src/megatron/bridge/training/checkpointing.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1769


🏁 Script executed:

# Search for when _generate_model_state_dict is called
rg -n "_generate_model_state_dict" src/megatron/bridge/training/checkpointing.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 332


🏁 Script executed:

# Check the context around those calls to see if they specify fsdp_dtensor format
sed -n '625,645p' src/megatron/bridge/training/checkpointing.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1246


🏁 Script executed:

# Check the context around line 1272
sed -n '1260,1280p' src/megatron/bridge/training/checkpointing.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1077


🏁 Script executed:

# Look at the save flow for FSDP DTensor to confirm state_dict_for_save_checkpoint is used
sed -n '1140,1160p' src/megatron/bridge/training/checkpointing.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1033


Switch to state_dict_for_save_checkpoint() to match saved DTensor checkpoints.

FSDP DTensor checkpoints are saved using state_dict_for_save_checkpoint() (via _generate_model_state_dict() when ckpt_format == "fsdp_dtensor"), but this load path uses sharded_state_dict() instead. This mismatch will cause key/type errors at load time.

🔧 Fix
-    model_state_dict = model_instance.sharded_state_dict()
+    model_state_dict = model_instance.state_dict_for_save_checkpoint()
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
model_list = unwrap_model(model)
model_instance = model_list[0]
# Build model sharded state dict and preprocess for FSDP DTensor format
model_state_dict = model_instance.sharded_state_dict()
state_dict = {"model": model_state_dict}
state_dict = preprocess_fsdp_dtensor_state_dict(None, state_dict, model_instance)
model_list = unwrap_model(model)
model_instance = model_list[0]
# Build model sharded state dict and preprocess for FSDP DTensor format
model_state_dict = model_instance.state_dict_for_save_checkpoint()
state_dict = {"model": model_state_dict}
state_dict = preprocess_fsdp_dtensor_state_dict(None, state_dict, model_instance)
🤖 Prompt for AI Agents
In `@src/megatron/bridge/training/checkpointing.py` around lines 1386 - 1393, The
code builds a model state dict using model_instance.sharded_state_dict(), but
DTensor checkpoints are saved with state_dict_for_save_checkpoint() so loading
will mismatch; replace the call to model_instance.sharded_state_dict() with
model_instance.state_dict_for_save_checkpoint() (or use the same helper path
used by _generate_model_state_dict when ckpt_format == "fsdp_dtensor") and then
pass that result into preprocess_fsdp_dtensor_state_dict(...) so the produced
state_dict matches the saved DTensor checkpoint format; keep references to
unwrap_model and model_instance so the change is localized to where
model_state_dict is constructed.

Comment on lines +1394 to +1399
# Load using PyTorch DCP
state_dict = _load_fsdp_dtensor_state_dict(checkpoint_path, state_dict)

if return_state_dict:
# Convert DTensors to full tensors for export
return _convert_dtensor_state_dict_to_full(state_dict)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

strict is ignored for DTensor loads.

_load_model_weights_fsdp_dtensor accepts strict, but the call to _load_fsdp_dtensor_state_dict doesn’t pass it, so partial-load is always allowed. Forward the flag to honor the API contract.

🐛 Suggested fix
-    state_dict = _load_fsdp_dtensor_state_dict(checkpoint_path, state_dict)
+    state_dict = _load_fsdp_dtensor_state_dict(checkpoint_path, state_dict, strict=strict)
🤖 Prompt for AI Agents
In `@src/megatron/bridge/training/checkpointing.py` around lines 1394 - 1399, The
call to _load_fsdp_dtensor_state_dict from _load_model_weights_fsdp_dtensor
ignores the strict parameter, so partial-load behavior is always allowed;
forward the strict argument when invoking _load_fsdp_dtensor_state_dict (and any
other internal DTensor-loading helpers called in the same flow) so that the
strict flag passed into _load_model_weights_fsdp_dtensor is honored, and ensure
the return_state_dict branch that converts DTensors to full tensors still
receives the state_dict produced with the strict semantics.

Comment on lines +166 to +184
def is_iteration_directory(path: str) -> bool:
"""Check if a path is a specific iteration checkpoint directory.

Iteration directories follow the naming convention ``iter_XXXXXXX`` where
X is a digit (e.g., ``iter_0000005``, ``iter_0001000``). The format matches
what ``get_checkpoint_name`` produces: ``iter_{:07d}``.

Args:
path: Path to a Megatron checkpoint directory. This can be either the root
checkpoint directory containing ``iter_*`` subdirectories or a specific
iteration directory.
path: Path to check.

Returns:
The HuggingFace model identifier/path if present, otherwise ``None``.

Raises:
FileNotFoundError: If the provided path does not exist.
NotADirectoryError: If the provided path is not a directory.
True if the path basename matches the iteration directory pattern.
"""
use_msc = MultiStorageClientFeature.is_enabled()
basename = os.path.basename(path.rstrip(os.sep))
# Match format from get_checkpoint_name: "iter_{:07d}" (iter_ + 7 digits)
if not basename.startswith(_ITER_PREFIX):
return False
suffix = basename[_ITER_PREFIX_LEN:]
return len(suffix) == 7 and suffix.isdigit()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Allow iteration directories beyond 7 digits.

get_checkpoint_name can emit iter_ names longer than 7 digits once iterations exceed 9,999,999, but this check rejects them. That will make long-running checkpoints undiscoverable via resolve_checkpoint_path.

🐛 Suggested fix
-    return len(suffix) == 7 and suffix.isdigit()
+    return len(suffix) >= 7 and suffix.isdigit()
🤖 Prompt for AI Agents
In `@src/megatron/bridge/training/utils/checkpoint_utils.py` around lines 166 -
184, The is_iteration_directory function currently only accepts exactly 7 digits
after the _ITER_PREFIX which breaks when get_checkpoint_name emits longer
iteration numbers; update the check in is_iteration_directory to allow any
suffix length >= 7 as long as all characters are digits (i.e., keep checking
basename.startswith(_ITER_PREFIX), then ensure len(suffix) >= 7 and
suffix.isdigit()). Reference the function name is_iteration_directory and the
constants _ITER_PREFIX/_ITER_PREFIX_LEN when making the change and update the
docstring to reflect "7 or more digits" if present.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants