diff --git a/skyrl-train/examples/gsm8k/convert_megatron_to_hf.sh b/skyrl-train/examples/gsm8k/convert_megatron_to_hf.sh new file mode 100644 index 000000000..d15041aef --- /dev/null +++ b/skyrl-train/examples/gsm8k/convert_megatron_to_hf.sh @@ -0,0 +1,64 @@ +set -x + +# Colocated GRPO model conversion pipeline from megatron to huggingface. + +# Assumed that you already have the megatron checkpoints for Qwen2.5-1.5B-Instruct on GSM8K, and you have already finished the training. + +## TRAINING SCRIPT ## +# uv run examples/gsm8k/gsm8k_dataset.py --output_dir $HOME/data/gsm8k +# export WANDB_API_KEY= +# bash examples/gsm8k/run_gsm8k.sh + +# NOTE (sumanthrh): `micro_train_batch_size_per_gpu` and `micro_forward_batch_size_per_gpu` can be tuned + +# Then you can execute the following script like this: + +# bash examples/gsm8k/convert_megatron_to_hf.sh + +: "${DATA_DIR:="$HOME/data/gsm8k"}" +: "${NUM_GPUS:=1}" +: "${LOGGER:=wandb}" # change to "console" to print to stdout + +: "${INFERENCE_BACKEND:=vllm}" + + +uv run --isolated --extra mcore -m skyrl_train.entrypoints.main_load \ + data.train_data="['$DATA_DIR/train.parquet']" \ + data.val_data="['$DATA_DIR/validation.parquet']" \ + trainer.algorithm.advantage_estimator="grpo" \ + trainer.policy.model.path="Qwen/Qwen2.5-0.5B-Instruct" \ + trainer.placement.colocate_all=true \ + trainer.strategy=megatron \ + trainer.placement.policy_num_gpus_per_node=$NUM_GPUS \ + trainer.placement.critic_num_gpus_per_node=$NUM_GPUS \ + trainer.placement.ref_num_gpus_per_node=$NUM_GPUS \ + generator.num_inference_engines=$NUM_GPUS \ + generator.inference_engine_tensor_parallel_size=1 \ + trainer.epochs=20 \ + trainer.eval_batch_size=1024 \ + trainer.eval_before_train=true \ + trainer.eval_interval=5 \ + trainer.update_epochs_per_batch=1 \ + trainer.train_batch_size=1024 \ + trainer.policy_mini_batch_size=256 \ + trainer.micro_forward_batch_size_per_gpu=1 \ + trainer.micro_train_batch_size_per_gpu=1 \ + trainer.ckpt_interval=1 \ + trainer.max_prompt_length=512 \ + generator.sampling_params.max_generate_length=1024 \ + trainer.policy.optimizer_config.lr=1.0e-6 \ + trainer.algorithm.use_kl_loss=true \ + generator.backend=$INFERENCE_BACKEND \ + generator.run_engines_locally=true \ + generator.weight_sync_backend=nccl \ + generator.async_engine=true \ + generator.batched=true \ + environment.env_class=gsm8k \ + generator.n_samples_per_prompt=5 \ + generator.gpu_memory_utilization=0.8 \ + trainer.logger="$LOGGER" \ + trainer.project_name="gsm8k" \ + trainer.run_name="gsm8k_test" \ + trainer.ckpt_path="$HOME/ckpts/gsm8k_1.5B_ckpt" \ + trainer.resume_mode="latest"\ + $@ \ No newline at end of file diff --git a/skyrl-train/skyrl_train/distributed/megatron/megatron_strategy.py b/skyrl-train/skyrl_train/distributed/megatron/megatron_strategy.py deleted file mode 100644 index 58402a3a0..000000000 --- a/skyrl-train/skyrl_train/distributed/megatron/megatron_strategy.py +++ /dev/null @@ -1,278 +0,0 @@ -import os -import random -from datetime import timedelta -from typing import List, Union, Optional -from jaxtyping import Float - -import numpy as np -import torch -import torch.nn as nn -from torch import optim -from torch import distributed as dist - -from skyrl_train.distributed.strategy import DistributedStrategy -from skyrl_train.distributed.utils import ModelOrModelOptimPair -from skyrl_train.utils.io import io -from skyrl_train.workers.megatron.megatron_model_wrapper import MegatronModelWrapper -import megatron.core.parallel_state as mpu -from skyrl_train.distributed.megatron.megatron_utils import ( - offload_megatron_model_to_cpu, - load_megatron_model_to_gpu, - offload_megatron_optimizer, - load_megatron_optimizer, -) - -from megatron.core.dist_checkpointing.strategies import base as ckpt_base -from megatron.core.dist_checkpointing.strategies.async_utils import AsyncCallsQueue - -from megatron.core import dist_checkpointing -from megatron.core.dist_checkpointing.serialization import ( - get_default_load_sharded_strategy, - get_default_save_sharded_strategy, -) -from megatron.core.dist_checkpointing.strategies.fully_parallel import ( - FullyParallelLoadStrategyWrapper, - FullyParallelSaveStrategyWrapper, -) -from transformers import PreTrainedTokenizer -from megatron.core.optimizer import DistributedOptimizer -from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler - - -class MegatronStrategy(DistributedStrategy): - """ - The strategy for training with Megatron. - """ - - def __init__( - self, - megatron_config, - optimizer_config=None, - seed: int = 42, - ) -> None: - super().__init__() - self.megatron_config = megatron_config - self.optimizer_config = optimizer_config - self.seed = seed - self.hf_config = None # Set by the megatron worker once configs are initialized. - - # NOTE: Set Megatron dist checkpoint async backend to persistent to avoid `os.fork()`-ing - # short-lived background workers, which does not work well with Ray. - ckpt_base.async_calls = AsyncCallsQueue(persistent=True) - - def set_seed(self, seed: int) -> None: - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - if torch.cuda.device_count() > 0: - from megatron.core import tensor_parallel - - tensor_parallel.model_parallel_cuda_manual_seed(seed) - - def setup_distributed(self, timeout=timedelta(minutes=30)) -> None: - local_rank = int(os.environ.get("LOCAL_RANK", "-1")) - if local_rank != -1: - torch.cuda.set_device(local_rank) - - mpu.initialize_model_parallel( - tensor_model_parallel_size=self.megatron_config.tensor_model_parallel_size, - pipeline_model_parallel_size=self.megatron_config.pipeline_model_parallel_size, - pipeline_model_parallel_split_rank=None, - expert_model_parallel_size=self.megatron_config.expert_model_parallel_size, - expert_tensor_parallel_size=self.megatron_config.expert_tensor_parallel_size, - use_sharp=False, - context_parallel_size=self.megatron_config.context_parallel_size, - nccl_communicator_config_path=None, - ) - self.set_seed(self.seed) - self.world_size = dist.get_world_size() - - def offload_to_cpu( - self, model, optimizer, pin_memory=True, non_blocking=True, offload_optimizer=True, offload_model=True - ): - """ - Offload model weights and optimizer to CPU memory. - """ - if offload_model: - offload_megatron_model_to_cpu(model) - if optimizer and offload_optimizer: - offload_megatron_optimizer(optimizer) - torch.cuda.synchronize() - torch.cuda.empty_cache() - - def backload_to_gpu(self, model, optimizer, non_blocking=True, backload_optimizer=True, backload_model=True): - """Reload model weights back to GPU.""" - if backload_model: - load_megatron_model_to_gpu(model) - if optimizer and backload_optimizer: - load_megatron_optimizer(optimizer) - torch.cuda.synchronize() - - def backward(self, loss: torch.Tensor, model, optimizer: optim.Optimizer, **kwargs) -> None: - raise NotImplementedError() - - def optimizer_step( - self, - optimizer: optim.Optimizer, - model, - scheduler, - name="model", - **kwargs, - ) -> Optional[Float[torch.Tensor, "1"]]: - """Perform optimizer step""" - _, grad_norm, _ = optimizer.step() - scheduler.step(1) - optimizer.zero_grad() - return grad_norm - - def prepare( - self, *models_or_model_optim_pairs: ModelOrModelOptimPair - ) -> Union[List[ModelOrModelOptimPair], ModelOrModelOptimPair]: - raise NotImplementedError() - - def save_checkpoint( - self, - model: MegatronModelWrapper, - ckpt_dir: str, - node_local_rank: int, - optimizer: Optional[DistributedOptimizer] = None, - scheduler: Optional[OptimizerParamScheduler] = None, - tokenizer: Optional[PreTrainedTokenizer] = None, - ): - # Extract base model. - model: List[nn.Module] = model.actor_module - assert len(model) == 1, "Megatron virtual pipeline parallel is not yet supported" - model = model[0] - if hasattr(model, "module"): - model = model.module - - # Create checkpoint directory if it doesn't exist. - if node_local_rank == 0: - io.makedirs(ckpt_dir, exist_ok=True) - - # All ranks wait for the checkpoint directory to be created before saving. - dist.barrier() - - # Collect the sharded state dicts for model and optimizer, and full state dict for the scheduler. - sharded_state_dict = {} - model_sharded_state_dict = model.sharded_state_dict() - sharded_state_dict["model"] = model_sharded_state_dict - if optimizer: - sharded_state_dict["optimizer"] = optimizer.sharded_state_dict(model_sharded_state_dict) - if scheduler: - sharded_state_dict["lr_scheduler"] = scheduler.state_dict() - - # Save RNG state. - sharded_state_dict["rng"] = self.get_rng_state() - - # Save the checkpoint across ranks in parallel. - save_strategy = get_default_save_sharded_strategy("torch_dist") - save_strategy = FullyParallelSaveStrategyWrapper( - save_strategy, mpu.get_data_parallel_group(with_context_parallel=True) - ) - - with io.local_work_dir(ckpt_dir) as work_dir: - # TODO(tgriggs): Support configurable async saves. - async_save_request = dist_checkpointing.save( - sharded_state_dict=sharded_state_dict, - checkpoint_dir=work_dir, - sharded_strategy=save_strategy, - async_sharded_save=False, - validate_access_integrity=True, - ) - assert async_save_request is None, "Async save is not yet supported for Megatron" - - # Only global rank 0 saves the Huggingface config and tokenizer. - if self.is_rank_0(): - hf_dir = os.path.join(work_dir, "huggingface") - self.save_hf_configs(self.hf_config, hf_dir, tokenizer) - - dist.barrier() - ckpt_base.async_calls.close() - ckpt_base.async_calls = AsyncCallsQueue(persistent=True) - self.print(f"Checkpoint successfully saved to {ckpt_dir}") - - def load_checkpoint( - self, - model: MegatronModelWrapper, - ckpt_dir: str, - optimizer: Optional[DistributedOptimizer] = None, - scheduler: Optional[OptimizerParamScheduler] = None, - load_module_strict: bool = True, - load_optimizer_states: bool = True, - load_lr_scheduler_states: bool = True, - ): - if not ckpt_dir or not io.exists(ckpt_dir): - raise FileNotFoundError(f"Checkpoint directory not found: {ckpt_dir}") - - # Extract base model. - model: List[nn.Module] = model.actor_module - assert len(model) == 1, "Megatron virtual pipeline parallel is not yet supported" - unwrapped_model = model[0] - if hasattr(unwrapped_model, "module"): - unwrapped_model = unwrapped_model.module - - # Extract sharded state dicts. - sharded_state_dict = {} - model_sharded_state_dict = unwrapped_model.sharded_state_dict() - sharded_state_dict["model"] = model_sharded_state_dict - if optimizer and load_optimizer_states: - sharded_state_dict["optimizer"] = optimizer.sharded_state_dict(model_sharded_state_dict) - if scheduler and load_lr_scheduler_states: - sharded_state_dict["lr_scheduler"] = scheduler.state_dict() - - with io.local_read_dir(ckpt_dir) as read_dir: - # Load the checkpoint in parallel. - load_strategy = get_default_load_sharded_strategy(read_dir) - load_strategy = FullyParallelLoadStrategyWrapper( - load_strategy, mpu.get_data_parallel_group(with_context_parallel=True) - ) - state_dict = dist_checkpointing.load( - sharded_state_dict=sharded_state_dict, checkpoint_dir=read_dir, sharded_strategy=load_strategy - ) - - # Load the model, optimizer, and scheduler state dicts. - assert ( - "model" in state_dict - ), f"Model state dict not found in checkpoint loaded from {ckpt_dir}. Available keys: {state_dict.keys()}" - model[0].load_state_dict(state_dict["model"], strict=load_module_strict) - self.print("Loaded model state dict.") - - if optimizer and load_optimizer_states: - assert ( - "optimizer" in state_dict - ), f"Optimizer state dict not found in checkpoint loaded from {ckpt_dir}. Available keys: {state_dict.keys()}" - optimizer.load_state_dict(state_dict["optimizer"]) - self.print("Loaded optimizer state dict.") - - if scheduler and load_lr_scheduler_states: - assert ( - "lr_scheduler" in state_dict - ), f"LR scheduler state dict not found in checkpoint loaded from {ckpt_dir}. Available keys: {state_dict.keys()}" - scheduler.load_state_dict(state_dict["lr_scheduler"]) - self.print("Loaded LR scheduler state dict.") - - # Load RNG state, if present. - if "rng" in state_dict: - self.load_rng_state(state_dict["rng"]) - - return ckpt_dir, {} - - def save_hf_model(self, bridge, model: MegatronModelWrapper, output_dir: str, tokenizer=None, **kwargs) -> None: - # Create checkpoint directory if it doesn't exist. - if self.is_rank_0(): - io.makedirs(output_dir, exist_ok=True) - dist.barrier() - - # All ranks call into bridge. - with io.local_work_dir(output_dir) as work_dir: - bridge.save_weights(model.actor_module, work_dir) - self.print(f"Successfully saved HF safetensors model to {output_dir}") - - # Only rank 0 saves the Huggingface config and tokenizer. - if self.is_rank_0(): - self.save_hf_configs(self.hf_config, work_dir, tokenizer) - self.print(f"Successfully saved HF config and tokenizer to {output_dir}") - - dist.barrier() diff --git a/skyrl-train/skyrl_train/entrypoints/main_load.py b/skyrl-train/skyrl_train/entrypoints/main_load.py new file mode 100644 index 000000000..7f61b1c9c --- /dev/null +++ b/skyrl-train/skyrl_train/entrypoints/main_load.py @@ -0,0 +1,314 @@ +""" +Main entrypoint for training. +""" + +from ray.util.placement_group import placement_group, PlacementGroup + +from transformers import AutoTokenizer, PreTrainedTokenizerBase +from skyrl_train.dataset import PromptDataset +from skyrl_train.utils import validate_cfg + +from skyrl_train.trainer import RayPPOTrainer +from skyrl_train.inference_engines.inference_engine_client import InferenceEngineClient +from skyrl_train.inference_engines.remote_inference_engine import create_remote_inference_engines +from skyrl_train.utils.utils import initialize_ray, get_ray_pg_ready_with_timeout +from skyrl_train.utils.constants import SKYRL_RAY_PG_TIMEOUT_IN_S +from skyrl_train.generators.base import GeneratorInterface +from omegaconf import OmegaConf, DictConfig +from pathlib import Path +import ray + +import os +import hydra +from loguru import logger +from skyrl_train.utils.tracking import Tracking +import multiprocessing as mp + +# NOTE (sumanthrh): We use ray heavily and thus disable `fork` start method. +# forking within ray leads to undefined behaviour and often causes hard to debug +# memory leaks. See: https://docs.ray.io/en/latest/ray-core/patterns/fork-new-processes.html +# A common culprit is Pytorch dataloaders which use `fork` by default. +mp.set_start_method("spawn", force=True) + +config_dir = str(Path(__file__).parent.parent / "config") +__all__ = ["BasePPOExp", "config_dir"] + + +def create_ray_wrapped_inference_engines_from_config(cfg: DictConfig, colocate_pg, tokenizer: PreTrainedTokenizerBase): + from skyrl_train.inference_engines.ray_wrapped_inference_engine import create_ray_wrapped_inference_engines + + engine_kwargs = { + "num_inference_engines": cfg.generator.num_inference_engines, + "tensor_parallel_size": cfg.generator.inference_engine_tensor_parallel_size, + "pipeline_parallel_size": cfg.generator.inference_engine_pipeline_parallel_size, + "model_dtype": cfg.generator.model_dtype, + "pretrain": cfg.trainer.policy.model.path, + "seed": cfg.trainer.seed, + "vllm_v1_disable_multiproc": cfg.generator.vllm_v1_disable_multiproc, + "enable_prefix_caching": cfg.generator.enable_prefix_caching, + "enforce_eager": cfg.generator.enforce_eager, + "expert_parallel_size": cfg.generator.inference_engine_expert_parallel_size, + "data_parallel_size": cfg.generator.inference_engine_data_parallel_size, + "shared_pg": colocate_pg, + "gpu_memory_utilization": cfg.generator.gpu_memory_utilization, + "inference_engine_enable_sleep": cfg.trainer.placement.colocate_all, + "async_engine": cfg.generator.async_engine, + "max_num_batched_tokens": cfg.generator.max_num_batched_tokens, + "max_num_seqs": cfg.generator.max_num_seqs, + "tokenizer": tokenizer, + "backend": cfg.generator.backend, + "engine_init_kwargs": cfg.generator.engine_init_kwargs, + } + + # Conditionally add LoRA parameters if LoRA is enabled + if cfg.trainer.policy.model.lora.rank > 0: + engine_kwargs["enable_lora"] = True + engine_kwargs["max_lora_rank"] = cfg.trainer.policy.model.lora.rank + engine_kwargs["sleep_level"] = 1 + engine_kwargs["max_loras"] = 1 + + if (rope_scaling := cfg.generator.get("rope_scaling", None)) is not None: + engine_kwargs["rope_scaling"] = rope_scaling + if (rope_theta := cfg.generator.get("rope_theta", None)) is not None: + engine_kwargs["rope_theta"] = rope_theta + + return create_ray_wrapped_inference_engines(**engine_kwargs) + + +def create_remote_inference_engines_from_config(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase): + # TODO(tgriggs): We may want a separate config for the model name in case it's different from the name used in the OpenAI API + return create_remote_inference_engines( + urls=cfg.generator.remote_inference_engine_urls, + model_name=cfg.trainer.policy.model.path, + engine_backend=cfg.generator.backend, + tokenizer=tokenizer, + tensor_parallel_size=cfg.generator.inference_engine_tensor_parallel_size, + pipeline_parallel_size=cfg.generator.inference_engine_pipeline_parallel_size, + data_parallel_size=cfg.generator.inference_engine_data_parallel_size, + expert_parallel_size=cfg.generator.inference_engine_expert_parallel_size, + ) + + +class BasePPOExp: + def __init__(self, cfg: DictConfig): + """ + Initializes a PPO experiment. + + The `cfg` passed here will be the final config from Hydra, including CLI overrides. + """ + self.cfg = cfg + self.tokenizer = self.get_tokenizer() + self.train_dataset = self.get_train_dataset() + self.eval_dataset = self.get_eval_dataset() + self.colocate_pg = self.get_colocate_pg() + + @staticmethod + def get_cfg_as_str(dict_cfg: DictConfig) -> str: + return OmegaConf.to_yaml(dict_cfg) + + def get_tokenizer(self, padding_side="left"): + """Initializes a tokenizer for the given model.""" + tokenizer = AutoTokenizer.from_pretrained( + self.cfg.trainer.policy.model.path, + trust_remote_code=True, + use_fast=not self.cfg.trainer.disable_fast_tokenizer, + ) + tokenizer.padding_side = padding_side + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + tokenizer.pad_token_id = tokenizer.eos_token_id + return tokenizer + + def get_train_dataset(self): + """Initializes the training dataset. + + Returns: + PromptDataset: The training dataset. + """ + prompts_dataset = PromptDataset( + datasets=self.cfg.data.train_data, + tokenizer=self.tokenizer, + max_prompt_length=self.cfg.trainer.max_prompt_length, + num_workers=8, + ) + # make sure the dataset is large enough to train on + assert ( + len(prompts_dataset) >= self.cfg.trainer.train_batch_size + ), f"dataset should be atleast as large as `train_batch_size` {self.cfg.trainer.train_batch_size}, got size {len(prompts_dataset)}" + return prompts_dataset + + def get_eval_dataset(self): + """Initializes the evaluation dataset. + + Returns: + PromptDataset: The evaluation dataset. + """ + if self.cfg.trainer.eval_interval > 0 and self.cfg.data.val_data: + prompts_dataset = PromptDataset( + datasets=self.cfg.data.val_data, + tokenizer=self.tokenizer, + max_prompt_length=self.cfg.trainer.max_prompt_length, + num_workers=8, + ) + return prompts_dataset + return None + + def get_colocate_pg(self, timeout: int = SKYRL_RAY_PG_TIMEOUT_IN_S) -> PlacementGroup: + """Initializes a placement group for colocated training. + + A single placement group that packs all the inference engines together is created. + + Args: + timeout (int): The timeout for the placement group to be ready. + + Returns: + PlacementGroup: The placement group for colocated training. + """ + if self.cfg.trainer.placement.colocate_all: + pg = placement_group( + [{"GPU": 1, "CPU": 1}] + * self.cfg.generator.num_inference_engines + * self.cfg.generator.inference_engine_tensor_parallel_size + * self.cfg.generator.inference_engine_pipeline_parallel_size + * self.cfg.generator.inference_engine_data_parallel_size, + strategy="PACK", + ) + get_ray_pg_ready_with_timeout(pg, timeout=timeout) + return pg + else: + return None + + def get_generator(self, cfg, tokenizer, inference_engine_client): + """Initializes the generator. + + Returns: + GeneratorInterface: The generator. + """ + from skyrl_train.generators.skyrl_gym_generator import SkyRLGymGenerator + + return SkyRLGymGenerator( + generator_cfg=cfg.generator, + skyrl_gym_cfg=cfg.environment.skyrl_gym, + inference_engine_client=inference_engine_client, + tokenizer=tokenizer, + model_name=cfg.trainer.policy.model.path, + ) + + def get_trainer( + self, + cfg, + tracker, + tokenizer, + train_dataset, + eval_dataset, + inference_engine_client, + generator: GeneratorInterface, + colocate_pg, + ): + """Initializes the trainer. + + Returns: + RayPPOTrainer: The trainer. + """ + return RayPPOTrainer( + cfg=cfg, + tracker=tracker, + tokenizer=tokenizer, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + inference_engine_client=inference_engine_client, + generator=generator, + colocate_pg=colocate_pg, + ) + + def get_tracker(self): + """Initializes the tracker for experiment tracking. + + Returns: + Tracking: The tracker. + """ + return Tracking( + project_name=self.cfg.trainer.project_name, + experiment_name=self.cfg.trainer.run_name, + backends=self.cfg.trainer.logger, + config=self.cfg, + ) + + def _setup_trainer(self): + """Setup and return the trainer. + + Instantiates the trainer and all the associated models for training. + + Returns: + RayPPOTrainer: The trainer. + """ + logger.info(self.get_cfg_as_str(self.cfg)) + os.makedirs(self.cfg.trainer.export_path, exist_ok=True) + os.makedirs(self.cfg.trainer.ckpt_path, exist_ok=True) + + if self.cfg.trainer.strategy == "deepspeed": + from skyrl_train.workers.deepspeed.deepspeed_worker import ( + PolicyWorker, + CriticWorker, + RefWorker, + ) + elif self.cfg.trainer.strategy in ("fsdp", "fsdp2"): + from skyrl_train.workers.fsdp.fsdp_worker import PolicyWorker, CriticWorker, RefWorker + elif self.cfg.trainer.strategy == "megatron": + from skyrl_train.workers.megatron.megatron_worker import PolicyWorker, CriticWorker, RefWorker + else: + raise ValueError(f"Unknown strategy type: {self.cfg.trainer.strategy}") + + # NOTE (sumanthrh): Instantiate tracker before trainer init. + # We have custom validation before this step to give better error messages. + tracker = self.get_tracker() + + tokenizer = self.tokenizer + if self.cfg.generator.run_engines_locally: + inference_engines = create_ray_wrapped_inference_engines_from_config(self.cfg, self.colocate_pg, tokenizer) + else: + inference_engines = create_remote_inference_engines_from_config(self.cfg, tokenizer) + + inference_engine_client = InferenceEngineClient(inference_engines, tokenizer, self.cfg) + + generator: GeneratorInterface = self.get_generator(self.cfg, tokenizer, inference_engine_client) + + trainer = self.get_trainer( + cfg=self.cfg, + tracker=tracker, + tokenizer=tokenizer, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + inference_engine_client=inference_engine_client, + generator=generator, + colocate_pg=self.colocate_pg, + ) + + # Build the models + trainer.build_models(PolicyWorker, CriticWorker, RefWorker) + return trainer + + def run(self): + trainer = self._setup_trainer() + # Start the training loop + trainer.load_checkpoint_and_save_to_hf() + + +@ray.remote(num_cpus=1) +def skyrl_entrypoint(cfg: DictConfig): + # make sure that the training loop is not run on the head node. + exp = BasePPOExp(cfg) + exp.run() + + +@hydra.main(config_path=config_dir, config_name="ppo_base_config", version_base=None) +def main(cfg: DictConfig) -> None: + # validate the arguments + validate_cfg(cfg) + + initialize_ray(cfg) + ray.get(skyrl_entrypoint.remote(cfg)) + + +if __name__ == "__main__": + main() diff --git a/skyrl-train/skyrl_train/trainer.py b/skyrl-train/skyrl_train/trainer.py index 4705d1612..64525bcee 100644 --- a/skyrl-train/skyrl_train/trainer.py +++ b/skyrl-train/skyrl_train/trainer.py @@ -4,6 +4,7 @@ import shutil from typing import Any, List, Optional, Dict, Tuple, Union from jaxtyping import Float +import datetime from pathlib import Path import ray from ray import ObjectRef @@ -123,6 +124,77 @@ async def eval(self) -> Dict[str, float]: ) return eval_metrics + def save_to_hf(self, model: PPORayActorGroup, model_name: Optional[str] = None): + """ + Save a single model (policy, critic, or ref) in Hugging Face format. + + Args: + model (Optional[PPORayActorGroup]): The Ray actor group representing the model. + model_name (Optional[str]): Optional subdirectory name (e.g. "policy", "critic", "ref"). + If None, the name is inferred automatically. + + Notes: + - The export directory is created under cfg.trainer.export_path/global_step_{self.global_step}/{model_name}. + - Does nothing if model is None. + - This should be called *after* model checkpoint loading or at evaluation export time. + """ + if model is None: + logger.warning(f"Skipping HF export for {model_name or 'unknown model'} (model is None).") + return + + # Infer the model name if it is not provided + if model_name is None: + if model is self.policy_model: + model_name = "policy" + elif model is self.critic_model: + model_name = "critic" + elif model is self.ref_model: + model_name = "ref" + else: + model_name = "unknown" + + export_root = getattr(self.cfg.trainer, "export_path", None) + if not export_root: + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + home_dir = Path.home() + export_root = home_dir / "exports" / timestamp + io.makedirs(export_root, exist_ok=True) + logger.warning(f"cfg.trainer.export_path not defined — using default export path: {export_root}") + export_dir = os.path.join(export_root, f"global_step_{self.global_step}", model_name) + io.makedirs(export_dir, exist_ok=True) + + try: + model.backload_to_gpu() + ray.get( + model.async_run_ray_method( + "pass_through", + "save_hf_model", + export_dir, + self.tokenizer, + ) + ) + logger.info(f"Saved {model_name} model in Hugging Face format to: {export_dir}") + except Exception as e: + logger.error(f"Failed to save {model_name} model to HF format: {e}") + finally: + model.offload_to_cpu() + + def load_checkpoint_and_save_to_hf(self): + # Initialize weight sync state between policy model and inference engines. + with Timer("init_weight_sync_state"): + self.init_weight_sync_state() + + # Load policy model to GPU before loading checkpoint. + if self.colocate_all: + self.policy_model.backload_to_gpu() + + # Load checkpoint state if resumption is enabled. + with Timer("load_checkpoints"): + self.global_step = self.load_checkpoints() + self.save_to_hf(self.policy_model) + self.save_to_hf(self.critic_model) + self.save_to_hf(self.ref_model) + def train(self): """ Main training loop for PPO