Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 3 additions & 14 deletions examples/conversion/compare_text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,21 +398,10 @@ def megatron_generate_from_checkpoint(
etp: Expert tensor parallelism size override. (default: 1)
"""
from megatron.bridge.training.model_load_save import build_and_load_model, load_model_config, load_tokenizer
from megatron.bridge.training.utils.checkpoint_utils import resolve_checkpoint_path

checkpoint_path = Path(megatron_path)
# Check for iter_* folders
iter_folders = [f for f in checkpoint_path.iterdir() if f.is_dir() and f.name.startswith("iter_")]
if iter_folders:
# Find the folder with the largest iteration number
def get_iter_number(folder_name):
try:
return int(folder_name.replace("iter_", ""))
except ValueError:
return -1 # Invalid format, put at the end

latest_iter = max(iter_folders, key=lambda f: get_iter_number(f.name))
checkpoint_path = checkpoint_path / latest_iter.name
# else: checkpoint_path remains as the input path (no iter folders found), and we assume it is a dist ckpt
# Resolve to specific iteration directory (handles both top-level and iter_* paths)
checkpoint_path = resolve_checkpoint_path(megatron_path)

print_rank_0(f"Loading Megatron model from checkpoint at {checkpoint_path}")

Expand Down
14 changes: 7 additions & 7 deletions examples/conversion/convert_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,19 +153,19 @@ def export_megatron_to_hf(
"""
print(f"🔄 Starting export: {megatron_path} -> {hf_path}")

from megatron.bridge.training.utils.checkpoint_utils import resolve_checkpoint_path

# Validate megatron checkpoint exists
checkpoint_path = validate_path(megatron_path, must_exist=True)
print(f"📂 Found Megatron checkpoint: {checkpoint_path}")

# Resolve to specific iteration directory
resolved_path = Path(resolve_checkpoint_path(str(checkpoint_path)))

# Look for configuration files to determine the model type
config_files = list(checkpoint_path.glob("**/run_config.yaml"))
config_files = list(resolved_path.glob("run_config.yaml"))
if not config_files:
# Look in iter_ subdirectories
iter_dirs = [d for d in checkpoint_path.iterdir() if d.is_dir() and d.name.startswith("iter_")]
if iter_dirs:
# Use the latest iteration
latest_iter = max(iter_dirs, key=lambda d: int(d.name.replace("iter_", "")))
config_files = list(latest_iter.glob("run_config.yaml"))
config_files = list(checkpoint_path.glob("**/run_config.yaml"))

if not config_files:
raise FileNotFoundError(
Expand Down
21 changes: 4 additions & 17 deletions src/megatron/bridge/models/conversion/auto_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,30 +654,17 @@ def load_megatron_model(
"""
try:
from megatron.bridge.training.model_load_save import load_megatron_model
from megatron.bridge.training.utils.checkpoint_utils import resolve_checkpoint_path
except ImportError:
raise ImportError("megatron.bridge.training is not available.")

checkpoint_path = Path(path)

# Check for iter_* folders
iter_folders = [f for f in checkpoint_path.iterdir() if f.is_dir() and f.name.startswith("iter_")]

if iter_folders:
# Find the folder with the largest iteration number
def get_iter_number(folder_name):
try:
return int(folder_name.replace("iter_", ""))
except ValueError:
return -1 # Invalid format, put at the end

latest_iter = max(iter_folders, key=lambda f: get_iter_number(f.name))
checkpoint_path = checkpoint_path / latest_iter.name
# else: checkpoint_path remains as the input path (no iter folders found)
# Resolve to specific iteration (handles both top-level and iter_* paths)
resolved_path = resolve_checkpoint_path(str(path))

skip_temp_dist_context = dist.is_available() and dist.is_initialized()
# Load the state dict
model = load_megatron_model(
str(checkpoint_path),
resolved_path,
use_cpu_init=(skip_temp_dist_context and dist.get_backend() == "gloo"),
skip_temp_dist_context=skip_temp_dist_context,
mp_overrides=mp_overrides,
Expand Down
199 changes: 181 additions & 18 deletions src/megatron/bridge/training/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
get_checkpoint_train_state_filename,
read_run_config,
read_train_state,
resolve_checkpoint_path,
)
from megatron.bridge.training.utils.log_utils import append_to_progress_log
from megatron.bridge.training.utils.pg_utils import get_pg_collection
Expand Down Expand Up @@ -1298,6 +1299,184 @@ def _load_model_weights_from_checkpoint(
torch.distributed.barrier()


def _load_fsdp_dtensor_state_dict(
checkpoint_path: str,
state_dict: dict[str, Any],
strict: bool = False,
) -> dict[str, Any]:
"""Load a state dict from an FSDP DTensor checkpoint using PyTorch DCP.

Args:
checkpoint_path: Full path to the checkpoint directory.
state_dict: Pre-built state dict with DTensor placeholders to load into.
strict: If True, require exact key match. If False, allow partial loading.

Returns:
The loaded state dict (modified in-place).
"""
fs_storage_reader = torch.distributed.checkpoint.FileSystemReader(checkpoint_path)

if not strict:
state_dict_metadata = fs_storage_reader.read_metadata().state_dict_metadata
print_diff_in_state_dicts(state_dict_metadata, state_dict)

planner = torch.distributed.checkpoint.default_planner.DefaultLoadPlanner(allow_partial_load=not strict)
torch.distributed.checkpoint.load_state_dict(
state_dict=state_dict,
storage_reader=fs_storage_reader,
planner=planner,
)
return state_dict


def _convert_dtensor_state_dict_to_full(state_dict: StateDict) -> StateDict:
"""Convert a state dict containing DTensors to one with full (unsharded) tensors.

Args:
state_dict: State dict potentially containing DTensor values.

Returns:
State dict with all DTensors converted to regular tensors via full_tensor().
"""
try:
from torch.distributed.tensor import DTensor
except ImportError:
# No DTensor available, return as-is
return state_dict

def _convert_value(value: Any) -> Any:
if isinstance(value, DTensor):
return value.full_tensor()
elif isinstance(value, dict):
return {k: _convert_value(v) for k, v in value.items()}
elif isinstance(value, list):
return [_convert_value(v) for v in value]
return value

return _convert_value(state_dict)


def _load_model_weights_fsdp_dtensor(
model: list[MegatronModule],
checkpoint_path: str,
strict: bool = True,
return_state_dict: bool = False,
) -> Optional[StateDict]:
"""Load model weights from an FSDP DTensor checkpoint.

This is a simplified loading path for FSDP DTensor checkpoints that only loads
model weights without optimizer state, RNG state, or iteration tracking.

Args:
model: The model(s) to load weights into.
checkpoint_path: Full path to the iteration checkpoint directory.
strict: Whether to enforce strict state dict loading.
return_state_dict: If True, return the state dict instead of loading into model.

Returns:
If return_state_dict is True, returns the model state dict with full tensors.
Otherwise returns None.
"""
if not HAVE_MEGATRON_FSDP:
raise RuntimeError(
"Megatron FSDP is required but not available for loading FSDP DTensor checkpoints. "
"Please install megatron-core with FSDP support."
)

model_list = unwrap_model(model)
model_instance = model_list[0]

# Build model sharded state dict and preprocess for FSDP DTensor format
model_state_dict = model_instance.sharded_state_dict()
state_dict = {"model": model_state_dict}
state_dict = preprocess_fsdp_dtensor_state_dict(None, state_dict, model_instance)

Comment on lines +1386 to +1393
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, check the file exists and get size
wc -l src/megatron/bridge/training/checkpointing.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 121


🏁 Script executed:

# Get the context around lines 1386-1393
sed -n '1370,1410p' src/megatron/bridge/training/checkpointing.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1592


🏁 Script executed:

# Find the function/method containing these lines
sed -n '1350,1395p' src/megatron/bridge/training/checkpointing.py | cat -n

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 2157


🏁 Script executed:

# Search for DTensor-related code in checkpointing.py
grep -n -i "dtensor" src/megatron/bridge/training/checkpointing.py | head -20

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1391


🏁 Script executed:

# Search for sharded_state_dict usage
rg "sharded_state_dict" -n src/megatron/bridge/training/checkpointing.py

# Search for state_dict_for_save_checkpoint usage
rg "state_dict_for_save_checkpoint" -n src/megatron/bridge/training/checkpointing.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 2868


🏁 Script executed:

# Find where these methods are defined
fd -t f "*.py" src/megatron/bridge | xargs grep -l "def sharded_state_dict\|def state_dict_for_save_checkpoint" | head -10

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 354


🏁 Script executed:

# Get broader context to see the save path logic
sed -n '1080,1120p' src/megatron/bridge/training/checkpointing.py | cat -n

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1923


🏁 Script executed:

# Find the function containing lines 1080-1120
sed -n '1050,1080p' src/megatron/bridge/training/checkpointing.py | tail -30

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1473


🏁 Script executed:

# Get more context showing the condition/format check before lines 1100-1106
sed -n '1090,1115p' src/megatron/bridge/training/checkpointing.py | cat -n

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1159


🏁 Script executed:

# Search for functions that save FSDP DTensor checkpoints
rg -n "def.*save.*fsdp.*dtensor|def.*fsdp.*dtensor.*save" -i src/megatron/bridge/training/checkpointing.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 53


🏁 Script executed:

# Look at the preprocess_fsdp_dtensor_state_dict function definition
sed -n '1178,1230p' src/megatron/bridge/training/checkpointing.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1995


🏁 Script executed:

# Get the function name and full context around lines 1098-1106
sed -n '1050,1065p' src/megatron/bridge/training/checkpointing.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1072


🏁 Script executed:

# Check what function signature and docstring contains this save logic
sed -n '1030,1060p' src/megatron/bridge/training/checkpointing.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1769


🏁 Script executed:

# Search for when _generate_model_state_dict is called
rg -n "_generate_model_state_dict" src/megatron/bridge/training/checkpointing.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 332


🏁 Script executed:

# Check the context around those calls to see if they specify fsdp_dtensor format
sed -n '625,645p' src/megatron/bridge/training/checkpointing.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1246


🏁 Script executed:

# Check the context around line 1272
sed -n '1260,1280p' src/megatron/bridge/training/checkpointing.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1077


🏁 Script executed:

# Look at the save flow for FSDP DTensor to confirm state_dict_for_save_checkpoint is used
sed -n '1140,1160p' src/megatron/bridge/training/checkpointing.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1033


Switch to state_dict_for_save_checkpoint() to match saved DTensor checkpoints.

FSDP DTensor checkpoints are saved using state_dict_for_save_checkpoint() (via _generate_model_state_dict() when ckpt_format == "fsdp_dtensor"), but this load path uses sharded_state_dict() instead. This mismatch will cause key/type errors at load time.

🔧 Fix
-    model_state_dict = model_instance.sharded_state_dict()
+    model_state_dict = model_instance.state_dict_for_save_checkpoint()
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
model_list = unwrap_model(model)
model_instance = model_list[0]
# Build model sharded state dict and preprocess for FSDP DTensor format
model_state_dict = model_instance.sharded_state_dict()
state_dict = {"model": model_state_dict}
state_dict = preprocess_fsdp_dtensor_state_dict(None, state_dict, model_instance)
model_list = unwrap_model(model)
model_instance = model_list[0]
# Build model sharded state dict and preprocess for FSDP DTensor format
model_state_dict = model_instance.state_dict_for_save_checkpoint()
state_dict = {"model": model_state_dict}
state_dict = preprocess_fsdp_dtensor_state_dict(None, state_dict, model_instance)
🤖 Prompt for AI Agents
In `@src/megatron/bridge/training/checkpointing.py` around lines 1386 - 1393, The
code builds a model state dict using model_instance.sharded_state_dict(), but
DTensor checkpoints are saved with state_dict_for_save_checkpoint() so loading
will mismatch; replace the call to model_instance.sharded_state_dict() with
model_instance.state_dict_for_save_checkpoint() (or use the same helper path
used by _generate_model_state_dict when ckpt_format == "fsdp_dtensor") and then
pass that result into preprocess_fsdp_dtensor_state_dict(...) so the produced
state_dict matches the saved DTensor checkpoint format; keep references to
unwrap_model and model_instance so the change is localized to where
model_state_dict is constructed.

# Load using PyTorch DCP
state_dict = _load_fsdp_dtensor_state_dict(checkpoint_path, state_dict)

if return_state_dict:
# Convert DTensors to full tensors for export
return _convert_dtensor_state_dict_to_full(state_dict)
Comment on lines +1394 to +1399
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

strict is ignored for DTensor loads.

_load_model_weights_fsdp_dtensor accepts strict, but the call to _load_fsdp_dtensor_state_dict doesn’t pass it, so partial-load is always allowed. Forward the flag to honor the API contract.

🐛 Suggested fix
-    state_dict = _load_fsdp_dtensor_state_dict(checkpoint_path, state_dict)
+    state_dict = _load_fsdp_dtensor_state_dict(checkpoint_path, state_dict, strict=strict)
🤖 Prompt for AI Agents
In `@src/megatron/bridge/training/checkpointing.py` around lines 1394 - 1399, The
call to _load_fsdp_dtensor_state_dict from _load_model_weights_fsdp_dtensor
ignores the strict parameter, so partial-load behavior is always allowed;
forward the strict argument when invoking _load_fsdp_dtensor_state_dict (and any
other internal DTensor-loading helpers called in the same flow) so that the
strict flag passed into _load_model_weights_fsdp_dtensor is honored, and ensure
the return_state_dict branch that converts DTensors to full tensors still
receives the state_dict produced with the strict semantics.


# Load the state dict into the model
_load_model_state_dict(model_instance, state_dict["model"], strict)

if torch.distributed.is_initialized():
torch.distributed.barrier()


def load_model_weights(
model: list[MegatronModule],
checkpoint_path: str,
*,
fully_parallel_load: bool = False,
strict: bool = True,
return_state_dict: bool = False,
) -> Optional[StateDict]:
"""Load only model weights from a checkpoint.

Simple API for loading pretrained model weights without optimizer state,
RNG state, or iteration tracking. Supports both ``torch_dist`` and ``fsdp_dtensor``
checkpoint formats.

This function automatically:
- Resolves top-level checkpoint directories to the latest iteration
- Detects the checkpoint format
- Loads only model weights (no optimizer, RNG, or training state)

Args:
model: The model(s) to load weights into.
checkpoint_path: Path to checkpoint. Can be either:
- A top-level checkpoint directory (loads latest iteration)
- A specific iteration directory (e.g., ``/path/to/checkpoint/iter_0000005``)
fully_parallel_load: Apply full load parallelization across data parallel ranks.
Only supported for ``torch_dist`` format; ignored for other formats.
strict: Whether to enforce strict state dict loading.
return_state_dict: If True, return the state dict instead of loading into model.

Returns:
If return_state_dict is True, returns the model state dict.
Otherwise returns None.

Raises:
FileNotFoundError: If no valid checkpoint is found at the path.
NotImplementedError: If the checkpoint format is not supported.

Example:
>>> # Load from a specific iteration
>>> load_model_weights(model, "/checkpoints/iter_0000005")

>>> # Load latest checkpoint from a directory
>>> load_model_weights(model, "/checkpoints")

>>> # Get state dict for export
>>> state_dict = load_model_weights(model, "/checkpoints", return_state_dict=True)
"""
# Resolve to specific iteration directory
resolved_path = resolve_checkpoint_path(checkpoint_path)
ckpt_format = _get_checkpoint_format(resolved_path)

if ckpt_format == "torch_dist":
return _load_model_weights_from_checkpoint(
resolved_path,
model,
fully_parallel_load=fully_parallel_load,
strict=strict,
return_state_dict=return_state_dict,
)
elif ckpt_format == "fsdp_dtensor":
if fully_parallel_load:
print_rank_0("Warning: fully_parallel_load is not supported for fsdp_dtensor format, ignoring")
return _load_model_weights_fsdp_dtensor(
model, resolved_path, strict=strict, return_state_dict=return_state_dict
)
else:
raise NotImplementedError(
f"Checkpoint format '{ckpt_format}' is not supported for load_model_weights. "
f"Supported formats: 'torch_dist', 'fsdp_dtensor'"
)


def load_checkpoint(
state: GlobalState,
model: list[MegatronModule],
Expand Down Expand Up @@ -2243,24 +2422,8 @@ def _load_fsdp_dtensor_base_checkpoint(
state_dict = preprocess_fsdp_dtensor_state_dict(cfg, state_dict, model[0])

checkpoint_name = get_checkpoint_name(load_dir, iteration, release)
fs_storage_reader = torch.distributed.checkpoint.FileSystemReader(checkpoint_name)

# Configure partial loading based on strict_fsdp_dtensor_load setting
allow_partial_load = not getattr(ckpt_cfg, "strict_fsdp_dtensor_load", False)
if allow_partial_load:
state_dict_metadata = fs_storage_reader.read_metadata().state_dict_metadata
rank = torch.distributed.get_rank()
import time as _time

_time.sleep(rank * 0.001) # Prevent log overlap across ranks
print_diff_in_state_dicts(state_dict_metadata, state_dict)

planner = torch.distributed.checkpoint.default_planner.DefaultLoadPlanner(allow_partial_load=allow_partial_load)
torch.distributed.checkpoint.load_state_dict(
state_dict=state_dict,
storage_reader=fs_storage_reader,
planner=planner,
)
strict = getattr(ckpt_cfg, "strict_fsdp_dtensor_load", False)
state_dict = _load_fsdp_dtensor_state_dict(checkpoint_name, state_dict, strict=strict)

# Restore raw state dicts to maintain original structure for the rest of the load process
if raw_optimizer_state_dict is not None:
Expand Down
14 changes: 4 additions & 10 deletions src/megatron/bridge/training/model_load_save.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,9 +265,7 @@ def build_and_load_model(
The model instance with loaded weights if return_state_dict is False,
otherwise returns a dictionary containing the full, unsharded model state_dict.
"""
from megatron.bridge.training.checkpointing import (
_load_model_weights_from_checkpoint,
)
from megatron.bridge.training.checkpointing import load_model_weights
from megatron.bridge.training.mlm_compat.arguments import _tokenizer_config_from_args
from megatron.bridge.training.mlm_compat.model import _get_model, _gpt_provider, _mamba_provider
from megatron.bridge.training.post_training.checkpointing import has_modelopt_state
Expand Down Expand Up @@ -324,15 +322,11 @@ def _load_checkpoint():

load_modelopt_state(model, checkpoint_path)

maybe_state_dict = _load_model_weights_from_checkpoint(
checkpoint_path, model, return_state_dict=return_state_dict
)

result = load_model_weights(model, checkpoint_path, return_state_dict=return_state_dict)
if return_state_dict:
del model
return maybe_state_dict
else:
return model
return result
return model

if skip_temp_dist_context:
return _load_checkpoint()
Expand Down
Loading
Loading