-
Notifications
You must be signed in to change notification settings - Fork 166
[checkpoint] Refactor checkpoint utils and expose public API to load model weights #2239
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
ananthsub
wants to merge
2
commits into
NVIDIA-NeMo:main
Choose a base branch
from
ananthsub:ckpt-utils-consolidate
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -73,6 +73,7 @@ | |
| get_checkpoint_train_state_filename, | ||
| read_run_config, | ||
| read_train_state, | ||
| resolve_checkpoint_path, | ||
| ) | ||
| from megatron.bridge.training.utils.log_utils import append_to_progress_log | ||
| from megatron.bridge.training.utils.pg_utils import get_pg_collection | ||
|
|
@@ -1298,6 +1299,184 @@ def _load_model_weights_from_checkpoint( | |
| torch.distributed.barrier() | ||
|
|
||
|
|
||
| def _load_fsdp_dtensor_state_dict( | ||
| checkpoint_path: str, | ||
| state_dict: dict[str, Any], | ||
| strict: bool = False, | ||
| ) -> dict[str, Any]: | ||
| """Load a state dict from an FSDP DTensor checkpoint using PyTorch DCP. | ||
|
|
||
| Args: | ||
| checkpoint_path: Full path to the checkpoint directory. | ||
| state_dict: Pre-built state dict with DTensor placeholders to load into. | ||
| strict: If True, require exact key match. If False, allow partial loading. | ||
|
|
||
| Returns: | ||
| The loaded state dict (modified in-place). | ||
| """ | ||
| fs_storage_reader = torch.distributed.checkpoint.FileSystemReader(checkpoint_path) | ||
|
|
||
| if not strict: | ||
| state_dict_metadata = fs_storage_reader.read_metadata().state_dict_metadata | ||
| print_diff_in_state_dicts(state_dict_metadata, state_dict) | ||
|
|
||
| planner = torch.distributed.checkpoint.default_planner.DefaultLoadPlanner(allow_partial_load=not strict) | ||
| torch.distributed.checkpoint.load_state_dict( | ||
| state_dict=state_dict, | ||
| storage_reader=fs_storage_reader, | ||
| planner=planner, | ||
| ) | ||
| return state_dict | ||
|
|
||
|
|
||
| def _convert_dtensor_state_dict_to_full(state_dict: StateDict) -> StateDict: | ||
| """Convert a state dict containing DTensors to one with full (unsharded) tensors. | ||
|
|
||
| Args: | ||
| state_dict: State dict potentially containing DTensor values. | ||
|
|
||
| Returns: | ||
| State dict with all DTensors converted to regular tensors via full_tensor(). | ||
| """ | ||
| try: | ||
| from torch.distributed.tensor import DTensor | ||
| except ImportError: | ||
| # No DTensor available, return as-is | ||
| return state_dict | ||
|
|
||
| def _convert_value(value: Any) -> Any: | ||
| if isinstance(value, DTensor): | ||
| return value.full_tensor() | ||
| elif isinstance(value, dict): | ||
| return {k: _convert_value(v) for k, v in value.items()} | ||
| elif isinstance(value, list): | ||
| return [_convert_value(v) for v in value] | ||
| return value | ||
|
|
||
| return _convert_value(state_dict) | ||
|
|
||
|
|
||
| def _load_model_weights_fsdp_dtensor( | ||
| model: list[MegatronModule], | ||
| checkpoint_path: str, | ||
| strict: bool = True, | ||
| return_state_dict: bool = False, | ||
| ) -> Optional[StateDict]: | ||
| """Load model weights from an FSDP DTensor checkpoint. | ||
|
|
||
| This is a simplified loading path for FSDP DTensor checkpoints that only loads | ||
| model weights without optimizer state, RNG state, or iteration tracking. | ||
|
|
||
| Args: | ||
| model: The model(s) to load weights into. | ||
| checkpoint_path: Full path to the iteration checkpoint directory. | ||
| strict: Whether to enforce strict state dict loading. | ||
| return_state_dict: If True, return the state dict instead of loading into model. | ||
|
|
||
| Returns: | ||
| If return_state_dict is True, returns the model state dict with full tensors. | ||
| Otherwise returns None. | ||
| """ | ||
| if not HAVE_MEGATRON_FSDP: | ||
| raise RuntimeError( | ||
| "Megatron FSDP is required but not available for loading FSDP DTensor checkpoints. " | ||
| "Please install megatron-core with FSDP support." | ||
| ) | ||
|
|
||
| model_list = unwrap_model(model) | ||
| model_instance = model_list[0] | ||
|
|
||
| # Build model sharded state dict and preprocess for FSDP DTensor format | ||
| model_state_dict = model_instance.sharded_state_dict() | ||
| state_dict = {"model": model_state_dict} | ||
| state_dict = preprocess_fsdp_dtensor_state_dict(None, state_dict, model_instance) | ||
|
|
||
| # Load using PyTorch DCP | ||
| state_dict = _load_fsdp_dtensor_state_dict(checkpoint_path, state_dict) | ||
|
|
||
| if return_state_dict: | ||
| # Convert DTensors to full tensors for export | ||
| return _convert_dtensor_state_dict_to_full(state_dict) | ||
|
Comment on lines
+1394
to
+1399
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
🐛 Suggested fix- state_dict = _load_fsdp_dtensor_state_dict(checkpoint_path, state_dict)
+ state_dict = _load_fsdp_dtensor_state_dict(checkpoint_path, state_dict, strict=strict)🤖 Prompt for AI Agents |
||
|
|
||
| # Load the state dict into the model | ||
| _load_model_state_dict(model_instance, state_dict["model"], strict) | ||
|
|
||
| if torch.distributed.is_initialized(): | ||
| torch.distributed.barrier() | ||
|
|
||
|
|
||
| def load_model_weights( | ||
| model: list[MegatronModule], | ||
| checkpoint_path: str, | ||
| *, | ||
| fully_parallel_load: bool = False, | ||
| strict: bool = True, | ||
| return_state_dict: bool = False, | ||
| ) -> Optional[StateDict]: | ||
| """Load only model weights from a checkpoint. | ||
|
|
||
| Simple API for loading pretrained model weights without optimizer state, | ||
| RNG state, or iteration tracking. Supports both ``torch_dist`` and ``fsdp_dtensor`` | ||
| checkpoint formats. | ||
|
|
||
| This function automatically: | ||
| - Resolves top-level checkpoint directories to the latest iteration | ||
| - Detects the checkpoint format | ||
| - Loads only model weights (no optimizer, RNG, or training state) | ||
|
|
||
| Args: | ||
| model: The model(s) to load weights into. | ||
| checkpoint_path: Path to checkpoint. Can be either: | ||
| - A top-level checkpoint directory (loads latest iteration) | ||
| - A specific iteration directory (e.g., ``/path/to/checkpoint/iter_0000005``) | ||
| fully_parallel_load: Apply full load parallelization across data parallel ranks. | ||
| Only supported for ``torch_dist`` format; ignored for other formats. | ||
| strict: Whether to enforce strict state dict loading. | ||
| return_state_dict: If True, return the state dict instead of loading into model. | ||
|
|
||
| Returns: | ||
| If return_state_dict is True, returns the model state dict. | ||
| Otherwise returns None. | ||
|
|
||
| Raises: | ||
| FileNotFoundError: If no valid checkpoint is found at the path. | ||
| NotImplementedError: If the checkpoint format is not supported. | ||
|
|
||
| Example: | ||
| >>> # Load from a specific iteration | ||
| >>> load_model_weights(model, "/checkpoints/iter_0000005") | ||
|
|
||
| >>> # Load latest checkpoint from a directory | ||
| >>> load_model_weights(model, "/checkpoints") | ||
|
|
||
| >>> # Get state dict for export | ||
| >>> state_dict = load_model_weights(model, "/checkpoints", return_state_dict=True) | ||
| """ | ||
| # Resolve to specific iteration directory | ||
| resolved_path = resolve_checkpoint_path(checkpoint_path) | ||
| ckpt_format = _get_checkpoint_format(resolved_path) | ||
|
|
||
| if ckpt_format == "torch_dist": | ||
| return _load_model_weights_from_checkpoint( | ||
| resolved_path, | ||
| model, | ||
| fully_parallel_load=fully_parallel_load, | ||
| strict=strict, | ||
| return_state_dict=return_state_dict, | ||
| ) | ||
| elif ckpt_format == "fsdp_dtensor": | ||
| if fully_parallel_load: | ||
| print_rank_0("Warning: fully_parallel_load is not supported for fsdp_dtensor format, ignoring") | ||
| return _load_model_weights_fsdp_dtensor( | ||
| model, resolved_path, strict=strict, return_state_dict=return_state_dict | ||
| ) | ||
| else: | ||
| raise NotImplementedError( | ||
| f"Checkpoint format '{ckpt_format}' is not supported for load_model_weights. " | ||
| f"Supported formats: 'torch_dist', 'fsdp_dtensor'" | ||
| ) | ||
|
|
||
|
|
||
| def load_checkpoint( | ||
| state: GlobalState, | ||
| model: list[MegatronModule], | ||
|
|
@@ -2243,24 +2422,8 @@ def _load_fsdp_dtensor_base_checkpoint( | |
| state_dict = preprocess_fsdp_dtensor_state_dict(cfg, state_dict, model[0]) | ||
|
|
||
| checkpoint_name = get_checkpoint_name(load_dir, iteration, release) | ||
| fs_storage_reader = torch.distributed.checkpoint.FileSystemReader(checkpoint_name) | ||
|
|
||
| # Configure partial loading based on strict_fsdp_dtensor_load setting | ||
| allow_partial_load = not getattr(ckpt_cfg, "strict_fsdp_dtensor_load", False) | ||
| if allow_partial_load: | ||
| state_dict_metadata = fs_storage_reader.read_metadata().state_dict_metadata | ||
| rank = torch.distributed.get_rank() | ||
| import time as _time | ||
|
|
||
| _time.sleep(rank * 0.001) # Prevent log overlap across ranks | ||
| print_diff_in_state_dicts(state_dict_metadata, state_dict) | ||
|
|
||
| planner = torch.distributed.checkpoint.default_planner.DefaultLoadPlanner(allow_partial_load=allow_partial_load) | ||
| torch.distributed.checkpoint.load_state_dict( | ||
| state_dict=state_dict, | ||
| storage_reader=fs_storage_reader, | ||
| planner=planner, | ||
| ) | ||
| strict = getattr(ckpt_cfg, "strict_fsdp_dtensor_load", False) | ||
| state_dict = _load_fsdp_dtensor_state_dict(checkpoint_name, state_dict, strict=strict) | ||
|
|
||
| # Restore raw state dicts to maintain original structure for the rest of the load process | ||
| if raw_optimizer_state_dict is not None: | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
# First, check the file exists and get size wc -l src/megatron/bridge/training/checkpointing.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 121
🏁 Script executed:
Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 1592
🏁 Script executed:
Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 2157
🏁 Script executed:
Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 1391
🏁 Script executed:
Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 2868
🏁 Script executed:
Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 354
🏁 Script executed:
Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 1923
🏁 Script executed:
Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 1473
🏁 Script executed:
Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 1159
🏁 Script executed:
Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 53
🏁 Script executed:
Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 1995
🏁 Script executed:
Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 1072
🏁 Script executed:
Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 1769
🏁 Script executed:
Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 332
🏁 Script executed:
Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 1246
🏁 Script executed:
Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 1077
🏁 Script executed:
Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 1033
Switch to
state_dict_for_save_checkpoint()to match saved DTensor checkpoints.FSDP DTensor checkpoints are saved using
state_dict_for_save_checkpoint()(via_generate_model_state_dict()whenckpt_format == "fsdp_dtensor"), but this load path usessharded_state_dict()instead. This mismatch will cause key/type errors at load time.🔧 Fix
📝 Committable suggestion
🤖 Prompt for AI Agents