-
Notifications
You must be signed in to change notification settings - Fork 215
Added support for saving shard as HuggingFace safetensors for Megatron backend (WIP) #658
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds functionality to convert Megatron checkpoints to the HuggingFace safetensors format. It introduces a new entrypoint main_load and an example script to perform this conversion. The changes look good overall, but I have a few suggestions to improve maintainability and code quality.
My main feedback is to refactor the new main_load.py entrypoint to avoid significant code duplication with the existing training entrypoint. I've also pointed out some leftover debugging code that should be removed, and made a few suggestions for improving code clarity and consistency in trainer.py and the new shell script. Addressing these points will make the new functionality more robust and easier to maintain.
| self.print("Saving checkpoint function began: ") | ||
| # import pdb | ||
| # pdb.set_trace() | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| self.print("Loading checkpoint function began: ") | ||
| # import pdb | ||
| # pdb.set_trace() | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| """ | ||
| Main entrypoint for training. | ||
| """ | ||
|
|
||
| from ray.util.placement_group import placement_group, PlacementGroup | ||
|
|
||
| from transformers import AutoTokenizer, PreTrainedTokenizerBase | ||
| from skyrl_train.dataset import PromptDataset | ||
| from skyrl_train.utils import validate_cfg | ||
|
|
||
| from skyrl_train.trainer import RayPPOTrainer | ||
| from skyrl_train.inference_engines.inference_engine_client import InferenceEngineClient | ||
| from skyrl_train.inference_engines.remote_inference_engine import create_remote_inference_engines | ||
| from skyrl_train.utils.utils import initialize_ray, get_ray_pg_ready_with_timeout | ||
| from skyrl_train.utils.constants import SKYRL_RAY_PG_TIMEOUT_IN_S | ||
| from skyrl_train.generators.base import GeneratorInterface | ||
| from omegaconf import OmegaConf, DictConfig | ||
| from pathlib import Path | ||
| import ray | ||
|
|
||
| import os | ||
| import hydra | ||
| from loguru import logger | ||
| from skyrl_train.utils.tracking import Tracking | ||
| import multiprocessing as mp | ||
|
|
||
| # NOTE (sumanthrh): We use ray heavily and thus disable `fork` start method. | ||
| # forking within ray leads to undefined behaviour and often causes hard to debug | ||
| # memory leaks. See: https://docs.ray.io/en/latest/ray-core/patterns/fork-new-processes.html | ||
| # A common culprit is Pytorch dataloaders which use `fork` by default. | ||
| mp.set_start_method("spawn", force=True) | ||
|
|
||
| config_dir = str(Path(__file__).parent.parent / "config") | ||
| __all__ = ["BasePPOExp", "config_dir"] | ||
|
|
||
|
|
||
| def create_ray_wrapped_inference_engines_from_config(cfg: DictConfig, colocate_pg, tokenizer: PreTrainedTokenizerBase): | ||
| from skyrl_train.inference_engines.ray_wrapped_inference_engine import create_ray_wrapped_inference_engines | ||
|
|
||
| engine_kwargs = { | ||
| "num_inference_engines": cfg.generator.num_inference_engines, | ||
| "tensor_parallel_size": cfg.generator.inference_engine_tensor_parallel_size, | ||
| "pipeline_parallel_size": cfg.generator.inference_engine_pipeline_parallel_size, | ||
| "model_dtype": cfg.generator.model_dtype, | ||
| "pretrain": cfg.trainer.policy.model.path, | ||
| "seed": cfg.trainer.seed, | ||
| "vllm_v1_disable_multiproc": cfg.generator.vllm_v1_disable_multiproc, | ||
| "enable_prefix_caching": cfg.generator.enable_prefix_caching, | ||
| "enforce_eager": cfg.generator.enforce_eager, | ||
| "expert_parallel_size": cfg.generator.inference_engine_expert_parallel_size, | ||
| "data_parallel_size": cfg.generator.inference_engine_data_parallel_size, | ||
| "shared_pg": colocate_pg, | ||
| "gpu_memory_utilization": cfg.generator.gpu_memory_utilization, | ||
| "inference_engine_enable_sleep": cfg.trainer.placement.colocate_all, | ||
| "async_engine": cfg.generator.async_engine, | ||
| "max_num_batched_tokens": cfg.generator.max_num_batched_tokens, | ||
| "max_num_seqs": cfg.generator.max_num_seqs, | ||
| "tokenizer": tokenizer, | ||
| "backend": cfg.generator.backend, | ||
| "engine_init_kwargs": cfg.generator.engine_init_kwargs, | ||
| } | ||
|
|
||
| # Conditionally add LoRA parameters if LoRA is enabled | ||
| if cfg.trainer.policy.model.lora.rank > 0: | ||
| engine_kwargs["enable_lora"] = True | ||
| engine_kwargs["max_lora_rank"] = cfg.trainer.policy.model.lora.rank | ||
| engine_kwargs["sleep_level"] = 1 | ||
| engine_kwargs["max_loras"] = 1 | ||
|
|
||
| if (rope_scaling := cfg.generator.get("rope_scaling", None)) is not None: | ||
| engine_kwargs["rope_scaling"] = rope_scaling | ||
| if (rope_theta := cfg.generator.get("rope_theta", None)) is not None: | ||
| engine_kwargs["rope_theta"] = rope_theta | ||
|
|
||
| return create_ray_wrapped_inference_engines(**engine_kwargs) | ||
|
|
||
|
|
||
| def create_remote_inference_engines_from_config(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase): | ||
| # TODO(tgriggs): We may want a separate config for the model name in case it's different from the name used in the OpenAI API | ||
| return create_remote_inference_engines( | ||
| urls=cfg.generator.remote_inference_engine_urls, | ||
| model_name=cfg.trainer.policy.model.path, | ||
| engine_backend=cfg.generator.backend, | ||
| tokenizer=tokenizer, | ||
| tensor_parallel_size=cfg.generator.inference_engine_tensor_parallel_size, | ||
| pipeline_parallel_size=cfg.generator.inference_engine_pipeline_parallel_size, | ||
| data_parallel_size=cfg.generator.inference_engine_data_parallel_size, | ||
| expert_parallel_size=cfg.generator.inference_engine_expert_parallel_size, | ||
| ) | ||
|
|
||
|
|
||
| class BasePPOExp: | ||
| def __init__(self, cfg: DictConfig): | ||
| """ | ||
| Initializes a PPO experiment. | ||
| The `cfg` passed here will be the final config from Hydra, including CLI overrides. | ||
| """ | ||
| self.cfg = cfg | ||
| self.tokenizer = self.get_tokenizer() | ||
| self.train_dataset = self.get_train_dataset() | ||
| self.eval_dataset = self.get_eval_dataset() | ||
| self.colocate_pg = self.get_colocate_pg() | ||
|
|
||
| @staticmethod | ||
| def get_cfg_as_str(dict_cfg: DictConfig) -> str: | ||
| return OmegaConf.to_yaml(dict_cfg) | ||
|
|
||
| def get_tokenizer(self, padding_side="left"): | ||
| """Initializes a tokenizer for the given model.""" | ||
| tokenizer = AutoTokenizer.from_pretrained( | ||
| self.cfg.trainer.policy.model.path, | ||
| trust_remote_code=True, | ||
| use_fast=not self.cfg.trainer.disable_fast_tokenizer, | ||
| ) | ||
| tokenizer.padding_side = padding_side | ||
| if tokenizer.pad_token is None: | ||
| tokenizer.pad_token = tokenizer.eos_token | ||
| tokenizer.pad_token_id = tokenizer.eos_token_id | ||
| return tokenizer | ||
|
|
||
| def get_train_dataset(self): | ||
| """Initializes the training dataset. | ||
| Returns: | ||
| PromptDataset: The training dataset. | ||
| """ | ||
| prompts_dataset = PromptDataset( | ||
| datasets=self.cfg.data.train_data, | ||
| tokenizer=self.tokenizer, | ||
| max_prompt_length=self.cfg.trainer.max_prompt_length, | ||
| num_workers=8, | ||
| ) | ||
| # make sure the dataset is large enough to train on | ||
| assert ( | ||
| len(prompts_dataset) >= self.cfg.trainer.train_batch_size | ||
| ), f"dataset should be atleast as large as `train_batch_size` {self.cfg.trainer.train_batch_size}, got size {len(prompts_dataset)}" | ||
| return prompts_dataset | ||
|
|
||
| def get_eval_dataset(self): | ||
| """Initializes the evaluation dataset. | ||
| Returns: | ||
| PromptDataset: The evaluation dataset. | ||
| """ | ||
| if self.cfg.trainer.eval_interval > 0 and self.cfg.data.val_data: | ||
| prompts_dataset = PromptDataset( | ||
| datasets=self.cfg.data.val_data, | ||
| tokenizer=self.tokenizer, | ||
| max_prompt_length=self.cfg.trainer.max_prompt_length, | ||
| num_workers=8, | ||
| ) | ||
| return prompts_dataset | ||
| return None | ||
|
|
||
| def get_colocate_pg(self, timeout: int = SKYRL_RAY_PG_TIMEOUT_IN_S) -> PlacementGroup: | ||
| """Initializes a placement group for colocated training. | ||
| A single placement group that packs all the inference engines together is created. | ||
| Args: | ||
| timeout (int): The timeout for the placement group to be ready. | ||
| Returns: | ||
| PlacementGroup: The placement group for colocated training. | ||
| """ | ||
| if self.cfg.trainer.placement.colocate_all: | ||
| pg = placement_group( | ||
| [{"GPU": 1, "CPU": 1}] | ||
| * self.cfg.generator.num_inference_engines | ||
| * self.cfg.generator.inference_engine_tensor_parallel_size | ||
| * self.cfg.generator.inference_engine_pipeline_parallel_size | ||
| * self.cfg.generator.inference_engine_data_parallel_size, | ||
| strategy="PACK", | ||
| ) | ||
| get_ray_pg_ready_with_timeout(pg, timeout=timeout) | ||
| return pg | ||
| else: | ||
| return None | ||
|
|
||
| def get_generator(self, cfg, tokenizer, inference_engine_client): | ||
| """Initializes the generator. | ||
| Returns: | ||
| GeneratorInterface: The generator. | ||
| """ | ||
| from skyrl_train.generators.skyrl_gym_generator import SkyRLGymGenerator | ||
|
|
||
| return SkyRLGymGenerator( | ||
| generator_cfg=cfg.generator, | ||
| skyrl_gym_cfg=cfg.environment.skyrl_gym, | ||
| inference_engine_client=inference_engine_client, | ||
| tokenizer=tokenizer, | ||
| model_name=cfg.trainer.policy.model.path, | ||
| ) | ||
|
|
||
| def get_trainer( | ||
| self, | ||
| cfg, | ||
| tracker, | ||
| tokenizer, | ||
| train_dataset, | ||
| eval_dataset, | ||
| inference_engine_client, | ||
| generator: GeneratorInterface, | ||
| colocate_pg, | ||
| ): | ||
| """Initializes the trainer. | ||
| Returns: | ||
| RayPPOTrainer: The trainer. | ||
| """ | ||
| return RayPPOTrainer( | ||
| cfg=cfg, | ||
| tracker=tracker, | ||
| tokenizer=tokenizer, | ||
| train_dataset=train_dataset, | ||
| eval_dataset=eval_dataset, | ||
| inference_engine_client=inference_engine_client, | ||
| generator=generator, | ||
| colocate_pg=colocate_pg, | ||
| ) | ||
|
|
||
| def get_tracker(self): | ||
| """Initializes the tracker for experiment tracking. | ||
| Returns: | ||
| Tracking: The tracker. | ||
| """ | ||
| return Tracking( | ||
| project_name=self.cfg.trainer.project_name, | ||
| experiment_name=self.cfg.trainer.run_name, | ||
| backends=self.cfg.trainer.logger, | ||
| config=self.cfg, | ||
| ) | ||
|
|
||
| def _setup_trainer(self): | ||
| """Setup and return the trainer. | ||
| Instantiates the trainer and all the associated models for training. | ||
| Returns: | ||
| RayPPOTrainer: The trainer. | ||
| """ | ||
| logger.info(self.get_cfg_as_str(self.cfg)) | ||
| os.makedirs(self.cfg.trainer.export_path, exist_ok=True) | ||
| os.makedirs(self.cfg.trainer.ckpt_path, exist_ok=True) | ||
|
|
||
| if self.cfg.trainer.strategy == "deepspeed": | ||
| from skyrl_train.workers.deepspeed.deepspeed_worker import ( | ||
| PolicyWorker, | ||
| CriticWorker, | ||
| RefWorker, | ||
| ) | ||
| elif self.cfg.trainer.strategy in ("fsdp", "fsdp2"): | ||
| from skyrl_train.workers.fsdp.fsdp_worker import PolicyWorker, CriticWorker, RefWorker | ||
| elif self.cfg.trainer.strategy == "megatron": | ||
| from skyrl_train.workers.megatron.megatron_worker import PolicyWorker, CriticWorker, RefWorker | ||
| else: | ||
| raise ValueError(f"Unknown strategy type: {self.cfg.trainer.strategy}") | ||
|
|
||
| # NOTE (sumanthrh): Instantiate tracker before trainer init. | ||
| # We have custom validation before this step to give better error messages. | ||
| tracker = self.get_tracker() | ||
|
|
||
| tokenizer = self.tokenizer | ||
| if self.cfg.generator.run_engines_locally: | ||
| inference_engines = create_ray_wrapped_inference_engines_from_config(self.cfg, self.colocate_pg, tokenizer) | ||
| else: | ||
| inference_engines = create_remote_inference_engines_from_config(self.cfg, tokenizer) | ||
|
|
||
| inference_engine_client = InferenceEngineClient(inference_engines, tokenizer, self.cfg) | ||
|
|
||
| generator: GeneratorInterface = self.get_generator(self.cfg, tokenizer, inference_engine_client) | ||
|
|
||
| trainer = self.get_trainer( | ||
| cfg=self.cfg, | ||
| tracker=tracker, | ||
| tokenizer=tokenizer, | ||
| train_dataset=self.train_dataset, | ||
| eval_dataset=self.eval_dataset, | ||
| inference_engine_client=inference_engine_client, | ||
| generator=generator, | ||
| colocate_pg=self.colocate_pg, | ||
| ) | ||
|
|
||
| # Build the models | ||
| trainer.build_models(PolicyWorker, CriticWorker, RefWorker) | ||
| return trainer | ||
|
|
||
| def run(self): | ||
| trainer = self._setup_trainer() | ||
| # Start the training loop | ||
| trainer.get_model() | ||
|
|
||
|
|
||
|
|
||
| @ray.remote(num_cpus=1) | ||
| def skyrl_entrypoint(cfg: DictConfig): | ||
| # make sure that the training loop is not run on the head node. | ||
| exp = BasePPOExp(cfg) | ||
| exp.run() | ||
|
|
||
|
|
||
| @hydra.main(config_path=config_dir, config_name="ppo_base_config", version_base=None) | ||
| def main(cfg: DictConfig) -> None: | ||
| # validate the arguments | ||
| validate_cfg(cfg) | ||
|
|
||
| initialize_ray(cfg) | ||
| ray.get(skyrl_entrypoint.remote(cfg)) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This new file main_load.py appears to be almost a complete duplicate of the existing training entrypoint. This introduces significant code duplication, which will make future maintenance harder.
Instead of creating a new entrypoint file, consider adding a new mode or command-line argument to the existing entrypoint. For example, you could add a --task argument that could be train or convert_checkpoint. This would allow you to reuse the majority of the setup code and just change the final action (e.g., calling trainer.train() vs. a renamed trainer.get_model()).
| trainer.resume_mode=null \ | ||
| trainer.ckpt_path="$HOME/ckpts/gsm8k_1.5B_ckpt" \ | ||
| trainer.resume_mode="latest"\ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The trainer.resume_mode is set to null on line 62 and then immediately overridden to "latest" on line 64. The first setting is redundant and can be removed to avoid confusion.
| trainer.resume_mode=null \ | |
| trainer.ckpt_path="$HOME/ckpts/gsm8k_1.5B_ckpt" \ | |
| trainer.resume_mode="latest"\ | |
| trainer.ckpt_path="$HOME/ckpts/gsm8k_1.5B_ckpt" \ | |
| trainer.resume_mode="latest"\ |
skyrl-train/skyrl_train/trainer.py
Outdated
|
|
||
| export_root = getattr(self.cfg.trainer, "export_path", None) | ||
| if not export_root: | ||
| import datetime |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
skyrl-train/skyrl_train/trainer.py
Outdated
| home_dir = os.path.expanduser("~") | ||
| export_root = os.path.join(home_dir, "exports", timestamp) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This section uses os.path.expanduser and os.path.join for path manipulation. Other parts of the codebase, like skyrl-train/skyrl_train/entrypoints/main_load.py, use pathlib.Path. For consistency and better readability, it would be better to use pathlib.Path here as well. The Path object is already imported in this file.
| home_dir = os.path.expanduser("~") | |
| export_root = os.path.join(home_dir, "exports", timestamp) | |
| home_dir = Path.home() | |
| export_root = home_dir / "exports" / timestamp |
skyrl-train/skyrl_train/trainer.py
Outdated
| finally: | ||
| model.offload_to_cpu() | ||
|
|
||
| def get_model(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The method name get_model is a bit misleading. It doesn't just "get" a model; it performs a sequence of operations: initializing weights, loading a checkpoint, and then saving the models in Hugging Face format. A more descriptive name like load_checkpoint_and_save_to_hf would better reflect its purpose and improve code clarity. Remember to update the call site in skyrl-train/skyrl_train/entrypoints/main_load.py as well.
| def get_model(self): | |
| def load_checkpoint_and_save_to_hf(self): |
|
Fixed the other suggestions, have not changed the entrypoint |
Added new runnable script
main_loadwhich can be referenced with uv pip.Added support for new functions in
worker.py. Provided example implementation of working conversion using a new scriptskyrl-train/examples/gsm8k/convert_megatron_to_hf.shTo run the script, copy the original
.shfile used for training and change the entrypoint fromskyrl_train.entrypoints.main_basetoskyrl_train.entrypoints.main_load.Please comment and give feedback.