[checkpoint] Refactor checkpoint utils and expose public API to load model weights#2239
[checkpoint] Refactor checkpoint utils and expose public API to load model weights#2239ananthsub wants to merge 2 commits intoNVIDIA-NeMo:mainfrom
Conversation
Signed-off-by: Ananth Subramaniam <ansubramania@nvidia.com>
|
/ok to test d9bab4f |
Signed-off-by: Ananth Subramaniam <ansubramania@nvidia.com>
|
/ok to test d96819a |
📝 WalkthroughWalkthroughThis PR refactors checkpoint path resolution by introducing centralized utilities to handle iteration-based directory discovery and adds support for FSDP DTensor checkpoint format. Manual iteration directory scanning is replaced with a single Changes
Sequence Diagram(s)sequenceDiagram
participant Client as Caller
participant LoadWeights as load_model_weights()
participant Resolver as resolve_checkpoint_path()
participant Format as _get_checkpoint_format()
participant TorchDist as torch_dist Loader
participant FSDPDTensor as fsdp_dtensor Loader
participant Model as Model Instance
Client->>LoadWeights: load_model_weights(model, checkpoint_path)
LoadWeights->>Resolver: resolve_checkpoint_path(checkpoint_path)
Resolver-->>LoadWeights: resolved_path
LoadWeights->>Format: _get_checkpoint_format(resolved_path)
Format-->>LoadWeights: format_type
alt format == "torch_dist"
LoadWeights->>TorchDist: load_model_weights_from_checkpoint(model, resolved_path, ...)
TorchDist->>Model: load weights
TorchDist-->>LoadWeights: state_dict (if return_state_dict)
else format == "fsdp_dtensor"
LoadWeights->>FSDPDTensor: _load_model_weights_fsdp_dtensor(model, resolved_path, ...)
FSDPDTensor->>Model: load DTensor checkpoint via DCP
FSDPDTensor-->>LoadWeights: state_dict (if return_state_dict)
else unsupported format
LoadWeights-->>Client: NotImplementedError
end
LoadWeights-->>Client: state_dict or model
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Fix all issues with AI agents
In `@src/megatron/bridge/training/checkpointing.py`:
- Around line 1394-1399: The call to _load_fsdp_dtensor_state_dict from
_load_model_weights_fsdp_dtensor ignores the strict parameter, so partial-load
behavior is always allowed; forward the strict argument when invoking
_load_fsdp_dtensor_state_dict (and any other internal DTensor-loading helpers
called in the same flow) so that the strict flag passed into
_load_model_weights_fsdp_dtensor is honored, and ensure the return_state_dict
branch that converts DTensors to full tensors still receives the state_dict
produced with the strict semantics.
- Around line 1386-1393: The code builds a model state dict using
model_instance.sharded_state_dict(), but DTensor checkpoints are saved with
state_dict_for_save_checkpoint() so loading will mismatch; replace the call to
model_instance.sharded_state_dict() with
model_instance.state_dict_for_save_checkpoint() (or use the same helper path
used by _generate_model_state_dict when ckpt_format == "fsdp_dtensor") and then
pass that result into preprocess_fsdp_dtensor_state_dict(...) so the produced
state_dict matches the saved DTensor checkpoint format; keep references to
unwrap_model and model_instance so the change is localized to where
model_state_dict is constructed.
In `@src/megatron/bridge/training/utils/checkpoint_utils.py`:
- Around line 166-184: The is_iteration_directory function currently only
accepts exactly 7 digits after the _ITER_PREFIX which breaks when
get_checkpoint_name emits longer iteration numbers; update the check in
is_iteration_directory to allow any suffix length >= 7 as long as all characters
are digits (i.e., keep checking basename.startswith(_ITER_PREFIX), then ensure
len(suffix) >= 7 and suffix.isdigit()). Reference the function name
is_iteration_directory and the constants _ITER_PREFIX/_ITER_PREFIX_LEN when
making the change and update the docstring to reflect "7 or more digits" if
present.
🧹 Nitpick comments (5)
src/megatron/bridge/training/utils/checkpoint_utils.py (1)
35-36: Align new global constants with the G_ prefix convention.The new globals
_ITER_PREFIXand_ITER_PREFIX_LENdon’t follow the required G_ prefix for module-level variables. Consider renaming and updating references to keep naming consistent.♻️ Suggested refactor
-_ITER_PREFIX: str = "iter_" -_ITER_PREFIX_LEN: int = len(_ITER_PREFIX) +G_ITER_PREFIX: str = "iter_" +G_ITER_PREFIX_LEN: int = len(G_ITER_PREFIX)As per coding guidelines, Use upper snake_case and prefix 'G' for global variables (e.g., G_MY_GLOBAL).
tests/unit_tests/training/utils/test_checkpoint_utils.py (1)
788-959: Update iteration-name expectations and add unit-test markers.If
is_iteration_directoryis relaxed to accept ≥7-digit iterations, flip the 8‑digit case to valid and consider adding an explicit >7‑digit valid case. Also, mark these new classes with@pytest.mark.unit(or a module‑levelpytestmark) to satisfy test categorization guidance.🧪 Suggested updates
+@pytest.mark.unit class TestIsIterationDirectory: @@ - assert is_iteration_directory("iter_12345678") is False # 8 digits + assert is_iteration_directory("iter_12345678") is True # 8+ digits validAs per coding guidelines, Use 'pytest.mark' to categorize tests (unit, integration, system).
tests/unit_tests/training/post_training/test_checkpointing.py (1)
47-55: Add unit-test markers for new test classes.Consider adding
@pytest.mark.uniton the new test classes (or a module‑levelpytestmark) to align with test categorization guidance.🧪 Example
+@pytest.mark.unit class TestGetModeloptCheckpointPath:As per coding guidelines, Use 'pytest.mark' to categorize tests (unit, integration, system).
src/megatron/bridge/models/conversion/auto_bridge.py (1)
655-659: Consider chaining the ImportError for better debugging.The static analysis tool flagged that the re-raised exception should use exception chaining for clearer error context.
🔧 Proposed fix for exception chaining
try: from megatron.bridge.training.model_load_save import load_megatron_model from megatron.bridge.training.utils.checkpoint_utils import resolve_checkpoint_path except ImportError: - raise ImportError("megatron.bridge.training is not available.") + raise ImportError("megatron.bridge.training is not available.") from Nonetests/unit_tests/training/test_checkpointing.py (1)
1604-1627: Consider assertingmock_load_fsdpwas called.The test verifies that a warning is printed when
fully_parallel_load=Trueis used withfsdp_dtensorformat, but it doesn't verify that loading still proceeds. Adding an assertion would make the test more robust.🔧 Suggested improvement
# Should warn about unsupported fully_parallel_load mock_print.assert_called() warning_calls = [call for call in mock_print.call_args_list if "fully_parallel_load" in str(call)] assert len(warning_calls) > 0 + + # Verify loading still proceeds despite the warning + mock_load_fsdp.assert_called_once_with( + mock_model, "/resolved/checkpoint/iter_0000005", strict=True, return_state_dict=False + )
| model_list = unwrap_model(model) | ||
| model_instance = model_list[0] | ||
|
|
||
| # Build model sharded state dict and preprocess for FSDP DTensor format | ||
| model_state_dict = model_instance.sharded_state_dict() | ||
| state_dict = {"model": model_state_dict} | ||
| state_dict = preprocess_fsdp_dtensor_state_dict(None, state_dict, model_instance) | ||
|
|
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, check the file exists and get size
wc -l src/megatron/bridge/training/checkpointing.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 121
🏁 Script executed:
# Get the context around lines 1386-1393
sed -n '1370,1410p' src/megatron/bridge/training/checkpointing.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 1592
🏁 Script executed:
# Find the function/method containing these lines
sed -n '1350,1395p' src/megatron/bridge/training/checkpointing.py | cat -nRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 2157
🏁 Script executed:
# Search for DTensor-related code in checkpointing.py
grep -n -i "dtensor" src/megatron/bridge/training/checkpointing.py | head -20Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 1391
🏁 Script executed:
# Search for sharded_state_dict usage
rg "sharded_state_dict" -n src/megatron/bridge/training/checkpointing.py
# Search for state_dict_for_save_checkpoint usage
rg "state_dict_for_save_checkpoint" -n src/megatron/bridge/training/checkpointing.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 2868
🏁 Script executed:
# Find where these methods are defined
fd -t f "*.py" src/megatron/bridge | xargs grep -l "def sharded_state_dict\|def state_dict_for_save_checkpoint" | head -10Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 354
🏁 Script executed:
# Get broader context to see the save path logic
sed -n '1080,1120p' src/megatron/bridge/training/checkpointing.py | cat -nRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 1923
🏁 Script executed:
# Find the function containing lines 1080-1120
sed -n '1050,1080p' src/megatron/bridge/training/checkpointing.py | tail -30Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 1473
🏁 Script executed:
# Get more context showing the condition/format check before lines 1100-1106
sed -n '1090,1115p' src/megatron/bridge/training/checkpointing.py | cat -nRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 1159
🏁 Script executed:
# Search for functions that save FSDP DTensor checkpoints
rg -n "def.*save.*fsdp.*dtensor|def.*fsdp.*dtensor.*save" -i src/megatron/bridge/training/checkpointing.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 53
🏁 Script executed:
# Look at the preprocess_fsdp_dtensor_state_dict function definition
sed -n '1178,1230p' src/megatron/bridge/training/checkpointing.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 1995
🏁 Script executed:
# Get the function name and full context around lines 1098-1106
sed -n '1050,1065p' src/megatron/bridge/training/checkpointing.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 1072
🏁 Script executed:
# Check what function signature and docstring contains this save logic
sed -n '1030,1060p' src/megatron/bridge/training/checkpointing.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 1769
🏁 Script executed:
# Search for when _generate_model_state_dict is called
rg -n "_generate_model_state_dict" src/megatron/bridge/training/checkpointing.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 332
🏁 Script executed:
# Check the context around those calls to see if they specify fsdp_dtensor format
sed -n '625,645p' src/megatron/bridge/training/checkpointing.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 1246
🏁 Script executed:
# Check the context around line 1272
sed -n '1260,1280p' src/megatron/bridge/training/checkpointing.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 1077
🏁 Script executed:
# Look at the save flow for FSDP DTensor to confirm state_dict_for_save_checkpoint is used
sed -n '1140,1160p' src/megatron/bridge/training/checkpointing.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 1033
Switch to state_dict_for_save_checkpoint() to match saved DTensor checkpoints.
FSDP DTensor checkpoints are saved using state_dict_for_save_checkpoint() (via _generate_model_state_dict() when ckpt_format == "fsdp_dtensor"), but this load path uses sharded_state_dict() instead. This mismatch will cause key/type errors at load time.
🔧 Fix
- model_state_dict = model_instance.sharded_state_dict()
+ model_state_dict = model_instance.state_dict_for_save_checkpoint()📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| model_list = unwrap_model(model) | |
| model_instance = model_list[0] | |
| # Build model sharded state dict and preprocess for FSDP DTensor format | |
| model_state_dict = model_instance.sharded_state_dict() | |
| state_dict = {"model": model_state_dict} | |
| state_dict = preprocess_fsdp_dtensor_state_dict(None, state_dict, model_instance) | |
| model_list = unwrap_model(model) | |
| model_instance = model_list[0] | |
| # Build model sharded state dict and preprocess for FSDP DTensor format | |
| model_state_dict = model_instance.state_dict_for_save_checkpoint() | |
| state_dict = {"model": model_state_dict} | |
| state_dict = preprocess_fsdp_dtensor_state_dict(None, state_dict, model_instance) |
🤖 Prompt for AI Agents
In `@src/megatron/bridge/training/checkpointing.py` around lines 1386 - 1393, The
code builds a model state dict using model_instance.sharded_state_dict(), but
DTensor checkpoints are saved with state_dict_for_save_checkpoint() so loading
will mismatch; replace the call to model_instance.sharded_state_dict() with
model_instance.state_dict_for_save_checkpoint() (or use the same helper path
used by _generate_model_state_dict when ckpt_format == "fsdp_dtensor") and then
pass that result into preprocess_fsdp_dtensor_state_dict(...) so the produced
state_dict matches the saved DTensor checkpoint format; keep references to
unwrap_model and model_instance so the change is localized to where
model_state_dict is constructed.
| # Load using PyTorch DCP | ||
| state_dict = _load_fsdp_dtensor_state_dict(checkpoint_path, state_dict) | ||
|
|
||
| if return_state_dict: | ||
| # Convert DTensors to full tensors for export | ||
| return _convert_dtensor_state_dict_to_full(state_dict) |
There was a problem hiding this comment.
strict is ignored for DTensor loads.
_load_model_weights_fsdp_dtensor accepts strict, but the call to _load_fsdp_dtensor_state_dict doesn’t pass it, so partial-load is always allowed. Forward the flag to honor the API contract.
🐛 Suggested fix
- state_dict = _load_fsdp_dtensor_state_dict(checkpoint_path, state_dict)
+ state_dict = _load_fsdp_dtensor_state_dict(checkpoint_path, state_dict, strict=strict)🤖 Prompt for AI Agents
In `@src/megatron/bridge/training/checkpointing.py` around lines 1394 - 1399, The
call to _load_fsdp_dtensor_state_dict from _load_model_weights_fsdp_dtensor
ignores the strict parameter, so partial-load behavior is always allowed;
forward the strict argument when invoking _load_fsdp_dtensor_state_dict (and any
other internal DTensor-loading helpers called in the same flow) so that the
strict flag passed into _load_model_weights_fsdp_dtensor is honored, and ensure
the return_state_dict branch that converts DTensors to full tensors still
receives the state_dict produced with the strict semantics.
| def is_iteration_directory(path: str) -> bool: | ||
| """Check if a path is a specific iteration checkpoint directory. | ||
|
|
||
| Iteration directories follow the naming convention ``iter_XXXXXXX`` where | ||
| X is a digit (e.g., ``iter_0000005``, ``iter_0001000``). The format matches | ||
| what ``get_checkpoint_name`` produces: ``iter_{:07d}``. | ||
|
|
||
| Args: | ||
| path: Path to a Megatron checkpoint directory. This can be either the root | ||
| checkpoint directory containing ``iter_*`` subdirectories or a specific | ||
| iteration directory. | ||
| path: Path to check. | ||
|
|
||
| Returns: | ||
| The HuggingFace model identifier/path if present, otherwise ``None``. | ||
|
|
||
| Raises: | ||
| FileNotFoundError: If the provided path does not exist. | ||
| NotADirectoryError: If the provided path is not a directory. | ||
| True if the path basename matches the iteration directory pattern. | ||
| """ | ||
| use_msc = MultiStorageClientFeature.is_enabled() | ||
| basename = os.path.basename(path.rstrip(os.sep)) | ||
| # Match format from get_checkpoint_name: "iter_{:07d}" (iter_ + 7 digits) | ||
| if not basename.startswith(_ITER_PREFIX): | ||
| return False | ||
| suffix = basename[_ITER_PREFIX_LEN:] | ||
| return len(suffix) == 7 and suffix.isdigit() |
There was a problem hiding this comment.
Allow iteration directories beyond 7 digits.
get_checkpoint_name can emit iter_ names longer than 7 digits once iterations exceed 9,999,999, but this check rejects them. That will make long-running checkpoints undiscoverable via resolve_checkpoint_path.
🐛 Suggested fix
- return len(suffix) == 7 and suffix.isdigit()
+ return len(suffix) >= 7 and suffix.isdigit()🤖 Prompt for AI Agents
In `@src/megatron/bridge/training/utils/checkpoint_utils.py` around lines 166 -
184, The is_iteration_directory function currently only accepts exactly 7 digits
after the _ITER_PREFIX which breaks when get_checkpoint_name emits longer
iteration numbers; update the check in is_iteration_directory to allow any
suffix length >= 7 as long as all characters are digits (i.e., keep checking
basename.startswith(_ITER_PREFIX), then ensure len(suffix) >= 7 and
suffix.isdigit()). Reference the function name is_iteration_directory and the
constants _ITER_PREFIX/_ITER_PREFIX_LEN when making the change and update the
docstring to reflect "7 or more digits" if present.
What does this PR do ?
load_model_weightsfunction which to load just the model states from the checkpoint. this covers bothtorch_distandfsdp_dtensorcheckpoint formats transparentlyChangelog
GitHub Actions CI
See the CI sectionin the Contributing doc for how to trigger the CI. A Nvidia developer will need to approve and trigger the CI for external contributors.
Before your PR is "Ready for review"
Pre checks:
If you haven't finished some of the above items you can still open "Draft" PR.
Additional Information
Summary by CodeRabbit
Release Notes
New Features
Bug Fixes