Skip to content

Conversation

@zhenga1
Copy link
Contributor

@zhenga1 zhenga1 commented Nov 12, 2025

Added new runnable script main_load which can be referenced with uv pip.

Added support for new functions in worker.py. Provided example implementation of working conversion using a new script skyrl-train/examples/gsm8k/convert_megatron_to_hf.sh

To run the script, copy the original .sh file used for training and change the entrypoint from skyrl_train.entrypoints.main_base to skyrl_train.entrypoints.main_load.

Please comment and give feedback.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds functionality to convert Megatron checkpoints to the HuggingFace safetensors format. It introduces a new entrypoint main_load and an example script to perform this conversion. The changes look good overall, but I have a few suggestions to improve maintainability and code quality.

My main feedback is to refactor the new main_load.py entrypoint to avoid significant code duplication with the existing training entrypoint. I've also pointed out some leftover debugging code that should be removed, and made a few suggestions for improving code clarity and consistency in trainer.py and the new shell script. Addressing these points will make the new functionality more robust and easier to maintain.

Comment on lines 143 to 146
self.print("Saving checkpoint function began: ")
# import pdb
# pdb.set_trace()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This appears to be leftover debugging code, including a print statement and commented-out pdb calls. This should be removed before merging.

Comment on lines 210 to 213
self.print("Loading checkpoint function began: ")
# import pdb
# pdb.set_trace()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This appears to be leftover debugging code, including a print statement and commented-out pdb calls. This should be removed before merging.

Comment on lines 1 to 315
"""
Main entrypoint for training.
"""

from ray.util.placement_group import placement_group, PlacementGroup

from transformers import AutoTokenizer, PreTrainedTokenizerBase
from skyrl_train.dataset import PromptDataset
from skyrl_train.utils import validate_cfg

from skyrl_train.trainer import RayPPOTrainer
from skyrl_train.inference_engines.inference_engine_client import InferenceEngineClient
from skyrl_train.inference_engines.remote_inference_engine import create_remote_inference_engines
from skyrl_train.utils.utils import initialize_ray, get_ray_pg_ready_with_timeout
from skyrl_train.utils.constants import SKYRL_RAY_PG_TIMEOUT_IN_S
from skyrl_train.generators.base import GeneratorInterface
from omegaconf import OmegaConf, DictConfig
from pathlib import Path
import ray

import os
import hydra
from loguru import logger
from skyrl_train.utils.tracking import Tracking
import multiprocessing as mp

# NOTE (sumanthrh): We use ray heavily and thus disable `fork` start method.
# forking within ray leads to undefined behaviour and often causes hard to debug
# memory leaks. See: https://docs.ray.io/en/latest/ray-core/patterns/fork-new-processes.html
# A common culprit is Pytorch dataloaders which use `fork` by default.
mp.set_start_method("spawn", force=True)

config_dir = str(Path(__file__).parent.parent / "config")
__all__ = ["BasePPOExp", "config_dir"]


def create_ray_wrapped_inference_engines_from_config(cfg: DictConfig, colocate_pg, tokenizer: PreTrainedTokenizerBase):
from skyrl_train.inference_engines.ray_wrapped_inference_engine import create_ray_wrapped_inference_engines

engine_kwargs = {
"num_inference_engines": cfg.generator.num_inference_engines,
"tensor_parallel_size": cfg.generator.inference_engine_tensor_parallel_size,
"pipeline_parallel_size": cfg.generator.inference_engine_pipeline_parallel_size,
"model_dtype": cfg.generator.model_dtype,
"pretrain": cfg.trainer.policy.model.path,
"seed": cfg.trainer.seed,
"vllm_v1_disable_multiproc": cfg.generator.vllm_v1_disable_multiproc,
"enable_prefix_caching": cfg.generator.enable_prefix_caching,
"enforce_eager": cfg.generator.enforce_eager,
"expert_parallel_size": cfg.generator.inference_engine_expert_parallel_size,
"data_parallel_size": cfg.generator.inference_engine_data_parallel_size,
"shared_pg": colocate_pg,
"gpu_memory_utilization": cfg.generator.gpu_memory_utilization,
"inference_engine_enable_sleep": cfg.trainer.placement.colocate_all,
"async_engine": cfg.generator.async_engine,
"max_num_batched_tokens": cfg.generator.max_num_batched_tokens,
"max_num_seqs": cfg.generator.max_num_seqs,
"tokenizer": tokenizer,
"backend": cfg.generator.backend,
"engine_init_kwargs": cfg.generator.engine_init_kwargs,
}

