Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ test = [
"pytest-runner>=6.0.1",
"pygithub",
"click",
"multi-storage-client",
]
dev = [
"pre-commit>=3.6.0",
Expand Down
217 changes: 34 additions & 183 deletions src/megatron/bridge/training/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,13 @@
import sys
import threading
from enum import Enum, auto
from functools import lru_cache
from logging import getLogger
from pathlib import Path
from time import time
from typing import Any, Callable, Literal, Optional, Union

import numpy as np
import torch
import yaml
from megatron.core import dist_checkpointing, mpu, tensor_parallel
from megatron.core.dist_checkpointing.mapping import ShardedObject, ShardedStateDict
from megatron.core.dist_checkpointing.serialization import (
Expand All @@ -42,6 +40,7 @@
FullyParallelLoadStrategyWrapper,
FullyParallelSaveStrategyWrapper,
)
from megatron.core.msc_utils import MultiStorageClientFeature
from megatron.core.num_microbatches_calculator import update_num_microbatches
from megatron.core.optimizer import MegatronOptimizer
from megatron.core.rerun_state_machine import get_rerun_state_machine
Expand All @@ -52,10 +51,20 @@
from megatron.bridge.training.config import CheckpointConfig
from megatron.bridge.training.state import GlobalState, TrainState
from megatron.bridge.training.utils import wandb_utils
from megatron.bridge.training.utils.checkpoint_utils import (
checkpoint_exists,
ensure_directory_exists,
file_exists,
get_checkpoint_name,
get_checkpoint_run_config_filename,
get_checkpoint_tracker_filename,
get_checkpoint_train_state_filename,
read_run_config,
read_train_state,
)
from megatron.bridge.training.utils.log_utils import append_to_progress_log
from megatron.bridge.utils.common_utils import (
get_rank_safe,
get_world_size_safe,
is_last_rank,
print_rank_0,
unwrap_model,
Expand All @@ -78,9 +87,7 @@
except Exception:
has_nvidia_modelopt = False

TRAIN_STATE_FILE = "train_state.pt"
TRACKER_PREFIX = "latest"
CONFIG_FILE = "run_config.yaml"
_CHECKPOINT_VERSION = None

logger = getLogger(__name__)
Expand Down Expand Up @@ -114,40 +121,6 @@ def get_checkpoint_version() -> Optional[float]:
return _CHECKPOINT_VERSION


def ensure_directory_exists(filename: str, check_parent: bool = True) -> None:
"""Ensure that the directory for a given filename exists.

Args:
filename: The path whose directory should be checked/created.
check_parent: If True (default), checks the parent directory of the filename.
If False, treats the filename itself as the directory path.
"""
dirname = os.path.dirname(filename) if check_parent else filename
os.makedirs(dirname, exist_ok=True)


def get_checkpoint_name(checkpoints_path: str, iteration: int, release: bool = False) -> str:
"""Determine the directory name for a specific checkpoint.

Constructs the path based on iteration number or release flag.

Args:
checkpoints_path: Base directory where checkpoints are stored.
iteration: The training iteration number.
release: If True, uses 'release' as the directory name instead of iteration.

Returns:
The full path to the checkpoint directory.
"""
if release:
directory = "release"
else:
directory = "iter_{:07d}".format(iteration)

common_path = os.path.join(checkpoints_path, directory)
return common_path


def find_checkpoint_rank_0(checkpoints_path: str, iteration: int, release: bool = False) -> Optional[str]:
"""Find the checkpoint directory for a given iteration, assuming distributed checkpoints.

Expand All @@ -169,105 +142,6 @@ def find_checkpoint_rank_0(checkpoints_path: str, iteration: int, release: bool
return None


def get_checkpoint_train_state_filename(checkpoints_path: str, prefix: Optional[str] = None) -> str:
"""Get the filename for the train state tracker file.

This file typically stores metadata about the latest checkpoint, like the iteration number.

Args:
checkpoints_path: Base directory where checkpoints are stored.
prefix: Optional prefix (e.g., 'latest') to prepend to the filename.

Returns:
The full path to the train state tracker file.
"""
if prefix is None:
return os.path.join(checkpoints_path, TRAIN_STATE_FILE)
else:
return os.path.join(checkpoints_path, f"{prefix}_{TRAIN_STATE_FILE}")


def get_checkpoint_run_config_filename(checkpoints_path: str) -> str:
"""Get the filename for the run configuration file within a checkpoint directory.

Args:
checkpoints_path: Base directory where checkpoints are stored.

Returns:
The full path to the run configuration file (e.g., run_config.yaml).
"""
return os.path.join(checkpoints_path, CONFIG_FILE)


def get_checkpoint_tracker_filename(checkpoints_path: str) -> str:
"""Tracker file rescords the latest chckpoint during training to restart from.

Supports checkpoints produced by Megatron-LM.

Args:
checkpoints_path: Base directory where checkpoints are stored.

Returns:
The full path to the checkpoint tracker file (e.g., latest_checkpointed_iteration.txt).
"""
return os.path.join(checkpoints_path, "latest_checkpointed_iteration.txt")


def checkpoint_exists(checkpoints_path: Optional[str]) -> bool:
"""Check if a checkpoint directory exists.

Args:
checkpoints_path: Path to the potential checkpoint directory.

Returns:
True if the path exists, False otherwise.
"""
if checkpoints_path is None:
return False

train_state_filename = os.path.join(checkpoints_path, f"{TRACKER_PREFIX}_{TRAIN_STATE_FILE}")
if os.path.exists(train_state_filename):
return True

# Fallback to the Megatron-LM tracker file
path = get_checkpoint_tracker_filename(checkpoints_path)
return os.path.isfile(path)


@lru_cache()
def read_train_state(train_state_filename: str) -> TrainState:
"""Read the train state metadata from a YAML file (rank 0 only).

Reads the file on rank 0 and broadcasts the result to other ranks.

Args:
train_state_filename: Path to the train state YAML file.

Returns:
An initialized TrainState object.
"""
state_obj = [None]
if get_rank_safe() == 0:
try:
state_dict = torch.load(train_state_filename, map_location="cpu")
ts = TrainState()
ts.load_state_dict(state_dict)
state_obj[0] = ts
except Exception as e:
error_msg = f"ERROR: Unable to load train state file {train_state_filename}: {e}"
sys.stderr.write(error_msg + "\n")
state_obj[0] = {"error": True, "msg": error_msg}

if torch.distributed.is_initialized():
print_rank_0(f"Broadcasting TrainState from rank 0 to all {get_world_size_safe()} ranks")
torch.distributed.broadcast_object_list(state_obj, src=0)

if isinstance(state_obj[0], dict) and state_obj[0].get("error", False):
raise RuntimeError(state_obj[0]["msg"])

return state_obj[0]


def read_metadata(tracker_filename: str) -> tuple[int, bool]:
"""Read the metadata from the Megatron-LM tracker file.

Expand All @@ -280,7 +154,13 @@ def read_metadata(tracker_filename: str) -> tuple[int, bool]:
iteration = 0
release = False

with open(tracker_filename, "r") as f:
if MultiStorageClientFeature.is_enabled():
msc = MultiStorageClientFeature.import_package()
open_file = msc.open
else:
open_file = open

with open_file(tracker_filename, "r") as f:
metastring = f.read().strip()
try:
iteration = int(metastring)
Expand Down Expand Up @@ -316,43 +196,6 @@ def read_metadata(tracker_filename: str) -> tuple[int, bool]:
return max_iter, release


@lru_cache()
def read_run_config(run_config_filename: str) -> dict[str, Any]:
"""Read the run configuration from a YAML file (rank 0 only).

Reads the file on rank 0 and broadcasts the result to other ranks.

Args:
run_config_filename: Path to the run config YAML file.

Returns:
A dictionary containing the run configuration.

Raises:
RuntimeError: If reading the config file fails on rank 0.
"""
config_obj = [None]

if get_rank_safe() == 0:
try:
with open(run_config_filename, "r") as f:
config_dict = yaml.safe_load(f)
config_obj[0] = config_dict
except Exception as e:
error_msg = f"ERROR: Unable to load config file {run_config_filename}: {e}"
sys.stderr.write(error_msg + "\n")
config_obj[0] = {"error": True, "msg": error_msg}

if torch.distributed.is_initialized():
print_rank_0(f"Broadcasting config from rank 0 to all {get_world_size_safe()} ranks")
torch.distributed.broadcast_object_list(config_obj, src=0)

if isinstance(config_obj[0], dict) and config_obj[0].get("error", False):
raise RuntimeError(config_obj[0]["msg"])

return config_obj[0]


def _extract_megatron_lm_args_from_state_dict(state_dict: dict[str, Any]) -> dict[str, Any]:
"""Extract and convert legacy Megatron-LM args from checkpoint state_dict to Megatron-Bridge config format.

Expand Down Expand Up @@ -735,8 +578,14 @@ def train_state_finalize_fn() -> None:
train_state_dict["floating_point_operations_so_far"] = torch.tensor(
num_floating_point_operations_so_far, dtype=torch.float32
)
torch.save(train_state_dict, train_state_local_filename)
shutil.copy(train_state_local_filename, train_state_global_filename)
if MultiStorageClientFeature.is_enabled():
msc = MultiStorageClientFeature.import_package()
msc.torch.save(train_state_dict, train_state_local_filename)
msc.torch.save(train_state_dict, train_state_global_filename)
else:
torch.save(train_state_dict, train_state_local_filename)
shutil.copy(train_state_local_filename, train_state_global_filename)

cfg.to_yaml(config_filename)

tp_rank = (tensor_rank if tensor_rank is not None else mpu.get_tensor_model_parallel_rank()) + 1
Expand Down Expand Up @@ -1128,7 +977,8 @@ def _load_checkpoint_from_path(
# If that fails, we are loading from a Megatron-LM checkpoint, so extract the corresponding values
# from args in the state_dict
run_config_filename = get_checkpoint_run_config_filename(checkpoint_name)
if os.path.exists(run_config_filename):

if file_exists(run_config_filename):
run_config = read_run_config(run_config_filename)
else:
# Fallback to legacy Megatron-LM args extraction
Expand Down Expand Up @@ -1274,7 +1124,8 @@ def _load_checkpoint_from_path(
# Try to read train_state.pt from checkpoint directory
# If it doesn't exist (checkpoint generated by Megatron-LM), create from available information
train_state_filename = get_checkpoint_train_state_filename(checkpoint_name)
if os.path.exists(train_state_filename):

if file_exists(train_state_filename):
state.train_state = read_train_state(train_state_filename)
else:
# Legacy Megatron-LM checkpoint - create TrainState from checkpoint iteration
Expand Down Expand Up @@ -1597,7 +1448,7 @@ def _get_non_persistent_iteration(
return -1
elif non_persistent_ckpt_type == "global":
train_state_filename = get_checkpoint_train_state_filename(non_persistent_global_dir, prefix=TRACKER_PREFIX)
if os.path.isfile(train_state_filename):
if file_exists(train_state_filename):
train_state = read_train_state(train_state_filename)
iteration = train_state.step
# if train_state.release:
Expand Down Expand Up @@ -1702,14 +1553,14 @@ def _load_base_checkpoint(
tracker_filename = "because load directory is not defined"
if load_dir is not None:
tracker_filename = get_checkpoint_train_state_filename(load_dir, prefix=TRACKER_PREFIX)
if os.path.isfile(tracker_filename):
if file_exists(tracker_filename):
train_state = read_train_state(tracker_filename)
iteration = train_state.step
# release = train_state.release
else:
# Fallback to legacy Megatron-LM tracker file format
legacy_tracker_filename = get_checkpoint_tracker_filename(load_dir)
if os.path.isfile(legacy_tracker_filename):
if file_exists(legacy_tracker_filename):
print_rank_0(f"Loading from legacy Megatron-LM checkpoint format: {legacy_tracker_filename}")
iteration, release = read_metadata(legacy_tracker_filename)
tracker_filename = legacy_tracker_filename # Update for error messages
Expand Down
7 changes: 5 additions & 2 deletions src/megatron/bridge/training/model_load_save.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from megatron.bridge.training.config import CheckpointConfig, ConfigContainer, LoggerConfig
from megatron.bridge.training.state import GlobalState
from megatron.bridge.training.tokenizers.tokenizer import MegatronTokenizer, build_tokenizer
from megatron.bridge.training.utils.checkpoint_utils import file_exists


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -133,7 +134,8 @@ def load_tokenizer(checkpoint_path: str) -> MegatronTokenizer:
from megatron.bridge.utils.instantiate_utils import instantiate

run_config_filename = get_checkpoint_run_config_filename(checkpoint_path)
if os.path.exists(run_config_filename):

if file_exists(run_config_filename):
run_config = read_run_config(run_config_filename)
mbridge_ckpt = True
else:
Expand Down Expand Up @@ -194,7 +196,8 @@ def load_megatron_model(
from megatron.bridge.utils.instantiate_utils import instantiate

run_config_filename = get_checkpoint_run_config_filename(checkpoint_path)
if os.path.exists(run_config_filename):

if file_exists(run_config_filename):
run_config = read_run_config(run_config_filename)
mbridge_ckpt = True
else:
Expand Down
Loading
Loading