Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions tests/trainer/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
get_vision_language_model_path,
)
from trinity.buffer import get_buffer_reader
from trinity.cli.launcher import bench, both, explore, run, serve, train
from trinity.cli.launcher import bench, both, convert, explore, run, serve, train
from trinity.common.config import (
AlgorithmConfig,
BufferConfig,
Expand Down Expand Up @@ -98,7 +98,7 @@ def test_trainer(self):
eval_tasksets[0].repeat_times = 4
eval_tasksets[1].repeat_times = 4
self.config.trainer.save_interval = 4
self.config.trainer.save_hf_checkpoint = "always"
self.config.trainer.save_hf_checkpoint = "never"
if self.strategy == "megatron":
self.config.trainer.trainer_strategy = "megatron"
self.config.check_and_update()
Expand Down Expand Up @@ -144,12 +144,18 @@ def test_trainer(self):
)
self.assertGreater(len(os.listdir(os.path.join(checkpoint_step_4, "actor"))), 0)
self.assertGreater(len(os.listdir(os.path.join(checkpoint_step_8, "actor"))), 0)
self.assertGreater(
len(os.listdir(os.path.join(checkpoint_step_4, "actor", "huggingface"))), 0
)
self.assertGreater(
len(os.listdir(os.path.join(checkpoint_step_8, "actor", "huggingface"))), 0
)
hf_dir_step_4 = os.listdir(os.path.join(checkpoint_step_4, "actor", "huggingface"))
hf_dir_step_8 = os.listdir(os.path.join(checkpoint_step_8, "actor", "huggingface"))
self.assertGreater(len(hf_dir_step_4), 0)
self.assertGreater(len(hf_dir_step_8), 0)
self.assertNotIn("model.safetensors", hf_dir_step_4)
self.assertNotIn("model.safetensors", hf_dir_step_8)
# test checkpoint convert
convert(self.config.checkpoint_job_dir)
hf_dir_step_4 = os.listdir(os.path.join(checkpoint_step_4, "actor", "huggingface"))
hf_dir_step_8 = os.listdir(os.path.join(checkpoint_step_8, "actor", "huggingface"))
self.assertIn("model.safetensors", hf_dir_step_4)
self.assertIn("model.safetensors", hf_dir_step_8)
self.assertEqual(step_num, 8)
ray.init(ignore_reinit_error=True, namespace=self.config.ray_namespace)
# test bench mode
Expand Down
184 changes: 184 additions & 0 deletions trinity/cli/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import traceback
from pathlib import Path
from pprint import pprint
from typing import Optional

import ray

Expand Down Expand Up @@ -301,6 +302,171 @@ def debug(
)


class Converter:
def __init__(self, base_model_dir: Optional[str] = None):
self.logger = get_logger(__name__)
self.base_model_dir = base_model_dir
self.base_model = None
self._init_process_group = False
self.checkpoint_converter = None

def init_base_model(self) -> bool:
if not self.base_model_dir:
return False
if self.base_model is not None:
return True
try:
self.base_model, _ = self._get_config_and_empty_model(self.base_model_dir)
except Exception:
return False
return True

def init_process_group(self):
if self._init_process_group:
return

import torch
from verl.utils.device import get_nccl_backend
from verl.utils.distributed import set_numa_affinity