# Conditionally add LoRA parameters if LoRA is enabled
if cfg.trainer.policy.model.lora.rank > 0:
engine_kwargs["enable_lora"] = True
engine_kwargs["max_lora_rank"] = cfg.trainer.policy.model.lora.rank
engine_kwargs["sleep_level"] = 1
engine_kwargs["max_loras"] = 1

if (rope_scaling := cfg.generator.get("rope_scaling", None)) is not None:
engine_kwargs["rope_scaling"] = rope_scaling
if (rope_theta := cfg.generator.get("rope_theta", None)) is not None:
engine_kwargs["rope_theta"] = rope_theta

return create_ray_wrapped_inference_engines(**engine_kwargs)


def create_remote_inference_engines_from_config(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase):
# TODO(tgriggs): We may want a separate config for the model name in case it's different from the name used in the OpenAI API
return create_remote_inference_engines(
urls=cfg.generator.remote_inference_engine_urls,
model_name=cfg.trainer.policy.model.path,
engine_backend=cfg.generator.backend,
tokenizer=tokenizer,
tensor_parallel_size=cfg.generator.inference_engine_tensor_parallel_size,
pipeline_parallel_size=cfg.generator.inference_engine_pipeline_parallel_size,
data_parallel_size=cfg.generator.inference_engine_data_parallel_size,
expert_parallel_size=cfg.generator.inference_engine_expert_parallel_size,
)


class BasePPOExp:
def __init__(self, cfg: DictConfig):
"""
Initializes a PPO experiment.
The `cfg` passed here will be the final config from Hydra, including CLI overrides.
"""
self.cfg = cfg
self.tokenizer = self.get_tokenizer()
self.train_dataset = self.get_train_dataset()
self.eval_dataset = self.get_eval_dataset()
self.colocate_pg = self.get_colocate_pg()

@staticmethod
def get_cfg_as_str(dict_cfg: DictConfig) -> str:
return OmegaConf.to_yaml(dict_cfg)

def get_tokenizer(self, padding_side="left"):
"""Initializes a tokenizer for the given model."""
tokenizer = AutoTokenizer.from_pretrained(
self.cfg.trainer.policy.model.path,
trust_remote_code=True,
use_fast=not self.cfg.trainer.disable_fast_tokenizer,
)
tokenizer.padding_side = padding_side
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
return tokenizer

def get_train_dataset(self):
"""Initializes the training dataset.
Returns:
PromptDataset: The training dataset.
"""
prompts_dataset = PromptDataset(
datasets=self.cfg.data.train_data,
tokenizer=self.tokenizer,
max_prompt_length=self.cfg.trainer.max_prompt_length,
num_workers=8,
)
# make sure the dataset is large enough to train on
assert (
len(prompts_dataset) >= self.cfg.trainer.train_batch_size
), f"dataset should be atleast as large as `train_batch_size` {self.cfg.trainer.train_batch_size}, got size {len(prompts_dataset)}"
return prompts_dataset

def get_eval_dataset(self):
"""Initializes the evaluation dataset.
Returns:
PromptDataset: The evaluation dataset.
"""
if self.cfg.trainer.eval_interval > 0 and self.cfg.data.val_data:
prompts_dataset = PromptDataset(
datasets=self.cfg.data.val_data,
tokenizer=self.tokenizer,
max_prompt_length=self.cfg.trainer.max_prompt_length,
num_workers=8,
)
return prompts_dataset
return None