if "WORLD_SIZE" not in os.environ:
os.environ["RANK"] = "0"
os.environ["LOCAL_RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"

set_numa_affinity()
torch.distributed.init_process_group(get_nccl_backend())
self._init_process_group = True

def init_checkpoint_converter(self, checkpoint_dir) -> bool:
if self.checkpoint_converter is not None:
return True
if not os.path.basename(checkpoint_dir).startswith("global_step_"):
self.logger.error(f"Invalid checkpoint directory {checkpoint_dir}.")
return False

actor_ckpt_dir = os.path.join(checkpoint_dir, "actor")
huggingface_dir = os.path.join(actor_ckpt_dir, "huggingface")
if not os.path.exists(os.path.join(huggingface_dir, "config.json")):
if not self.init_base_model():
self.logger.error(
f"Failed to load base model from {self.base_model_dir}, "
"please check if the model exists."
)
return False
self.base_model.config.save_pretrained(huggingface_dir)

from trinity.common.models.utils import get_megatron_converter

self.init_process_group()
self.checkpoint_converter = get_megatron_converter(actor_ckpt_dir)
return True

def _get_config_and_empty_model(self, model_dir: str):
import torch
import transformers
from accelerate import init_empty_weights

model_config = transformers.AutoConfig.from_pretrained(model_dir)

if "ForTokenClassification" in model_config.architectures[0]:
from transformers import AutoModelForTokenClassification

auto_model_cls = AutoModelForTokenClassification
elif "ForCausalLM" in model_config.architectures[0]:
from transformers import AutoModelForCausalLM

auto_model_cls = AutoModelForCausalLM
elif "ForConditionalGeneration" in model_config.architectures[0]:
# Handle different transformers versions for Vision2Seq models
import transformers
from packaging import version

if version.parse(transformers.__version__) >= version.parse("4.54.0"):
# transformers >= 4.54.0 uses AutoModelForImageTextToText
from transformers import AutoModelForImageTextToText

auto_model_cls = AutoModelForImageTextToText
else:
# transformers < 4.54.0 uses AutoModelForVision2Seq
from transformers import AutoModelForVision2Seq

auto_model_cls = AutoModelForVision2Seq
else:
raise NotImplementedError(f"Unknown architecture {model_config['architectures']}")

with init_empty_weights():
model = auto_model_cls.from_config(model_config, dtype=torch.bfloat16)
model.to_empty(device="cpu")

return model, auto_model_cls

def convert(self, checkpoint_dir: str) -> None:
if os.path.basename(checkpoint_dir).startswith("global_step_"):
import torch

actor_ckpt_dir = os.path.join(checkpoint_dir, "actor")
huggingface_dir = os.path.join(actor_ckpt_dir, "huggingface")
model = None
if os.path.exists(huggingface_dir):
has_hf_checkpoint = True
try:
model, auto_model_cls = self._get_config_and_empty_model(huggingface_dir)
auto_model_cls.from_pretrained(huggingface_dir)
except Exception:
has_hf_checkpoint = False

if has_hf_checkpoint:
return
if model is None:
if not self.init_base_model():
self.logger.error(
f"Failed to load base model from {self.base_model_dir}, please check if the model exists."
)
return
model = self.base_model

self.logger.info(f"Converting {checkpoint_dir} to huggingface format...")
dist_cpkt_dir = os.path.join(actor_ckpt_dir, "dist_ckpt")
try:
if os.path.exists(dist_cpkt_dir): # megatron
if not self.init_checkpoint_converter(checkpoint_dir):
return
state_dict = self.checkpoint_converter.get_state_dict(actor_ckpt_dir)
else: # fsdp
from trinity.common.models.utils import (
load_fsdp_state_dict_from_verl_checkpoint,
)

state_dict = load_fsdp_state_dict_from_verl_checkpoint(actor_ckpt_dir)
except Exception:
self.logger.error(
f"Failed to convert {checkpoint_dir} to huggingface format.",
exc_info=True,
)
return

state_dict = {k: v.to(torch.bfloat16) for k, v in state_dict.items()}
model.save_pretrained(huggingface_dir, state_dict=state_dict)
self.logger.info(f"Saved huggingface checkpoint to {huggingface_dir}")

else: # recursive search
for sub_dir in os.listdir(checkpoint_dir):
sub_dir_path = os.path.join(checkpoint_dir, sub_dir)
if os.path.isdir(sub_dir_path):
self.convert(sub_dir_path)


def convert(checkpoint_dir: str, base_model_dir: Optional[str] = None) -> None:
if "global_step_" in checkpoint_dir:
while not os.path.basename(checkpoint_dir).startswith("global_step_"):
checkpoint_dir = os.path.dirname(checkpoint_dir)
converter = Converter(base_model_dir)
converter.convert(checkpoint_dir)


def main() -> None:
"""The main entrypoint."""
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -367,6 +533,22 @@ def main() -> None:
help="The port for Experience Viewer.",
)

convert_parser = subparsers.add_parser(
"convert", help="Convert checkpoint to huggingface format."
)
convert_parser.add_argument(
"--checkpoint-dir",
type=str,
required=True,
help="The path to the checkpoint directory.",
)
convert_parser.add_argument(
"--base-model-dir",
type=str,
default=None,
help="The path to the base model.",
)

args = parser.parse_args()
if args.command == "run":
# TODO: support parse all args from command line
Expand All @@ -383,6 +565,8 @@ def main() -> None:
args.port,
args.plugin_dir,
)
elif args.command == "convert":
convert(args.checkpoint_dir, args.base_model_dir)


if __name__ == "__main__":
Expand Down
3 changes: 2 additions & 1 deletion trinity/common/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,8 @@ def __init__(self, config: ModelMergerConfig):
self.hf_config = AutoConfig.from_pretrained(
self.config.hf_model_config_path, trust_remote_code=self.config.trust_remote_code
)
print(self.hf_config, flush=True)
self.logger = get_logger(__name__)
self.logger.debug(self.hf_config)

self.params_mapping = {
# megatron core gpt model name, huggingface model name
Expand Down