def get_colocate_pg(self, timeout: int = SKYRL_RAY_PG_TIMEOUT_IN_S) -> PlacementGroup:
"""Initializes a placement group for colocated training.
A single placement group that packs all the inference engines together is created.
Args:
timeout (int): The timeout for the placement group to be ready.
Returns:
PlacementGroup: The placement group for colocated training.
"""
if self.cfg.trainer.placement.colocate_all:
pg = placement_group(
[{"GPU": 1, "CPU": 1}]
* self.cfg.generator.num_inference_engines
* self.cfg.generator.inference_engine_tensor_parallel_size
* self.cfg.generator.inference_engine_pipeline_parallel_size
* self.cfg.generator.inference_engine_data_parallel_size,
strategy="PACK",
)
get_ray_pg_ready_with_timeout(pg, timeout=timeout)
return pg
else:
return None

def get_generator(self, cfg, tokenizer, inference_engine_client):
"""Initializes the generator.
Returns:
GeneratorInterface: The generator.
"""
from skyrl_train.generators.skyrl_gym_generator import SkyRLGymGenerator

return SkyRLGymGenerator(
generator_cfg=cfg.generator,
skyrl_gym_cfg=cfg.environment.skyrl_gym,
inference_engine_client=inference_engine_client,
tokenizer=tokenizer,
model_name=cfg.trainer.policy.model.path,
)

def get_trainer(
self,
cfg,
tracker,
tokenizer,
train_dataset,
eval_dataset,
inference_engine_client,
generator: GeneratorInterface,
colocate_pg,
):
"""Initializes the trainer.
Returns:
RayPPOTrainer: The trainer.
"""
return RayPPOTrainer(
cfg=cfg,
tracker=tracker,
tokenizer=tokenizer,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
inference_engine_client=inference_engine_client,
generator=generator,
colocate_pg=colocate_pg,
)

def get_tracker(self):
"""Initializes the tracker for experiment tracking.
Returns:
Tracking: The tracker.
"""
return Tracking(
project_name=self.cfg.trainer.project_name,
experiment_name=self.cfg.trainer.run_name,
backends=self.cfg.trainer.logger,
config=self.cfg,
)

def _setup_trainer(self):
"""Setup and return the trainer.
Instantiates the trainer and all the associated models for training.
Returns:
RayPPOTrainer: The trainer.
"""
logger.info(self.get_cfg_as_str(self.cfg))
os.makedirs(self.cfg.trainer.export_path, exist_ok=True)
os.makedirs(self.cfg.trainer.ckpt_path, exist_ok=True)

if self.cfg.trainer.strategy == "deepspeed":
from skyrl_train.workers.deepspeed.deepspeed_worker import (
PolicyWorker,
CriticWorker,
RefWorker,
)
elif self.cfg.trainer.strategy in ("fsdp", "fsdp2"):
from skyrl_train.workers.fsdp.fsdp_worker import PolicyWorker, CriticWorker, RefWorker
elif self.cfg.trainer.strategy == "megatron":
from skyrl_train.workers.megatron.megatron_worker import PolicyWorker, CriticWorker, RefWorker
else:
raise ValueError(f"Unknown strategy type: {self.cfg.trainer.strategy}")

# NOTE (sumanthrh): Instantiate tracker before trainer init.
# We have custom validation before this step to give better error messages.
tracker = self.get_tracker()

tokenizer = self.tokenizer
if self.cfg.generator.run_engines_locally:
inference_engines = create_ray_wrapped_inference_engines_from_config(self.cfg, self.colocate_pg, tokenizer)
else:
inference_engines = create_remote_inference_engines_from_config(self.cfg, tokenizer)

inference_engine_client = InferenceEngineClient(inference_engines, tokenizer, self.cfg)

generator: GeneratorInterface = self.get_generator(self.cfg, tokenizer, inference_engine_client)

trainer = self.get_trainer(
cfg=self.cfg,
tracker=tracker,
tokenizer=tokenizer,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
inference_engine_client=inference_engine_client,
generator=generator,
colocate_pg=self.colocate_pg,
)

# Build the models
trainer.build_models(PolicyWorker, CriticWorker, RefWorker)
return trainer

def run(self):
trainer = self._setup_trainer()
# Start the training loop
trainer.get_model()



@ray.remote(num_cpus=1)
def skyrl_entrypoint(cfg: DictConfig):
# make sure that the training loop is not run on the head node.
exp = BasePPOExp(cfg)
exp.run()


@hydra.main(config_path=config_dir, config_name="ppo_base_config", version_base=None)
def main(cfg: DictConfig) -> None:
# validate the arguments
validate_cfg(cfg)

initialize_ray(cfg)
ray.get(skyrl_entrypoint.remote(cfg))


if __name__ == "__main__":
main()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This new file main_load.py appears to be almost a complete duplicate of the existing training entrypoint. This introduces significant code duplication, which will make future maintenance harder.

Instead of creating a new entrypoint file, consider adding a new mode or command-line argument to the existing entrypoint. For example, you could add a --task argument that could be train or convert_checkpoint. This would allow you to reuse the majority of the setup code and just change the final action (e.g., calling trainer.train() vs. a renamed trainer.get_model()).

Comment on lines 62 to 64
trainer.resume_mode=null \
trainer.ckpt_path="$HOME/ckpts/gsm8k_1.5B_ckpt" \
trainer.resume_mode="latest"\
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The trainer.resume_mode is set to null on line 62 and then immediately overridden to "latest" on line 64. The first setting is redundant and can be removed to avoid confusion.

Suggested change
trainer.resume_mode=null \
trainer.ckpt_path="$HOME/ckpts/gsm8k_1.5B_ckpt" \
trainer.resume_mode="latest"\
trainer.ckpt_path="$HOME/ckpts/gsm8k_1.5B_ckpt" \
trainer.resume_mode="latest"\


export_root = getattr(self.cfg.trainer, "export_path", None)
if not export_root:
import datetime
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The import datetime statement is inside the save_to_hf method. According to Python's style guide (PEP 8), imports should usually be at the top of the file. Please move this import to the top-level of the module for better readability and to avoid potential issues.

Comment on lines 159 to 160
home_dir = os.path.expanduser("~")
export_root = os.path.join(home_dir, "exports", timestamp)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This section uses os.path.expanduser and os.path.join for path manipulation. Other parts of the codebase, like skyrl-train/skyrl_train/entrypoints/main_load.py, use pathlib.Path. For consistency and better readability, it would be better to use pathlib.Path here as well. The Path object is already imported in this file.

Suggested change
home_dir = os.path.expanduser("~")
export_root = os.path.join(home_dir, "exports", timestamp)
home_dir = Path.home()
export_root = home_dir / "exports" / timestamp

finally:
model.offload_to_cpu()

def get_model(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The method name get_model is a bit misleading. It doesn't just "get" a model; it performs a sequence of operations: initializing weights, loading a checkpoint, and then saving the models in Hugging Face format. A more descriptive name like load_checkpoint_and_save_to_hf would better reflect its purpose and improve code clarity. Remember to update the call site in skyrl-train/skyrl_train/entrypoints/main_load.py as well.

Suggested change
def get_model(self):
def load_checkpoint_and_save_to_hf(self):

@zhenga1
Copy link
Contributor Author

zhenga1 commented Nov 14, 2025

Fixed the other suggestions, have not changed the entrypoint main_base.py, since there may be many training scripts that reference main_base.py already. In other words, the main_load.py and main_base.py still both exists.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants