diff --git a/tests/buffer/queue_test.py b/tests/buffer/queue_test.py index e06b133256..ce141909bc 100644 --- a/tests/buffer/queue_test.py +++ b/tests/buffer/queue_test.py @@ -4,7 +4,7 @@ from trinity.buffer.reader.queue_reader import QueueReader from trinity.buffer.writer.queue_writer import QueueWriter from trinity.common.config import BufferConfig, StorageConfig -from trinity.common.constants import AlgorithmType, StorageType +from trinity.common.constants import StorageType from trinity.common.experience import Experience @@ -15,7 +15,7 @@ def test_queue_buffer(self): read_batch_size = 4 meta = StorageConfig( name="test_buffer", - algorithm_type=AlgorithmType.PPO, + algorithm_type="ppo", storage_type=StorageType.QUEUE, ) config = BufferConfig( diff --git a/tests/buffer/sql_test.py b/tests/buffer/sql_test.py index 61ebc46315..5620c38f8e 100644 --- a/tests/buffer/sql_test.py +++ b/tests/buffer/sql_test.py @@ -6,7 +6,7 @@ from trinity.buffer.reader.sql_reader import SQLReader from trinity.buffer.writer.sql_writer import SQLWriter from trinity.common.config import BufferConfig, StorageConfig -from trinity.common.constants import AlgorithmType, StorageType +from trinity.common.constants import StorageType from trinity.common.experience import Experience db_path = os.path.join(os.path.dirname(__file__), "test.db") @@ -19,7 +19,7 @@ def test_create_sql_buffer(self) -> None: read_batch_size = 4 meta = StorageConfig( name="test_buffer", - algorithm_type=AlgorithmType.PPO, + algorithm_type="ppo", path=f"sqlite:///{db_path}", storage_type=StorageType.SQL, ) diff --git a/tests/explorer/runner_pool_test.py b/tests/explorer/runner_pool_test.py index 036339e747..8a6e262a90 100644 --- a/tests/explorer/runner_pool_test.py +++ b/tests/explorer/runner_pool_test.py @@ -10,7 +10,7 @@ from tests.tools import get_unittest_dataset_config from trinity.buffer.reader.queue_reader import QueueReader from trinity.common.config import InferenceModelConfig, StorageConfig, load_config -from trinity.common.constants import AlgorithmType, StorageType +from trinity.common.constants import StorageType from trinity.common.experience import Experience from trinity.common.models.model import InferenceModel from trinity.common.workflows import Task @@ -105,7 +105,7 @@ def setUp(self): ) = StorageConfig( name="test", storage_type=StorageType.QUEUE, - algorithm_type=AlgorithmType.PPO, + algorithm_type="ppo", ) self.queue = QueueReader( self.config.buffer.trainer_input.experience_buffer, self.config.buffer diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index e83b443c4b..5b2795d952 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -15,7 +15,7 @@ get_unittest_dataset_config, ) from trinity.cli.launcher import bench, both, train -from trinity.common.constants import AlgorithmType, MonitorType, SyncMethod +from trinity.common.constants import MonitorType, SyncMethod class BaseTrainerCase(RayUnittestBase): @@ -119,7 +119,7 @@ class TestTrainerGSM8K(BaseTrainerCase): def test_trainer(self): """Test GSM8K.""" # test both mode - self.config.algorithm.algorithm_type = AlgorithmType.GRPO + self.config.algorithm.algorithm_type = "grpo" self.config.algorithm.repeat_times = 4 # self.config.algorithm.repeat_times = 8 # TODO: used for real testing self.config.algorithm.advantage_fn = "grpo" @@ -157,7 +157,7 @@ class TestTrainerGSM8KWithSFT(BaseTrainerCase): def test_trainer(self): """Test GSM8K With SFT.""" # test both mode - self.config.algorithm.algorithm_type = AlgorithmType.GRPO + self.config.algorithm.algorithm_type = "grpo" self.config.algorithm.repeat_times = 4 self.config.algorithm.advantage_fn = "grpo" self.config.algorithm.advantage_fn_args = {} @@ -174,7 +174,7 @@ def test_trainer(self): parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard")) rollout_metrics = parser.metric_list("rollout") self.assertTrue(len(rollout_metrics) > 0) - self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 2) + self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 4) actor_metrics = parser.metric_list("actor") self.assertTrue(len(actor_metrics) > 0) self.assertEqual(parser.metric_max_step(actor_metrics[0]), 2) # SFT @@ -193,7 +193,7 @@ def test_trainer(self): """Test DPO.""" # test both mode self.config.mode = "train" - self.config.algorithm.algorithm_type = AlgorithmType.DPO + self.config.algorithm.algorithm_type = "dpo" self.config.algorithm.policy_loss_fn = "dpo" self.config.algorithm.policy_loss_fn_args = {} # self.config.buffer.batch_size = 32 diff --git a/trinity/algorithm/algorithm.py b/trinity/algorithm/algorithm.py new file mode 100644 index 0000000000..f94798fe85 --- /dev/null +++ b/trinity/algorithm/algorithm.py @@ -0,0 +1,186 @@ +# -*- coding: utf-8 -*- +"""Algorithm classes.""" + +from abc import ABC, ABCMeta +from typing import Dict + +from trinity.buffer.schema.sql_schema import DPODataModel, ExperienceModel, SFTDataModel +from trinity.common.config import Config +from trinity.common.constants import SyncMethod +from trinity.common.experience import Experience, Experiences +from trinity.utils.log import get_logger +from trinity.utils.registry import Registry + +logger = get_logger(__name__) + +ALGORITHM_TYPE = Registry("algorithm") + + +class ConstantMeta(ABCMeta): + def __setattr__(cls, name, value): + if name in cls.__dict__: + raise AttributeError(f"{name} is already defined in {cls.__name__}") + return super().__setattr__(name, value) + + +class AlgorithmType(ABC, metaclass=ConstantMeta): + use_critic: bool + use_reference: bool + use_advantage: bool + use_rollout: bool + can_balance_batch: bool + schema: type + + @classmethod + def gather_experience(cls, exps: list[Experience], pad_token_id: int = 0) -> Experiences: + return Experiences.gather_experiences(exps, pad_token_id) + + @classmethod + def get_default_config(cls) -> Dict: + raise NotImplementedError + + @classmethod + def name(cls) -> str: + return cls._name + + @classmethod + def check_config(cls, config: Config) -> None: + pass + + +@ALGORITHM_TYPE.register_module("sft") +class SFTAlgorithm(AlgorithmType): + """SFT Algorithm.""" + + use_critic: bool = False + use_reference: bool = False + use_advantage: bool = False + use_rollout: bool = False + can_balance_batch: bool = True + schema: type = SFTDataModel + + @classmethod + def get_default_config(cls) -> Dict: + return { + "policy_loss_fn": "sft", + "kl_loss_fn": "none", + "entropy_loss_fn": "none", + } + + +@ALGORITHM_TYPE.register_module("ppo") +class PPOAlgorithm(AlgorithmType): + """PPO Algorithm.""" + + use_critic: bool = True + use_reference: bool = True + use_advantage: bool = True + use_rollout: bool = True + can_balance_batch: bool = True + schema: type = ExperienceModel + + @classmethod + def get_default_config(cls) -> Dict: + return { + "repeat_times": 1, + "policy_loss_fn": "ppo", + "advantage_fn": "ppo", + "kl_penalty_fn": "none", + "kl_loss_fn": "k2", + "entropy_loss_fn": "basic", + } + + +@ALGORITHM_TYPE.register_module("grpo") +class GRPOAlgorithm(AlgorithmType): + """GRPO algorithm.""" + + use_critic: bool = False + use_reference: bool = True + use_advantage: bool = True + use_rollout: bool = True + can_balance_batch: bool = True + schema: type = ExperienceModel + + @classmethod + def get_default_config(cls) -> Dict: + return { + "repeat_times": 2, + "policy_loss_fn": "ppo", + "advantage_fn": "grpo", + "kl_penalty_fn": "none", + "kl_loss_fn": "k2", + "entropy_loss_fn": "basic", + } + + +@ALGORITHM_TYPE.register_module("opmd") +class OPMDAlgorithm(AlgorithmType): + """OPMD algorithm.""" + + use_critic: bool = False + use_reference: bool = True + use_advantage: bool = True + use_rollout: bool = True + can_balance_batch: bool = True + schema: type = ExperienceModel + + @classmethod + def get_default_config(cls) -> Dict: + return { + "repeat_times": 2, + "policy_loss_fn": "opmd", + "advantage_fn": "opmd", + "kl_penalty_fn": "none", + "kl_loss_fn": "k2", + "entropy_loss_fn": "basic", + } + + +@ALGORITHM_TYPE.register_module("dpo") +class DPOAlgorithm(AlgorithmType): + """DPO algorithm.""" + + use_critic: bool = False + use_reference: bool = True + use_advantage: bool = False + use_rollout: bool = False + can_balance_batch: bool = False + schema: type = DPODataModel + + @classmethod + def gather_experience(cls, exps: list[Experience], pad_token_id: int = 0) -> Experiences: + return Experiences.gather_dpo_experiences(exps, pad_token_id) + + @classmethod + def get_default_config(cls) -> Dict: + return { + "repeat_times": 2, # fake repeat times + "policy_loss_fn": "dpo", + "kl_loss_fn": "k2", + "entropy_loss_fn": "basic", + } + + @classmethod + def check_config(cls, config: Config) -> None: + if config.model == "train": + if ( + config.buffer.trainer_input.experience_buffer is None + or not config.buffer.trainer_input.experience_buffer.path + ): + raise ValueError( + "`buffer.trainer_input.experience_buffer.path` is required when `algorithm.algorithm_type == dpo`" + ) + elif config.mode in ["both", "explore"]: + raise ValueError(f"DPO does not support `{config.mode}` mode") + + if config.synchronizer.sync_method != SyncMethod.CHECKPOINT: + config.synchronizer.sync_method = SyncMethod.CHECKPOINT + logger.warning( + "DPO only supports checkpoint synchronization, set `synchronizer.sync_method` to `checkpoint`." + ) + if config.algorithm.repeat_times != 2: + config.algorithm.repeat_times = 2 + logger.warning( + "DPO only supports 2 repeat times, set `algorithm.repeat_times` to 2." + ) # no need to warn diff --git a/trinity/algorithm/algorithm_manager.py b/trinity/algorithm/algorithm_manager.py new file mode 100644 index 0000000000..3c2983c80b --- /dev/null +++ b/trinity/algorithm/algorithm_manager.py @@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- +"""AlgorithmManager for switching between SFT and RFT.""" + +from trinity.algorithm.algorithm import ALGORITHM_TYPE +from trinity.algorithm.entropy_loss_fn.entropy_loss_fn import ENTROPY_LOSS_FN +from trinity.algorithm.kl_fn.kl_fn import KL_FN +from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN +from trinity.common.config import AlgorithmConfig, Config + + +class AlgorithmManager: + def __init__(self, config: Config): + self.config = config + sft_type = ALGORITHM_TYPE.get("sft") + sft_default_config = sft_type.get_default_config() + self.sft_algorithm_config = AlgorithmConfig( + algorithm_type="sft", + **sft_default_config, + ) + policy_fn_cls = POLICY_LOSS_FN.get(self.sft_algorithm_config.policy_loss_fn) + self.sft_algorithm_config.policy_loss_fn_args = policy_fn_cls.default_args() + kl_loss_fn_cls = KL_FN.get(self.sft_algorithm_config.kl_loss_fn) + self.sft_algorithm_config.kl_loss_fn_args = kl_loss_fn_cls.default_args() + entropy_loss_fn_cls = ENTROPY_LOSS_FN.get(self.sft_algorithm_config.entropy_loss_fn) + self.sft_algorithm_config.entropy_loss_fn_args = entropy_loss_fn_cls.default_args() + + def get_current_algorithm_config(self, global_steps: int): + if global_steps <= self.config.buffer.trainer_input.sft_warmup_steps: + return self.sft_algorithm_config + else: + return self.config.algorithm + + def need_save(self, global_steps: int): + return global_steps == self.config.buffer.trainer_input.sft_warmup_steps diff --git a/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py b/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py index 4df9272ca0..cf102dd6b7 100644 --- a/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py +++ b/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py @@ -61,3 +61,25 @@ def __call__( @classmethod def default_args(cls) -> Dict: return {"entropy_coef": 0.0} + + +@ENTROPY_LOSS_FN.register_module("none") +class DummyEntropyLossFn(EntropyLossFn): + """ + Dummy entropy loss function. + """ + + def __init__(self): + pass + + def __call__( + self, + entropy: torch.Tensor, + action_mask: torch.Tensor, + **kwargs, + ) -> Tuple[torch.Tensor, Dict]: + return torch.tensor(0.0), {} + + @classmethod + def default_args(cls) -> Dict: + return {} diff --git a/trinity/algorithm/kl_fn/kl_fn.py b/trinity/algorithm/kl_fn/kl_fn.py index 95d2915a84..62ed48cd49 100644 --- a/trinity/algorithm/kl_fn/kl_fn.py +++ b/trinity/algorithm/kl_fn/kl_fn.py @@ -102,7 +102,7 @@ def default_args(cls): @KL_FN.register_module("none") -class DummyFn(KLFn): +class DummyKLFn(KLFn): """ Dummy KL function. """ diff --git a/trinity/buffer/buffer.py b/trinity/buffer/buffer.py index 09ff663c47..9d77dbb379 100644 --- a/trinity/buffer/buffer.py +++ b/trinity/buffer/buffer.py @@ -41,9 +41,9 @@ def get_buffer_reader(storage_config: StorageConfig, buffer_config: BufferConfig elif storage_config.storage_type == StorageType.FILE: from trinity.buffer.reader.file_reader import FILE_READERS - file_read_type = storage_config.algorithm_type - if file_read_type is not None: - file_read_type = file_read_type.value + algorithm_type = storage_config.algorithm_type + if algorithm_type is not None: + file_read_type = algorithm_type else: file_read_type = "rollout" return FILE_READERS.get(file_read_type)(storage_config, buffer_config) diff --git a/trinity/buffer/reader/file_reader.py b/trinity/buffer/reader/file_reader.py index 316d3ae297..58b762d3f2 100644 --- a/trinity/buffer/reader/file_reader.py +++ b/trinity/buffer/reader/file_reader.py @@ -6,9 +6,10 @@ import transformers from datasets import load_dataset +from trinity.algorithm.algorithm import DPOAlgorithm, SFTAlgorithm from trinity.buffer.buffer_reader import BufferReader from trinity.common.config import BufferConfig, StorageConfig -from trinity.common.constants import AlgorithmType, PromptType, ReadStrategy, TaskType +from trinity.common.constants import PromptType, ReadStrategy, TaskType from trinity.common.experience import Experience from trinity.common.rewards import REWARD_FUNCTIONS from trinity.common.workflows import WORKFLOWS, Task @@ -17,7 +18,7 @@ FILE_READERS = Registry("file_readers") -@FILE_READERS.register_module(AlgorithmType.SFT.value) +@FILE_READERS.register_module(SFTAlgorithm.name()) class SFTDataReader(BufferReader): """Reader for SFT file data.""" @@ -96,7 +97,7 @@ def read(self, strategy: Optional[ReadStrategy] = None) -> List: return exp_list -@FILE_READERS.register_module(AlgorithmType.DPO.value) +@FILE_READERS.register_module(DPOAlgorithm.name()) class DPODataReader(BufferReader): def __init__(self, meta: StorageConfig, config: BufferConfig): self.split = meta.split diff --git a/trinity/buffer/schema/sql_schema.py b/trinity/buffer/schema/sql_schema.py index db2e4ca137..21289c7768 100644 --- a/trinity/buffer/schema/sql_schema.py +++ b/trinity/buffer/schema/sql_schema.py @@ -5,9 +5,7 @@ from sqlalchemy import Column, Float, Integer, LargeBinary, String from sqlalchemy.ext.declarative import declarative_base -from trinity.common.constants import AlgorithmType from trinity.common.experience import Experience -from trinity.common.models.utils import tokenize_and_mask_messages_hf Base = declarative_base() @@ -85,6 +83,8 @@ def from_messages( chat_template: Optional[str] = None, ) -> "SFTDataModel": """Convert a list of messages into a single instance of SFT data.""" + from trinity.common.models.utils import tokenize_and_mask_messages_hf + token_ids, action_mask = tokenize_and_mask_messages_hf( tokenizer=tokenizer, messages=messages, @@ -125,22 +125,15 @@ def to_experience(self) -> Experience: return exp -SCHEMA_MAPPING = { - None: TaskModel, - AlgorithmType.SFT: SFTDataModel, - AlgorithmType.PPO: ExperienceModel, - AlgorithmType.GRPO: ExperienceModel, - AlgorithmType.OPMD: ExperienceModel, - AlgorithmType.DPO: DPODataModel, -} - - -def create_dynamic_table(algorithm_type: Union[AlgorithmType | None], table_name: str) -> Any: +def create_dynamic_table(algorithm_type: Union[str | None], table_name: str) -> Any: """Create a dynamic table based on the provided algorithm type and table name.""" - if algorithm_type not in SCHEMA_MAPPING: - raise ValueError(f"Unknown schema: {algorithm_type}") + if algorithm_type is None: + base_class = TaskModel + else: + from trinity.algorithm.algorithm import ALGORITHM_TYPE - base_class = SCHEMA_MAPPING[algorithm_type] + algorithm = ALGORITHM_TYPE.get(algorithm_type) + base_class = algorithm.schema table_attrs = { "__tablename__": table_name, diff --git a/trinity/buffer/writer/sql_writer.py b/trinity/buffer/writer/sql_writer.py index 7464064037..e0b0bdf640 100644 --- a/trinity/buffer/writer/sql_writer.py +++ b/trinity/buffer/writer/sql_writer.py @@ -5,6 +5,7 @@ from sqlalchemy.orm import sessionmaker from sqlalchemy.pool import NullPool +from trinity.algorithm.algorithm import ALGORITHM_TYPE from trinity.buffer.buffer_writer import BufferWriter from trinity.buffer.schema import Base, create_dynamic_table from trinity.buffer.utils import retry_session @@ -22,7 +23,8 @@ def __init__(self, meta: StorageConfig, config: BufferConfig) -> None: assert meta.storage_type == StorageType.SQL # we only support write RFT algorithm buffer for now # TODO: support other algorithms - assert meta.algorithm_type.is_rft, "Only RFT buffer is supported for writing." + algorithm = ALGORITHM_TYPE.get(meta.algorithm_type) + assert algorithm.use_rollout, "Only RFT buffer is supported for writing." self.engine = create_engine(meta.path, poolclass=NullPool) self.table_model_cls = create_dynamic_table(meta.algorithm_type, meta.name) diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index 9dfe4df8ee..6a01bfb688 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -8,7 +8,6 @@ import ray from trinity.common.config import Config, load_config -from trinity.common.constants import AlgorithmType from trinity.explorer.explorer import Explorer from trinity.trainer.trainer import Trainer from trinity.utils.log import get_logger @@ -49,20 +48,8 @@ def train(config: Config) -> None: trainer = Trainer.remote(config) ray.get(trainer.prepare.remote()) - if config.buffer.trainer_input.sft_warmup_steps > 0: - while True: - train_continue, train_step_num = ray.get( - trainer.train_one_period.remote(AlgorithmType.SFT) - ) - if train_step_num <= config.buffer.trainer_input.sft_warmup_steps: - logger.info(f"SFT warmup step {train_step_num} finished.") - if not train_continue: - logger.info("SFT warmup finished.") - break - - algo_type = config.algorithm.algorithm_type try: - ray.get(trainer.train.remote(algo_type)) + ray.get(trainer.train.remote()) logger.info("Train finished.") ray.get(trainer.shutdown.remote()) except Exception as e: @@ -93,23 +80,10 @@ def both(config: Config) -> None: # sync weight before training start ray.get([explorer.sync_weight.remote(), trainer.sync_weight.remote()]) - if config.buffer.trainer_input.sft_warmup_steps > 0: - while True: - train_continue, train_step_num = ray.get( - trainer.train_one_period.remote(AlgorithmType.SFT) - ) - if train_step_num <= config.buffer.trainer_input.sft_warmup_steps: - logger.info(f"SFT warmup step {train_step_num} finished.") - if not train_continue: - logger.info("SFT warmup finished.") - break - ray.get([explorer.sync_weight.remote(), trainer.sync_weight.remote()]) - - algo_type = config.algorithm.algorithm_type while True: try: ref_explore = explorer.explore_one_period.remote() - ref_train = trainer.train_one_period.remote(algo_type) + ref_train = trainer.train_one_period.remote() explore_continue, explore_step_num = ray.get(ref_explore) train_continue, train_step_num = ray.get(ref_train) if not explore_continue: diff --git a/trinity/common/config.py b/trinity/common/config.py index 91c7790571..dd863edbd3 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -7,7 +7,6 @@ from omegaconf import OmegaConf from trinity.common.constants import ( - AlgorithmType, MonitorType, PromptType, ReadStrategy, @@ -84,7 +83,7 @@ class StorageConfig: rollout_args: GenerationConfig = field(default_factory=GenerationConfig) # ! DO NOT SET, automatically set from algorithm.algorithm_type - algorithm_type: Optional[AlgorithmType] = None + algorithm_type: Optional[str] = None # ! DO NOT SET, automatically set from buffer.total_epochs total_epochs: int = 1 # automatically set @@ -170,27 +169,27 @@ class InferenceModelConfig: class AlgorithmConfig: """Config for algorithm.""" - algorithm_type: AlgorithmType = AlgorithmType.PPO + algorithm_type: str = "ppo" # for GRPO-like algorithms, repeat each task for `repeat_times` times repeat_times: int = 1 - policy_loss_fn: str = "ppo" + policy_loss_fn: Optional[str] = None # "ppo" # If not set, use PolicyLossFn.default_args() policy_loss_fn_args: Optional[dict] = None - advantage_fn: str = "ppo" + advantage_fn: Optional[str] = None # "ppo" # If not set, use AdvantageFn.default_args() advantage_fn_args: Optional[dict] = None - kl_penalty_fn: str = "none" # set to "none" to disable kl penalty in reward + kl_penalty_fn: Optional[str] = None # "none" # set to "none" to disable kl penalty in reward # If not set, use kl_penalty_fn.default_args() kl_penalty_fn_args: Optional[dict] = None - kl_loss_fn: str = "k2" # set to "none" to disable kl loss + kl_loss_fn: Optional[str] = None # "k2" # set to "none" to disable kl loss # If not set, use kl_loss_fn.default_args() kl_loss_fn_args: Optional[dict] = None - entropy_loss_fn: str = "basic" + entropy_loss_fn: Optional[str] = None # "basic" # If not set, use entropy_loss_fn.default_args() entropy_loss_fn_args: Optional[dict] = None @@ -198,6 +197,15 @@ class AlgorithmConfig: # TODO: move this to SFT warmup use_token_level_loss: bool = True + # do not set + algorithm_manager: Optional[Any] = None + + def get_current_algorithm_config(self, global_steps: int): + return self.algorithm_manager.get_current_algorithm_config(global_steps) + + def need_save(self, global_steps: int): + return self.algorithm_manager.need_save(global_steps) + @dataclass class ClusterConfig: @@ -351,32 +359,25 @@ def _check_deprecated(self) -> None: def _check_interval(self) -> None: assert self.synchronizer.sync_interval > 0 - # check eval_interval - if ( - self.mode != "bench" - and self.algorithm.algorithm_type != AlgorithmType.DPO - and self.explorer.eval_interval % self.synchronizer.sync_interval != 0 - ): - self.explorer.eval_interval = ( - max(self.explorer.eval_interval // self.synchronizer.sync_interval, 1) - ) * self.synchronizer.sync_interval - logger.warning( - f"`eval_interval` is not a multiple of `sync_interval`; adjusted to the nearest integer={self.explorer.eval_interval}." - ) - - # check save_interval - if ( - self.mode != "bench" - and self.algorithm.algorithm_type != AlgorithmType.DPO - and self.synchronizer.sync_method == SyncMethod.CHECKPOINT - ): - if self.trainer.save_interval != self.synchronizer.sync_interval: + if self.mode != "bench" and self.algorithm.algorithm_type != "dpo": # TODO + # check eval_interval + if self.explorer.eval_interval % self.synchronizer.sync_interval != 0: + self.explorer.eval_interval = ( + max(self.explorer.eval_interval // self.synchronizer.sync_interval, 1) + ) * self.synchronizer.sync_interval logger.warning( - f"When `algorithm.algorithm_type` != `DPO` and `synchronizer.sync_method` == `checkpoint`, " - f"`trainer.save_interval` will be set to " - f"`synchronizer.sync_interval = {self.synchronizer.sync_interval}`." + f"`eval_interval` is not a multiple of `sync_interval`; adjusted to the nearest integer={self.explorer.eval_interval}." ) - self.trainer.save_interval = self.synchronizer.sync_interval + + # check save_interval + if self.synchronizer.sync_method == SyncMethod.CHECKPOINT: + if self.trainer.save_interval != self.synchronizer.sync_interval: + logger.warning( + f"When `algorithm.algorithm_type` != `dpo` and `synchronizer.sync_method` == `checkpoint`, " + f"`trainer.save_interval` will be set to " + f"`synchronizer.sync_interval = {self.synchronizer.sync_interval}`." + ) + self.trainer.save_interval = self.synchronizer.sync_interval def _check_buffer(self) -> None: # noqa: C901 # check explorer_input @@ -440,14 +441,7 @@ def _check_buffer(self) -> None: # noqa: C901 f"Auto set `buffer.trainer_input.experience_buffer` to {self.buffer.trainer_input.experience_buffer}" ) elif self.mode == "train": # TODO: to be check - if self.algorithm.algorithm_type.is_dpo(): - if ( - self.buffer.trainer_input.experience_buffer is None - or not self.buffer.trainer_input.experience_buffer.path - ): - raise ValueError( - "`buffer.trainer_input.experience_buffer.path` is required when `algorithm.algorithm_type == AlgorithmType.DPO`" - ) + pass if self.buffer.trainer_input.experience_buffer is not None: self.buffer.trainer_input.experience_buffer.algorithm_type = ( self.algorithm.algorithm_type @@ -468,7 +462,7 @@ def _check_buffer(self) -> None: # noqa: C901 "`buffer.trainer_input.sft_warmup_dataset` is required when `buffer.trainer_input.sft_warmup_steps` > 0" ) if self.buffer.trainer_input.sft_warmup_dataset is not None: - self.buffer.trainer_input.sft_warmup_dataset.algorithm_type = AlgorithmType.SFT + self.buffer.trainer_input.sft_warmup_dataset.algorithm_type = "sft" # TODO # set read_batch_size / pad_token_id / tokenizer_path self.buffer.read_batch_size = self.buffer.batch_size * self.algorithm.repeat_times @@ -491,6 +485,21 @@ def _check_algorithm(self) -> None: KL_FN, POLICY_LOSS_FN, ) + from trinity.algorithm.algorithm import ALGORITHM_TYPE + + algorithm = ALGORITHM_TYPE.get(self.algorithm.algorithm_type) + algorithm.check_config(self) + default_config = { + "policy_loss_fn": "ppo", + "advantage_fn": "ppo", + "kl_penalty_fn": "none", + "kl_loss_fn": "k2", + "entropy_loss_fn": "basic", + } + default_config.update(algorithm.get_default_config()) + for key, value in default_config.items(): + if getattr(self.algorithm, key, None) is None: + setattr(self.algorithm, key, value) policy_fn_cls = POLICY_LOSS_FN.get(self.algorithm.policy_loss_fn) if policy_fn_cls is None: @@ -526,11 +535,12 @@ def check_and_update(self) -> None: # noqa: C901 """Check and update the config.""" self._check_deprecated() + # check algorithm + self._check_algorithm() + # check mode if self.mode not in ["explore", "train", "both", "bench"]: raise ValueError(f"Invalid mode: {self.mode}") - if self.algorithm.algorithm_type == AlgorithmType.DPO and self.mode == "both": - raise ValueError("DPO does not support `both` mode") # prepare for the checkpoint directory if not os.path.isabs(self.checkpoint_root_dir): @@ -545,9 +555,6 @@ def check_and_update(self) -> None: # noqa: C901 if not self.model.critic_model_path: self.model.critic_model_path = self.model.model_path - # check algorithm - self._check_algorithm() - # check explorer if ( self.explorer.rollout_model.engine_type != "vllm_async" @@ -572,17 +579,6 @@ def check_and_update(self) -> None: # noqa: C901 logger.warning( f"`{self.mode}` mode only supports checkpoint synchronization, set `synchronizer.sync_method` to `checkpoint`." ) - if ( - self.algorithm.algorithm_type == AlgorithmType.DPO - and self.synchronizer.sync_method != SyncMethod.CHECKPOINT - ): - self.synchronizer.sync_method = SyncMethod.CHECKPOINT - logger.warning( - "DPO only supports checkpoint synchronization, set `synchronizer.sync_method` to `checkpoint`." - ) - if self.algorithm.algorithm_type == AlgorithmType.DPO and self.algorithm.repeat_times != 2: - self.algorithm.repeat_times = 2 - logger.warning("DPO only supports 2 repeat times, set `algorithm.repeat_times` to 2.") self._check_interval() diff --git a/trinity/common/constants.py b/trinity/common/constants.py index 860cd39027..47b04f853b 100644 --- a/trinity/common/constants.py +++ b/trinity/common/constants.py @@ -62,34 +62,6 @@ class StorageType(CaseInsensitiveEnum): FILE = "file" -class AlgorithmType(CaseInsensitiveEnum): - """Algorithm Type.""" - - SFT = "sft" - PPO = "ppo" - GRPO = "grpo" - OPMD = "opmd" - PAIRWISE_OPMD = "pairwise_opmd" - DPO = "dpo" - - def is_rft(self) -> bool: - """Check if the algorithm is RFT.""" - return self in [ - AlgorithmType.PPO, - AlgorithmType.GRPO, - AlgorithmType.OPMD, - AlgorithmType.PAIRWISE_OPMD, - ] - - def is_sft(self) -> bool: - """Check if the algorithm is SFT.""" - return self == AlgorithmType.SFT - - def is_dpo(self) -> bool: - """Check if the algorithm is DPO.""" - return self == AlgorithmType.DPO - - class MonitorType(CaseInsensitiveEnum): """Monitor Type.""" diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index e8180f4718..644fe9a8f5 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -4,9 +4,8 @@ from omegaconf import OmegaConf +from trinity.algorithm.algorithm import DPOAlgorithm from trinity.common.config import BufferConfig, Config, SynchronizerConfig -from trinity.common.constants import AlgorithmType -from trinity.trainer.verl.ray_trainer import AdvantageEstimator from trinity.utils.log import get_logger logger = get_logger(__name__) @@ -79,7 +78,7 @@ class Actor: checkpoint: Checkpoint = field(default_factory=Checkpoint) optim: Optim = field(default_factory=Optim) fsdp_config: FSDPConfig = field(default_factory=FSDPConfig) - algorithm_type: AlgorithmType = AlgorithmType.PPO + algorithm_type: str = "ppo" # TODO tau: float = 0.001 # strength of regularization w.r.t. old / ref policy opmd_baseline: str = "mean" # mean / logavgexp, applicable to opmd use_uid: bool = False # True / False, applicable to pairwise_opmd @@ -95,8 +94,15 @@ class Ref: ulysses_sequence_parallel_size: int = 1 +@dataclass +class _ValKwargs: + do_sample: bool = False + + @dataclass class Rollout: + # do not set + val_kwargs: _ValKwargs = field(default_factory=_ValKwargs) temperature: float = 1.0 n: int = 1 # > 1 for grpo @@ -318,12 +324,6 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901 if adv_fn_args is not None and "lam" in adv_fn_args: self.algorithm.lam = adv_fn_args["lam"] self.actor_rollout_ref.actor.algorithm_type = config.algorithm.algorithm_type - if config.algorithm.algorithm_type == AlgorithmType.PPO: - logger.info("Setting `adv_estimator` to 'gae' for PPO") - self.algorithm.adv_estimator = AdvantageEstimator.GAE.value - elif config.algorithm.algorithm_type in (AlgorithmType.GRPO, AlgorithmType.OPMD): - logger.info("Setting `adv_estimator` to 'grpo' for GRPO/OPMD") - self.algorithm.adv_estimator = AdvantageEstimator.GRPO.value self.actor_rollout_ref.actor.use_kl_loss = config.algorithm.kl_loss_fn != "none" self.actor_rollout_ref.actor.kl_loss_coef = config.algorithm.kl_loss_fn_args["kl_coef"] # type: ignore self.actor_rollout_ref.actor.entropy_coeff = config.algorithm.entropy_loss_fn_args[ # type: ignore @@ -334,7 +334,7 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901 # Need to double check whether this is indeed the case, # and see if adv_estimator can be removed completely. - if self.actor_rollout_ref.actor.algorithm_type.is_dpo(): # for DPO + if isinstance(self.actor_rollout_ref.actor.algorithm_type, DPOAlgorithm): # for DPO if not self.actor_rollout_ref.actor.use_kl_loss: self.actor_rollout_ref.actor.use_kl_loss = True logger.warning("DPO must use KL loss.") diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 9c3cc414c7..37257f71ce 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -8,6 +8,7 @@ import ray import torch +from trinity.algorithm.algorithm_manager import AlgorithmManager from trinity.buffer import get_buffer_writer from trinity.buffer.buffer import get_buffer_reader from trinity.common.config import Config @@ -33,6 +34,7 @@ def __init__(self, config: Config): explorer_meta = self.cache.load_explorer() self.step_num = explorer_meta.get("latest_iteration", 0) self.config = config + self.algorithm_manager = AlgorithmManager(config) self.models, self.auxiliary_models = create_inference_models(config) if self.config.mode != "bench": self.experience_buffer = get_buffer_writer( @@ -177,6 +179,15 @@ def explore_one_period(self) -> Tuple[bool, int]: explore_status: whether there are more tasks to explore. explore_step_num: the number of explore steps """ + # skip for sft + algo_config = self.algorithm_manager.get_current_algorithm_config(self.step_num + 1) + if algo_config.algorithm_type == "sft": + for _ in range(self.config.synchronizer.sync_interval): + self.step_num += 1 + if self.algorithm_manager.need_save(self.step_num): + break + return True, self.step_num + task_num_per_period = self.config.synchronizer.sync_interval * self.config.buffer.batch_size st = time.time() diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py index c2ad5fec96..95859685ee 100644 --- a/trinity/trainer/trainer.py +++ b/trinity/trainer/trainer.py @@ -12,10 +12,11 @@ import ray +from trinity.algorithm.algorithm import ALGORITHM_TYPE, SFTAlgorithm +from trinity.algorithm.algorithm_manager import AlgorithmManager from trinity.buffer import get_buffer_reader -from trinity.common.config import AlgorithmConfig, Config -from trinity.common.constants import AlgorithmType, SyncMethod -from trinity.common.experience import Experiences +from trinity.common.config import Config +from trinity.common.constants import SyncMethod from trinity.utils.log import get_logger @@ -26,6 +27,7 @@ class Trainer: def __init__(self, config: Config) -> None: self.config = config self.logger = get_logger(__name__) + self.algorithm_manager = AlgorithmManager(config) self.train_buffer = get_buffer_reader( self.config.buffer.trainer_input.experience_buffer, # type: ignore self.config.buffer, @@ -44,86 +46,54 @@ def prepare(self) -> None: """Prepare the trainer.""" self.engine.prepare() - def train(self, algo_type: AlgorithmType = AlgorithmType.PPO): + def train(self): """Train the model.""" while True: - train_status, _ = self.train_step(algo_type) + train_status, _ = self.train_step() if not train_status: break - def train_one_period(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> Tuple[bool, int]: + def train_one_period(self) -> Tuple[bool, int]: """Train for one period. Each period contains `sync_interval` steps. Returns: train_status: Whether to continue training. train_step_num: The number of training steps""" for _ in range(self.config.synchronizer.sync_interval): - train_status, train_step_num = self.train_step(algo_type) + train_status, train_step_num = self.train_step() if not train_status: return False, train_step_num self.logger.info(f"Train step {train_step_num} finished.") return True, train_step_num - def train_step(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> Tuple[bool, int]: + def train_step(self) -> Tuple[bool, int]: """Train one step. - Args: - algo_type (AlgorithmType): The type of data to be used for training. - Defaults to AlgorithmType.PPO. - Returns: bool: Whether to continue training. """ - if algo_type.is_sft(): - algorithm_config = AlgorithmConfig( - algorithm_type=AlgorithmType.SFT, - policy_loss_fn="sft", - policy_loss_fn_args={ - "use_token_level_loss": self.config.algorithm.use_token_level_loss - }, - kl_loss_fn="none", - kl_loss_fn_args={}, - entropy_loss_fn="basic", - entropy_loss_fn_args=self.config.algorithm.entropy_loss_fn_args, - ) - self.engine.set_algorithm(algorithm_config) - else: - self.engine.set_algorithm(self.config.algorithm) - if algo_type.is_rft() and self.config.buffer.trainer_input.read_experience_strategy: + algo_config = self.algorithm_manager.get_current_algorithm_config( + self.engine.train_step_num + 1 + ) + algo_type = algo_config.algorithm_type + algorithm = ALGORITHM_TYPE.get(algo_type) + if algorithm.use_rollout: strategy = self.config.buffer.trainer_input.read_experience_strategy else: strategy = None try: - if algo_type.is_sft(): + if algorithm == SFTAlgorithm: exps = self.sft_warmup_buffer.read() else: exps = self.train_buffer.read(strategy=strategy) except StopIteration: self.logger.warning("No more data to train. Stop training.") - return False, 0 # TODO: get the actual step number - - if algo_type.is_sft(): - return self.engine.train_sft_step( - Experiences.gather_experiences( - exps, - pad_token_id=self.config.buffer.pad_token_id, # type: ignore - ) - ) - elif algo_type.is_rft(): - return self.engine.train_rft_step( - Experiences.gather_experiences( - exps, - pad_token_id=self.config.buffer.pad_token_id, # type: ignore - ) - ) - elif algo_type.is_dpo(): - return self.engine.train_dpo_step( - Experiences.gather_dpo_experiences( - exps, - pad_token_id=self.config.buffer.pad_token_id, # type: ignore - ) - ) - else: - raise ValueError(f"Unsupported algorithm type: {algo_type}") + return False, self.engine.train_step_num + + experiences = algorithm.gather_experience( + exps, + pad_token_id=self.config.buffer.pad_token_id, # type: ignore + ) + return self.engine.train_step(experiences) def sync_weight(self) -> None: """Sync the model weight.""" @@ -136,7 +106,7 @@ def flush_log(self, step: int) -> None: def shutdown(self) -> None: # if checkpoint not saved, save the last checkpoint - step_num = self.engine.global_steps - 1 + step_num = self.engine.train_step_num path = os.path.join(self.config.checkpoint_job_dir, f"global_step_{step_num}") if not os.path.isdir(path) or len(os.listdir(path)) == 0: self.engine.save_checkpoint() @@ -150,17 +120,14 @@ class TrainEngineWrapper(ABC): def prepare(self) -> None: """Do some preparation before training started.""" + @property @abstractmethod - def train_rft_step(self, experiences) -> Tuple[bool, int]: - """Train on the RFT data.""" + def train_step_num(self) -> int: + """Get the current training step number.""" @abstractmethod - def train_sft_step(self, experiences) -> Tuple[bool, int]: - """Train on the SFT data.""" - - @abstractmethod - def train_dpo_step(self, experiences) -> Tuple[bool, int]: - """Train on the DPO data.""" + def train_step(self, experiences) -> Tuple[bool, int]: + """Training.""" @abstractmethod def save_checkpoint(self) -> None: @@ -170,10 +137,6 @@ def save_checkpoint(self) -> None: def sync_weight(self) -> None: """Sync the model weight.""" - @abstractmethod - def set_algorithm(self, algorithm_config: AlgorithmConfig) -> None: - """Set training algorithm config.""" - @abstractmethod def shutdown(self) -> None: """Shutdown the engine.""" diff --git a/trinity/trainer/verl/core_algos.py b/trinity/trainer/verl/core_algos.py deleted file mode 100644 index f104e0f4f4..0000000000 --- a/trinity/trainer/verl/core_algos.py +++ /dev/null @@ -1,717 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2022 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Modified from core_algos.py -""" - -from abc import ABC, abstractmethod -from collections import defaultdict - -import numpy as np -import torch -import torch.nn.functional as F -import verl.utils.torch_functional as verl_F - -from trinity.common.constants import AlgorithmType - - -class KLController(ABC): - @abstractmethod - def update(self, current_kl, n_steps): - """update value""" - - -class AdaptiveKLController(KLController): - """ - Adaptive KL controller described in the paper: - https://arxiv.org/pdf/1909.08593.pdf - """ - - def __init__(self, init_kl_coef, target_kl, horizon): - self.value = init_kl_coef - self.target = target_kl - self.horizon = horizon - - def update(self, current_kl, n_steps): - target = self.target - proportional_error = np.clip(current_kl / target - 1, -0.2, 0.2) - mult = 1 + proportional_error * n_steps / self.horizon - self.value *= mult - - -class FixedKLController(KLController): - """Fixed KL controller.""" - - def __init__(self, kl_coef): - self.value = kl_coef - - def update(self, current_kl, n_steps): - pass - - -def get_kl_controller(kl_config): - if kl_config.type == "fixed": - return FixedKLController(kl_coef=kl_config.kl_coef) - elif kl_config.type == "adaptive": - assert kl_config.horizon > 0, f"horizon must be larger than 0. Got {kl_config.horizon}" - return AdaptiveKLController( - init_kl_coef=kl_config.kl_coef, - target_kl=kl_config.target_kl, - horizon=kl_config.horizon, - ) - else: - raise ValueError("Unknown kl_ctrl type") - - -def compute_opmd_outcome_advantage( - token_level_rewards: torch.Tensor, - eos_mask: torch.Tensor, - index: torch.Tensor, - opmd_baseline: str = "mean", - tau: float = 1.0, -): - """Modified from compute_grpo_outcome_advantage - - Compute advantage for OPMD, operating only on Outcome reward - (with only one scalar reward for each response). - Args: - token_level_rewards: `(torch.Tensor)` - shape: (bs, response_length) - eos_mask: `(torch.Tensor)` - shape: (bs, response_length) - - Returns: - advantages: `(torch.Tensor)` - shape: (bs, response_length) - Returns: `(torch.Tensor)` - shape: (bs, response_length) - """ - response_length = token_level_rewards.shape[-1] - scores = token_level_rewards.sum(dim=-1) - - id2score = defaultdict(list) - id2baseline = {} - - with torch.no_grad(): - bsz = scores.shape[0] - for i in range(bsz): - id2score[index[i]].append(scores[i]) - for idx in id2score: - if len(id2score[idx]) == 1: - id2baseline[idx] = torch.tensor(0.0) - # TODO: consider id2baseline[idx] = id2score[idx] (so that this sample won't take effect?) - elif len(id2score[idx]) > 1: - if opmd_baseline == "mean": - id2baseline[idx] = torch.mean(torch.tensor(id2score[idx])) - elif opmd_baseline == "logavgexp": - rewards_tensor = torch.tensor(id2score[idx]) - # NOTE: we use the fact that logavgexp(x) = logsumexp(x) - log(len(x)). - # Hopefully the logsumexp calculation is numerically stable (as claimed by PyTorch's doc) - # in cases where tau is small... - id2baseline[idx] = tau * ( - torch.logsumexp(rewards_tensor / tau, dim=-1) - - torch.log(torch.tensor(len(id2score[idx]))) - ) - else: - raise NotImplementedError - else: - raise ValueError(f"no score in prompt index: {idx}") - for i in range(bsz): - scores[i] = scores[i] - id2baseline[index[i]] - scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask - - return scores, scores - - -def compute_gae_advantage_return( - token_level_rewards: torch.Tensor, - values: torch.Tensor, - eos_mask: torch.Tensor, - gamma: float, - lam: float, -): - """Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py - - Args: - token_level_rewards: `(torch.Tensor)` - shape: (bs, response_length) - values: `(torch.Tensor)` - shape: (bs, response_length) - eos_mask: `(torch.Tensor)` - shape: (bs, response_length). [EOS] mask. The token after [EOS] have mask zero. - gamma: `(float)` - discounted factor used in RL - lam: `(float)` - lambda value when computing Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438) - - Returns: - advantages: `(torch.Tensor)` - shape: (bs, response_length) - Returns: `(torch.Tensor)` - shape: (bs, response_length) - - """ - with torch.no_grad(): - lastgaelam = 0 - advantages_reversed = [] - gen_len = token_level_rewards.shape[-1] - - # values = values * eos_mask TODO: may use in multi-turn - for t in reversed(range(gen_len)): - nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0 - delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t] - - lastgaelam = delta + gamma * lam * lastgaelam - # lastgaelam = torch.where( # TODO: may use in multi-turn - # eos_mask[:, t] == 1, delta + gamma * lam * lastgaelam, lastgaelam - # ) - advantages_reversed.append(lastgaelam) - advantages = torch.stack(advantages_reversed[::-1], dim=1) - - returns = advantages + values - advantages = verl_F.masked_whiten(advantages, eos_mask) - return advantages, returns - - -# NOTE(sgm): this implementation only consider outcome supervision, where the reward is a scalar. -def compute_grpo_outcome_advantage( - token_level_rewards: torch.Tensor, - eos_mask: torch.Tensor, - index: torch.Tensor, - epsilon: float = 1e-6, -): - """ - Compute advantage for GRPO, operating only on Outcome reward - (with only one scalar reward for each response). - Args: - token_level_rewards: `(torch.Tensor)` - shape: (bs, response_length) - eos_mask: `(torch.Tensor)` - shape: (bs, response_length) - - Returns: - advantages: `(torch.Tensor)` - shape: (bs, response_length) - Returns: `(torch.Tensor)` - shape: (bs, response_length) - """ - response_length = token_level_rewards.shape[-1] - scores = token_level_rewards.sum(dim=-1) - - id2score = defaultdict(list) - id2mean = {} - id2std = {} - - with torch.no_grad(): - bsz = scores.shape[0] - for i in range(bsz): - id2score[index[i]].append(scores[i]) - for idx in id2score: - if len(id2score[idx]) == 1: - id2mean[idx] = torch.tensor(0.0) - id2std[idx] = torch.tensor(1.0) - elif len(id2score[idx]) > 1: - id2mean[idx] = torch.mean(torch.tensor(id2score[idx])) - id2std[idx] = torch.std(torch.tensor([id2score[idx]])) - else: - raise ValueError(f"no score in prompt index: {idx}") - for i in range(bsz): - scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon) - scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask - - return scores, scores - - -def compute_rloo_outcome_advantage( - token_level_rewards: torch.Tensor, - eos_mask: torch.Tensor, - index: torch.Tensor, - epsilon: float = 1e-6, -): - """ - Compute advantage for RLOO based on https://arxiv.org/abs/2402.14740 - Args: - token_level_rewards: `(torch.Tensor)` - shape: (bs, response_length) - eos_mask: `(torch.Tensor)` - shape: (bs, response_length) - - Returns: - advantages: `(torch.Tensor)` - shape: (bs, response_length) - Returns: `(torch.Tensor)` - shape: (bs, response_length) - """ - response_length = token_level_rewards.shape[-1] - scores = token_level_rewards.sum(dim=-1) - - id2score = defaultdict(list) - id2mean = {} - - with torch.no_grad(): - bsz = scores.shape[0] - for i in range(bsz): - id2score[index[i]].append(scores[i]) - for idx in id2score: - if len(id2score[idx]) == 1: - id2mean[idx] = torch.tensor(0.0) - elif len(id2score[idx]) > 1: - id2mean[idx] = torch.mean(torch.tensor(id2score[idx])) - else: - raise ValueError(f"no score in prompt index: {idx}") - for i in range(bsz): - response_num = len(id2score[index[i]]) - if response_num > 1: - scores[i] = scores[i] * response_num / (response_num - 1) - id2mean[ - index[i] - ] * response_num / (response_num - 1) - scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask - - return scores, scores - - -def compute_reinforce_plus_plus_outcome_advantage( - token_level_rewards: torch.Tensor, eos_mask: torch.Tensor, gamma: float -): - """ - Compute advantage for REINFORCE++. - This implementation is based on the paper: https://arxiv.org/abs/2501.03262 - Args: - token_level_rewards: `(torch.Tensor)` - shape: (bs, response_length) - eos_mask: `(torch.Tensor)` - shape: (bs, response_length) - - Returns: - advantages: `(torch.Tensor)` - shape: (bs, response_length) - Returns: `(torch.Tensor)` - shape: (bs, response_length) - """ - - with torch.no_grad(): - returns = torch.zeros_like(token_level_rewards) - running_return = 0 - - for t in reversed(range(token_level_rewards.shape[1])): - running_return = token_level_rewards[:, t] + gamma * running_return - returns[:, t] = running_return - - advantages = verl_F.masked_whiten(returns, eos_mask) - advantages = advantages * eos_mask - - return advantages, returns - - -def compute_remax_outcome_advantage( - token_level_rewards: torch.Tensor, reward_baselines: torch.Tensor, eos_mask: torch.Tensor -): - """ - Compute advantage for ReMax, operating only on Outcome reward - This implementation is based on the paper: https://arxiv.org/abs/2310.10505 - - (with only one scalar reward for each response). - Args: - token_level_rewards: `(torch.Tensor)` - shape: (bs, response_length) - reward_baselines: `(torch.Tensor)` - shape: (bs,) - eos_mask: `(torch.Tensor)` - shape: (bs, response_length) - - Returns: - advantages: `(torch.Tensor)` - shape: (bs, response_length) - Returns: `(torch.Tensor)` - shape: (bs, response_length) - """ - response_length = token_level_rewards.shape[-1] - token_level_rewards.sum(dim=-1) - - with torch.no_grad(): - returns = (token_level_rewards * eos_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1]) - advantages = returns - reward_baselines.unsqueeze(-1).tile([1, response_length]) * eos_mask - - return advantages, returns - - -def compute_rewards(token_level_scores, old_log_prob, ref_log_prob, kl_ratio): - kl = old_log_prob - ref_log_prob - return token_level_scores - kl * kl_ratio - - -def compute_policy_loss(old_log_prob, log_prob, eos_mask, **kwargs): - """Compute policy loss for PPO / OPMD / pairwise OPMD""" - - algorithm_type: AlgorithmType = kwargs.get("algorithm_type", AlgorithmType.PPO) - - if algorithm_type == AlgorithmType.OPMD: - advantages = kwargs.get("advantages") - tau = kwargs.get("tau") - return compute_policy_loss_opmd( - old_log_prob=old_log_prob, - log_prob=log_prob, - advantages=advantages, - eos_mask=eos_mask, - tau=tau, - ) - - elif algorithm_type == AlgorithmType.PAIRWISE_OPMD: - token_level_scores = kwargs.get("token_level_scores") - index = kwargs.get("index") - tau = kwargs.get("tau") - return compute_policy_loss_pairwise_opmd( - old_log_prob=old_log_prob, - log_prob=log_prob, - token_level_scores=token_level_scores, - eos_mask=eos_mask, - index=index, - tau=tau, - ) - - elif algorithm_type.is_rft(): - advantages = kwargs.get("advantages") - cliprange = kwargs.get("cliprange") - return compute_policy_loss_ppo( - old_log_prob=old_log_prob, - log_prob=log_prob, - advantages=advantages, - eos_mask=eos_mask, - cliprange=cliprange, - ) - - else: - raise NotImplementedError(f"Get invalid algorithm_type '{algorithm_type}'.") - - -def compute_policy_loss_dpo( - log_prob, ref_log_prob, eos_mask, loss_type="sigmoid", beta=0.1, label_smoothing=0.0 -): - """Compute policy loss for DPO (Direct Preference Optimization) - - Ref: https://github.com/huggingface/trl/blob/main/trl/trainer/dpo_trainer.py#L918 - - Args: - log_prob: `(torch.Tensor)` - The log probabilities of the chosen responses from the policy model. - ref_log_prob: `(torch.Tensor)` - The log probabilities of the chosen responses from the reference model. - loss_type: `(str)` - Default: "sigmoid" - The type of loss function to use. - beta: `(float)` - Default: 0.1 - A temperature parameter that controls the sharpness of the preference signal. - Higher values make the loss more sensitive to small differences in log probabilities. - label_smoothing: `(float)` - Default: 0.0 - A parameter to encode uncertainty about the labels. Adds a small amount of smoothing to the loss - to avoid overconfident predictions. - - Returns: - dpo_loss: `a scalar torch.Tensor` - chosen_diff: `(torch.Tensor)` - rejected_diff: `(torch.Tensor)` - """ - # log_prob: chosen, rejected, chosen, rejected, ... - chosen_log_prob, rejected_log_prob = log_prob[::2], log_prob[1::2] - chosen_mask, rejected_mask = eos_mask[::2], eos_mask[1::2] - chosen_log_prob_sum = (chosen_log_prob * chosen_mask).sum(-1) - rejected_log_prob_sum = (rejected_log_prob * rejected_mask).sum(-1) - - if ref_log_prob is None: - raise NotImplementedError("DPO requires valid ref_log_prob") - chosen_ref_log_prob, rejected_ref_log_prob = ref_log_prob[::2], ref_log_prob[1::2] - chosen_ref_log_prob_sum = (chosen_ref_log_prob * chosen_mask).sum(-1) - rejected_ref_log_prob_sum = (rejected_ref_log_prob * rejected_mask).sum(-1) - - # compute logits - chosen_ratios = chosen_log_prob_sum - chosen_ref_log_prob_sum - rejected_ratios = rejected_log_prob_sum - rejected_ref_log_prob_sum - logits = chosen_ratios - rejected_ratios - - if loss_type == "sigmoid": - losses = ( - -F.logsigmoid(beta * logits) * (1 - label_smoothing) - - F.logsigmoid(-beta * logits) * label_smoothing - ) - loss = losses.mean() - - else: - raise NotImplementedError(f"loss_type {loss_type} is not supported in DPO") - - chosen_reward = beta * chosen_ratios.detach() - rejected_reward = beta * rejected_ratios.detach() - return loss, chosen_reward, rejected_reward - - -def compute_policy_loss_pairwise_opmd( - old_log_prob, log_prob, token_level_scores, eos_mask, index, tau -): - """Compute policy loss for pairwise_opmd - - NOTE: NOT TESTED YET - - TODO: allow using old_log_prob; for now we just discard it. - - NOTE: use token_level_scores rather than token_level_rewards, because we're not sure yet - whether this algorithm is compatible with kl penalty as negative reward - - Args: - old_log_prob: `(torch.Tensor)` - shape: (bs, response_length) - log_prob: `(torch.Tensor)` - shape: (bs, response_length) - token_level_scores: `(torch.Tensor)` - shape: (bs, response_length) - eos_mask: `(torch.Tensor)` - shape: (bs, response_length) - index: `(torch.Tensor)` or None (when use_uid is False) - tau: `float` - - Returns: - opmd_loss: `a scalar torch.Tensor` - pairwise_opmd loss - pg_clipfrac: (float) - a float number indicating the fraction of policy gradient loss being clipped - ppo_kl: (float) ... (TODO, confirm that this is only used for logging stats) - - """ - - # dummy computation - log_prob_diff = log_prob - log_prob - pg_clipfrac = verl_F.masked_mean(torch.gt(log_prob_diff, log_prob_diff).float(), eos_mask) - ppo_kl = verl_F.masked_mean(-log_prob_diff, eos_mask) - - # loss for pairwise_opmd - scores = token_level_scores.sum(dim=-1) - action_level_log_prob = (log_prob * eos_mask).sum(dim=-1) - diffs = scores - tau * (action_level_log_prob - action_level_log_prob.detach()) - - if index is None: - normalizer = eos_mask.sum() * max(1.0, tau) - opmd_loss = (diffs - diffs.mean()).square().sum() / normalizer - else: - opmd_loss = None - unique_index = list(set(index.tolist())) - for idx in unique_index: - subdiff = diffs[index == idx] - if subdiff.shape[0] == 1: - continue - # subloss = len(subdiff) * subdiff.square().sum() - subdiff.sum().square() - subloss = (subdiff - subdiff.mean()).square().sum() - if opmd_loss is None: - opmd_loss = subloss - else: - opmd_loss = opmd_loss + subloss - normalizer = eos_mask.sum() * max(1.0, tau) - opmd_loss = opmd_loss / normalizer - - # NOTE: return pg_clipfrac and ppo_kl merely for compatibility with original compute_policy_loss - return opmd_loss, pg_clipfrac, ppo_kl - - -def compute_policy_loss_opmd(old_log_prob, log_prob, advantages, eos_mask, tau): - """The OPMD counterpart of verl's original compute_policy_loss (now renamed as compute_policy_loss_ppo) - - Args: - old_log_prob: `(torch.Tensor)` - shape: (bs, response_length) - log_prob: `(torch.Tensor)` - shape: (bs, response_length) - advantages: `(torch.Tensor)` - shape: (bs, response_length) - eos_mask: `(torch.Tensor)` - shape: (bs, response_length) - tau: `float` - - Returns: - opmd_loss: `a scalar torch.Tensor` - opmd loss - pg_clipfrac: (float) - a float number indicating the fraction of policy gradient loss being clipped - ppo_kl: (float) ... (TODO, confirm that this is only used for logging stats) - - """ - log_prob_diff = log_prob - old_log_prob - pg_clipfrac = verl_F.masked_mean( - torch.gt(log_prob_diff, log_prob_diff).float(), eos_mask - ) # meaningless - ppo_kl = verl_F.masked_mean(-log_prob_diff, eos_mask) - - # --- version 0: kimi-opmd --- - - # # the original quadratic loss in OPMD can be reformulated as follows - # pg_losses = -advantages * log_prob - # pg_loss = verl_F.masked_sum(pg_losses, eos_mask) - - # reg_losses = (log_prob_diff * eos_mask).sum(dim=-1).square() - # reg_loss = reg_losses.sum() - - # opmd_loss = (pg_loss + 0.5 * tau * reg_loss) / eos_mask.sum() - # # NOTE: this implementation uses batch-wise normalization; - # # would it be beneficial to use trajectory-wise or group-wise normalization? - - # opmd_loss = opmd_loss / max(1.0, tau) # for stability when tau is large - - # --- version 1: min-opmd (minimalistic, but theoretically grounded) --- - - pg_losses = -advantages * log_prob - opmd_loss = verl_F.masked_mean(pg_losses, eos_mask) - opmd_loss = opmd_loss / (1.0 + tau) # for regularization (w.r.t. current pi_theta) - - # NOTE: return pg_clipfrac and ppo_kl merely for compatibility with original compute_policy_loss - return opmd_loss, pg_clipfrac, ppo_kl - - -def compute_policy_loss_ppo(old_log_prob, log_prob, advantages, eos_mask, cliprange): - """Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1122 - - Args: - old_log_prob: `(torch.Tensor)` - shape: (bs, response_length) - log_prob: `(torch.Tensor)` - shape: (bs, response_length) - advantages: `(torch.Tensor)` - shape: (bs, response_length) - eos_mask: `(torch.Tensor)` - shape: (bs, response_length) - cliprange: (float) - The clip range used in PPO. See https://arxiv.org/abs/1707.06347 - - Returns: - pg_loss: `a scalar torch.Tensor` - policy gradient loss computed via PPO - pg_clipfrac: (float) - a float number indicating the fraction of policy gradient loss being clipped - - """ - negative_approx_kl = log_prob - old_log_prob - ratio = torch.exp(negative_approx_kl) - ppo_kl = verl_F.masked_mean(-negative_approx_kl, eos_mask) - - pg_losses = -advantages * ratio - pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - cliprange, 1.0 + cliprange) - - pg_loss = verl_F.masked_mean(torch.max(pg_losses, pg_losses2), eos_mask) - pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses).float(), eos_mask) - return pg_loss, pg_clipfrac, ppo_kl - - -def compute_policy_loss_sft(log_prob, eos_mask): - """Simple way to compute SFT loss, unified with PG loss - - Args: - log_prob: `(torch.Tensor)` - shape: (bs, response_length) - eos_mask: `(torch.Tensor)` - shape: (bs, response_length) - - Returns: - sft_loss: `a scalar torch.Tensor` - pg_clipfrac: dummy value, merely for compatibility - ppo_kl: dummy value, merely for compatibility - - """ - log_prob_diff = log_prob - log_prob.detach() - pg_clipfrac = verl_F.masked_mean(torch.gt(log_prob_diff, log_prob_diff).float(), eos_mask) - ppo_kl = verl_F.masked_mean(-log_prob_diff, eos_mask) - - sft_loss = verl_F.masked_mean(-log_prob, eos_mask) - - # Return pg_clipfrac and ppo_kl merely for compatibility with original compute_policy_loss - return sft_loss, pg_clipfrac, ppo_kl - - -def compute_entropy_loss(logits, eos_mask): - """Compute Categorical entropy loss - - Args: - logits: `(torch.Tensor)` - shape: (bs, response_length, vocab_size) - eos_mask: `(torch.Tensor)` - shape: (bs, response_length) - - Returns: - entropy: a scalar torch.Tensor - - """ - # compute entropy - entropy = verl_F.entropy_from_logits(logits) # (bs, response_len) - entropy_loss = verl_F.masked_mean(entropy, mask=eos_mask) - return entropy_loss - - -def compute_value_loss(vpreds, returns, values, eos_mask, cliprange_value): - """Compute the value loss. Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1151 - - Args: - vpreds (`torch.FloatTensor`): - Predicted values of the value head, shape (`batch_size`, `response_length`) - values (`torch.FloatTensor`): - Old values of value head, shape (`batch_size`, `response_length`) - returns: (`torch.FloatTensor`): - Ground truth returns, shape (`batch_size`, `response_length`) - - Returns: - vf_loss: a scalar (`torch.FloatTensor`): - value function loss - vf_clipfrac: a float - The ratio of vf being clipped - - """ - vpredclipped = verl_F.clip_by_value(vpreds, values - cliprange_value, values + cliprange_value) - vf_losses1 = (vpreds - returns) ** 2 - vf_losses2 = (vpredclipped - returns) ** 2 - vf_loss = 0.5 * verl_F.masked_mean(torch.max(vf_losses1, vf_losses2), eos_mask) - vf_clipfrac = verl_F.masked_mean(torch.gt(vf_losses2, vf_losses1).float(), eos_mask) - return vf_loss, vf_clipfrac - - -def kl_penalty( - logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_penalty -) -> torch.FloatTensor: - """Compute KL divergence given logprob and ref_logprob. - Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1104 - - Args: - logprob: - ref_logprob: - - Returns: - - """ - if kl_penalty == "kl": - return logprob - ref_logprob - - if kl_penalty == "abs": - return (logprob - ref_logprob).abs() - - if kl_penalty == "mse": - return 0.5 * (logprob - ref_logprob).square() - - # J. Schulman. Approximating kl divergence, 2020. - # # URL http://joschu.net/blog/kl-approx.html. - if kl_penalty == "low_var_kl": - kl = ref_logprob - logprob - ratio = torch.exp(kl) - kld = (ratio - kl - 1).contiguous() - return torch.clamp(kld, min=-10, max=10) - - if kl_penalty == "full": - # so, here logprob and ref_logprob should contain the logits for every token in vocabulary - raise NotImplementedError - - raise NotImplementedError diff --git a/trinity/trainer/verl/dp_actor.py b/trinity/trainer/verl/dp_actor.py index a7705fc6a0..595084ac02 100644 --- a/trinity/trainer/verl/dp_actor.py +++ b/trinity/trainer/verl/dp_actor.py @@ -31,9 +31,9 @@ from verl.workers.actor import BasePPOActor from trinity.algorithm import ENTROPY_LOSS_FN, KL_FN, POLICY_LOSS_FN +from trinity.algorithm.kl_fn.kl_fn import DummyKLFn from trinity.algorithm.utils import prefix_metrics from trinity.common.config import AlgorithmConfig -from trinity.common.constants import AlgorithmType __all__ = ["DataParallelPPOActor"] @@ -55,11 +55,11 @@ def __init__( self.use_ulysses_sp = self.ulysses_sequence_parallel_size > 1 self.compute_entropy_from_logits = torch.compile(verl_F.entropy_from_logits, dynamic=True) - self.algorithm_type = AlgorithmType.PPO self.policy_loss_fn = None + self.kl_loss_fn = None + self.entropy_loss_fn = None def set_algorithm(self, algorithm_config: AlgorithmConfig): - self.algorithm_type = algorithm_config.algorithm_type self.policy_loss_fn = POLICY_LOSS_FN.get(algorithm_config.policy_loss_fn)( **algorithm_config.policy_loss_fn_args ) @@ -299,7 +299,7 @@ def update_policy(self, data: DataProto): # noqa: C901 for trinity_key in self.policy_loss_fn.select_keys: verl_key = select_keys_trinity2verl[trinity_key] select_keys.append(verl_key) - if self.config.use_kl_loss: + if not isinstance(self.kl_loss_fn, DummyKLFn): select_keys.append("ref_log_prob") select_keys = list(set(select_keys)) batch = data.select(batch_keys=select_keys).batch @@ -388,7 +388,7 @@ def update_policy(self, data: DataProto): # noqa: C901 ) # compute entropy loss from entropy - entropy_loss, entropy_loss_metrics = self.entropy_loss_fn( + entropy_loss, entropy_loss_metrics = self.entropy_loss_fn( # type: ignore entropy=entropy, action_mask=response_mask, ) @@ -403,11 +403,13 @@ def update_policy(self, data: DataProto): # noqa: C901 kl_loss, kl_loss_metrics = self.kl_loss_fn.calculate_kl_loss( logprob=log_prob, - ref_logprob=data["ref_log_prob"], + ref_logprob=data.get("ref_log_prob", None), response_mask=response_mask, ) prefix_metrics( - src_metrics=kl_loss_metrics, prefix="actor", dst_metrics=micro_batch_metrics + src_metrics=kl_loss_metrics, + prefix="actor", + dst_metrics=micro_batch_metrics, ) policy_loss = policy_loss + kl_loss diff --git a/trinity/trainer/verl/ray_trainer.py b/trinity/trainer/verl/ray_trainer.py deleted file mode 100644 index 5d883d05bb..0000000000 --- a/trinity/trainer/verl/ray_trainer.py +++ /dev/null @@ -1,816 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Modified from ray_trainer.py -""" - -import os -from contextlib import contextmanager -from dataclasses import dataclass, field -from enum import Enum -from typing import Dict, Type - -import numpy as np -import ray -import torch -from codetiming import Timer -from omegaconf import OmegaConf, open_dict -from torch.utils.data import RandomSampler, SequentialSampler -from torchdata.stateful_dataloader import StatefulDataLoader -from verl import DataProto -from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto -from verl.single_controller.base import Worker -from verl.single_controller.ray import ( - RayClassWithInitArgs, - RayResourcePool, - RayWorkerGroup, -) -from verl.single_controller.ray.base import create_colocated_worker_cls -from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path -from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn -from verl.utils.seqlen_balancing import ( - get_seqlen_balanced_partitions, - log_seqlen_unbalance, -) -from verl.utils.torch_functional import masked_mean -from verl.utils.tracking import ValidationGenerationsLogger - -from trinity.common.constants import AlgorithmType -from trinity.trainer.verl import core_algos - -WorkerType = Type[Worker] - - -class Role(Enum): - """ - To create more roles dynamically, you can subclass Role and add new members - """ - - Actor = 0 - Rollout = 1 - ActorRollout = 2 - Critic = 3 - RefPolicy = 4 - RewardModel = 5 - ActorRolloutRef = 6 - - -class AdvantageEstimator(str, Enum): - """ - Using an enumeration class to avoid spelling errors in adv_estimator - """ - - GAE = "gae" - GRPO = "grpo" - REINFORCE_PLUS_PLUS = "reinforce_plus_plus" - REMAX = "remax" - RLOO = "rloo" - - -@dataclass -class ResourcePoolManager: - """ - Define a resource pool specification. Resource pool will be initialized first. - Mapping - """ - - resource_pool_spec: dict[str, list[int]] - mapping: dict[Role, str] - resource_pool_dict: dict[str, RayResourcePool] = field(default_factory=dict) - - def create_resource_pool(self): - for resource_pool_name, process_on_nodes in self.resource_pool_spec.items(): - # max_colocate_count means the number of WorkerGroups (i.e. processes) in each RayResourcePool - # For FSDP backend, we recommend using max_colocate_count=1 that merge all WorkerGroups into one. - # For Megatron backend, we recommend using max_colocate_count>1 that can utilize different WorkerGroup for differnt models - resource_pool = RayResourcePool( - process_on_nodes=process_on_nodes, - use_gpu=True, - max_colocate_count=1, - name_prefix=resource_pool_name, - ) - self.resource_pool_dict[resource_pool_name] = resource_pool - - self._check_resource_available() - - def get_resource_pool(self, role: Role) -> RayResourcePool: - """Get the resource pool of the worker_cls""" - return self.resource_pool_dict[self.mapping[role]] - - def get_n_gpus(self) -> int: - """Get the number of gpus in this cluster.""" - return sum( - [ - n_gpus - for process_on_nodes in self.resource_pool_spec.values() - for n_gpus in process_on_nodes - ] - ) - - def _check_resource_available(self): - """Check if the resource pool can be satisfied in this ray cluster.""" - node_available_resources = ray.state.available_resources_per_node() - node_available_gpus = { - node: node_info.get("GPU", 0) for node, node_info in node_available_resources.items() - } - - # check total required gpus can be satisfied - total_available_gpus = sum(node_available_gpus.values()) - total_required_gpus = sum( - [ - n_gpus - for process_on_nodes in self.resource_pool_spec.values() - for n_gpus in process_on_nodes - ] - ) - if total_available_gpus < total_required_gpus: - raise ValueError( - f"Total available GPUs {total_available_gpus} is less than total desired GPUs {total_required_gpus}" - ) - - # check each resource pool can be satisfied, O(#resource_pools * #nodes) - for resource_pool_name, process_on_nodes in self.resource_pool_spec.items(): - num_gpus, num_nodes = process_on_nodes[0], len(process_on_nodes) - for node, available_gpus in node_available_gpus.items(): - if available_gpus >= num_gpus: - node_available_gpus[node] -= num_gpus - num_nodes -= 1 - if num_nodes == 0: - break - if num_nodes > 0: - raise ValueError( - f"Resource pool {resource_pool_name}: {num_gpus}*{num_nodes} cannot be satisfied in this ray cluster" - ) - - -def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty="kl"): - responses = data.batch["responses"] - response_length = responses.size(1) - token_level_scores = data.batch["token_level_scores"] - batch_size = data.batch.batch_size[0] - attention_mask = data.batch["attention_mask"] - # response_mask = attention_mask[:, -response_length:] - response_mask = data.batch["response_mask"] - assert response_mask.shape == attention_mask[:, -response_length:].shape - - # compute kl between ref_policy and current policy - if "ref_log_prob" in data.batch.keys(): - kld = core_algos.kl_penalty( - data.batch["old_log_probs"], data.batch["ref_log_prob"], kl_penalty=kl_penalty - ) # (batch_size, response_length) - kld = kld * response_mask - beta = kl_ctrl.value - else: - beta = 0 - kld = torch.zeros_like(response_mask, dtype=torch.float32) - - token_level_rewards = token_level_scores - beta * kld - - current_kl = masked_mean(kld, mask=response_mask, axis=-1) # average over sequence - current_kl = torch.mean(current_kl, dim=0).item() - - # according to https://github.com/huggingface/trl/blob/951ca1841f29114b969b57b26c7d3e80a39f75a0/trl/trainer/ppo_trainer.py#L837 - kl_ctrl.update(current_kl=current_kl, n_steps=batch_size) - data.batch["token_level_rewards"] = token_level_rewards - - metrics = {"critic/kl": current_kl, "critic/kl_coeff": beta} - - return data, metrics - - -def compute_response_mask(data: DataProto): - responses = data.batch["responses"] - response_length = responses.size(1) - attention_mask = data.batch["attention_mask"] - return attention_mask[:, -response_length:] - - -@contextmanager -def _timer(name: str, timing_raw: Dict[str, float]): - with Timer(name=name, logger=None) as timer: - yield - timing_raw[name] = timer.last - - -class RayPPOTrainer(object): - """ - Note that this trainer runs on the driver process on a single CPU/GPU node. - """ - - # TODO: support each role have individual ray_worker_group_cls, - # i.e., support different backend of different role - def __init__( - self, - config, - tokenizer, - role_worker_mapping: dict[Role, WorkerType], - resource_pool_manager: ResourcePoolManager, - ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup, - processor=None, - reward_fn=None, - val_reward_fn=None, - ): - # assert torch.cuda.is_available(), 'cuda must be available on driver' - - self.tokenizer = tokenizer - self.processor = processor - self.config = config - self.reward_fn = reward_fn - self.val_reward_fn = val_reward_fn - - self.hybrid_engine = config.actor_rollout_ref.hybrid_engine - assert self.hybrid_engine, "Currently, only support hybrid engine" - - if self.hybrid_engine: - assert Role.ActorRollout in role_worker_mapping, f"{role_worker_mapping.keys()=}" - - self.role_worker_mapping = role_worker_mapping - self.resource_pool_manager = resource_pool_manager - self.use_reference_policy = Role.RefPolicy in role_worker_mapping - self.use_rm = Role.RewardModel in role_worker_mapping - self.ray_worker_group_cls = ray_worker_group_cls - self.validation_generations_logger = ValidationGenerationsLogger() - - # define KL control - if self.use_reference_policy: - self.kl_ctrl = core_algos.get_kl_controller(config.algorithm.kl_ctrl) - else: - self.kl_ctrl = core_algos.FixedKLController(kl_coef=0.0) - - if ( - self.config.actor_rollout_ref.actor.get("algorithm_type", AlgorithmType.PPO) - != AlgorithmType.PPO - ): - self.use_critic = False - elif self.config.algorithm.adv_estimator == AdvantageEstimator.GAE: - self.use_critic = True - elif self.config.algorithm.adv_estimator in [ - AdvantageEstimator.GRPO, - AdvantageEstimator.REINFORCE_PLUS_PLUS, - AdvantageEstimator.REMAX, - AdvantageEstimator.RLOO, - ]: - self.use_critic = False - else: - raise NotImplementedError - - self._validate_config() - self._create_dataloader() - - def _validate_config(self): # noqa: C901 - config = self.config - # number of GPUs total - n_gpus = config.trainer.n_gpus_per_node * config.trainer.nnodes - - # 1. Check total batch size for data correctness - real_train_batch_size = config.data.train_batch_size * config.actor_rollout_ref.rollout.n - assert ( - real_train_batch_size % n_gpus == 0 - ), f"real_train_batch_size ({real_train_batch_size}) must be divisible by total n_gpus ({n_gpus})." - - # A helper function to check "micro_batch_size" vs "micro_batch_size_per_gpu" - # We throw an error if the user sets both. The new convention is "..._micro_batch_size_per_gpu". - def check_mutually_exclusive(mbs, mbs_per_gpu, name: str): - if mbs is None and mbs_per_gpu is None: - raise ValueError( - f"[{name}] Please set at least one of '{name}.micro_batch_size' or " - f"'{name}.micro_batch_size_per_gpu'." - ) - - if mbs is not None and mbs_per_gpu is not None: - raise ValueError( - f"[{name}] You have set both '{name}.micro_batch_size' AND " - f"'{name}.micro_batch_size_per_gpu'. Please remove '{name}.micro_batch_size' " - f"because only '*_micro_batch_size_per_gpu' is supported (the former is deprecated)." - ) - - if not config.actor_rollout_ref.actor.use_dynamic_bsz: - # actor: ppo_micro_batch_size vs. ppo_micro_batch_size_per_gpu - check_mutually_exclusive( - config.actor_rollout_ref.actor.ppo_micro_batch_size, - config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu, - "actor_rollout_ref.actor", - ) - - # reference: log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu - check_mutually_exclusive( - config.actor_rollout_ref.ref.log_prob_micro_batch_size, - config.actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu, - "actor_rollout_ref.ref", - ) - - # The rollout section also has log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu - check_mutually_exclusive( - config.actor_rollout_ref.rollout.log_prob_micro_batch_size, - config.actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu, - "actor_rollout_ref.rollout", - ) - - if self.use_critic and not config.critic.use_dynamic_bsz: - # Check for critic micro-batch size conflicts - check_mutually_exclusive( - config.critic.ppo_micro_batch_size, - config.critic.ppo_micro_batch_size_per_gpu, - "critic", - ) - - # Check for reward model micro-batch size conflicts - if config.reward_model.enable and not config.reward_model.use_dynamic_bsz: - check_mutually_exclusive( - config.reward_model.micro_batch_size, - config.reward_model.micro_batch_size_per_gpu, - "reward_model", - ) - - # Actor - # if NOT dynamic_bsz, we must ensure: - # ppo_mini_batch_size is divisible by ppo_micro_batch_size - # ppo_micro_batch_size * sequence_parallel_size >= n_gpus - if not config.actor_rollout_ref.actor.use_dynamic_bsz: - assert ( - config.data.train_batch_size >= config.actor_rollout_ref.actor.ppo_mini_batch_size - ) - sp_size = config.actor_rollout_ref.actor.get("ulysses_sequence_parallel_size", 1) - if config.actor_rollout_ref.actor.ppo_micro_batch_size is not None: - assert ( - config.actor_rollout_ref.actor.ppo_mini_batch_size - % config.actor_rollout_ref.actor.ppo_micro_batch_size - == 0 - ) - assert config.actor_rollout_ref.actor.ppo_micro_batch_size * sp_size >= n_gpus - - # critic - if self.use_critic and not config.critic.use_dynamic_bsz: - assert config.data.train_batch_size >= config.critic.ppo_mini_batch_size - sp_size = config.critic.get("ulysses_sequence_parallel_size", 1) - if config.critic.ppo_micro_batch_size is not None: - assert config.critic.ppo_mini_batch_size % config.critic.ppo_micro_batch_size == 0 - assert config.critic.ppo_micro_batch_size * sp_size >= n_gpus - - # Check if use_remove_padding is enabled when using sequence parallelism for fsdp - if config.actor_rollout_ref.actor.strategy == "fsdp": - if ( - config.actor_rollout_ref.actor.get("ulysses_sequence_parallel_size", 1) > 1 - or config.actor_rollout_ref.ref.get("ulysses_sequence_parallel_size", 1) > 1 - ): - assert ( - config.actor_rollout_ref.model.use_remove_padding - ), "When using sequence parallelism for actor/ref policy, you must enable `use_remove_padding`." - - if self.use_critic and config.critic.strategy == "fsdp": - if config.critic.get("ulysses_sequence_parallel_size", 1) > 1: - assert ( - config.critic.model.use_remove_padding - ), "When using sequence parallelism for critic, you must enable `use_remove_padding`." - - if config.data.get("val_batch_size", None) is not None: - print( - "WARNING: val_batch_size is deprecated. Validation datasets are sent to inference engines as a whole batch, which will schedule the memory themselves." - ) - - print("[validate_config] All configuration checks passed successfully!") - - def _create_dataloader(self): - # TODO: we have to make sure the batch size is divisible by the dp size - self.train_dataset = RLHFDataset( - parquet_files=self.config.data.train_files, - tokenizer=self.tokenizer, - processor=self.processor, - prompt_key=self.config.data.prompt_key, - image_key=self.config.data.get("image_key", "images"), - max_prompt_length=self.config.data.max_prompt_length, - filter_prompts=True, - return_raw_chat=self.config.data.get("return_raw_chat", False), - truncation="error", - ) - # use sampler for better ckpt resume - if self.config.data.shuffle: - train_dataloader_generator = torch.Generator() - train_dataloader_generator.manual_seed(self.config.data.get("seed", 1)) - sampler = RandomSampler( - data_source=self.train_dataset, generator=train_dataloader_generator - ) - else: - sampler = SequentialSampler(data_source=self.train_dataset) - - self.train_dataloader = StatefulDataLoader( - dataset=self.train_dataset, - batch_size=self.config.data.train_batch_size, - num_workers=8, - drop_last=True, - collate_fn=collate_fn, - sampler=sampler, - ) - - self.val_dataset = RLHFDataset( - parquet_files=self.config.data.val_files, - tokenizer=self.tokenizer, - processor=self.processor, - prompt_key=self.config.data.prompt_key, - image_key=self.config.data.get("image_key", "images"), - max_prompt_length=self.config.data.max_prompt_length, - filter_prompts=True, - return_raw_chat=self.config.data.get("return_raw_chat", False), - truncation="error", - ) - self.val_dataloader = StatefulDataLoader( - dataset=self.val_dataset, - # Validation datasets are sent to inference engines as a whole batch, - # which will schedule the memory themselves. - batch_size=len(self.val_dataset), - num_workers=8, - shuffle=False, - drop_last=False, - collate_fn=collate_fn, - ) - - assert len(self.train_dataloader) >= 1 - assert ( - len(self.val_dataloader) == 1 - ), "Validation dataloader must have a single batch, which inference engines will schedule the memory themselves." - - print(f"Size of train dataloader: {len(self.train_dataloader)}") - - # inject total_training_steps to actor/critic optim_config. This is hacky. - total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs - - if self.config.trainer.total_training_steps is not None: - total_training_steps = self.config.trainer.total_training_steps - - self.total_training_steps = total_training_steps - print(f"Total training steps: {self.total_training_steps}") - - OmegaConf.set_struct(self.config, True) - with open_dict(self.config): - self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps - self.config.critic.optim.total_training_steps = total_training_steps - - def _maybe_log_val_generations(self, inputs, outputs, scores): - """Log a table of validation samples to the configured logger (wandb or swanlab)""" - - generations_to_log = self.config.trainer.val_generations_to_log_to_wandb - - if generations_to_log == 0: - return - - import numpy as np - - # Create tuples of (input, output, score) and sort by input text - samples = list(zip(inputs, outputs, scores)) - samples.sort(key=lambda x: x[0]) # Sort by input text - - # Use fixed random seed for deterministic shuffling - rng = np.random.RandomState(42) - rng.shuffle(samples) - - # Take first N samples after shuffling - samples = samples[:generations_to_log] - - # Log to each configured logger - self.validation_generations_logger.log( - self.config.trainer.logger, samples, self.global_steps - ) - - def _validate(self): - reward_tensor_lst = [] - data_source_lst = [] - - # Lists to collect samples for the table - sample_inputs = [] - sample_outputs = [] - sample_scores = [] - - for test_data in self.val_dataloader: - test_batch = DataProto.from_single_dict(test_data) - - # we only do validation on rule-based rm - if ( - self.config.reward_model.enable - and test_batch[0].non_tensor_batch["reward_model"]["style"] == "model" - ): - return {} - - # Store original inputs - input_ids = test_batch.batch["input_ids"] - input_texts = [ - self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids - ] - sample_inputs.extend(input_texts) - - if "multi_modal_inputs" in test_batch.non_tensor_batch.keys(): - test_gen_batch = test_batch.pop( - batch_keys=["input_ids", "attention_mask", "position_ids"], - non_tensor_batch_keys=[ - "raw_prompt_ids", - "multi_modal_data", - "multi_modal_inputs", - ], - ) - else: - test_gen_batch = test_batch.pop( - batch_keys=["input_ids", "attention_mask", "position_ids"], - non_tensor_batch_keys=["raw_prompt_ids"], - ) - - test_gen_batch.meta_info = { - "eos_token_id": self.tokenizer.eos_token_id, - "pad_token_id": self.tokenizer.pad_token_id, - "recompute_log_prob": False, - "do_sample": False, - "validate": True, - } - - # pad to be divisible by dp_size - test_gen_batch_padded, pad_size = pad_dataproto_to_divisor( - test_gen_batch, self.actor_rollout_wg.world_size - ) - test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences( - test_gen_batch_padded - ) - # unpad - test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size) - print("validation generation end") - - # Store generated outputs - output_ids = test_output_gen_batch.batch["responses"] - output_texts = [ - self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids - ] - sample_outputs.extend(output_texts) - - test_batch = test_batch.union(test_output_gen_batch) - - # evaluate using reward_function - reward_tensor = self.val_reward_fn(test_batch) - - # Store scores - scores = reward_tensor.sum(-1).cpu().tolist() - sample_scores.extend(scores) - - reward_tensor_lst.append(reward_tensor) - data_source_lst.append( - test_batch.non_tensor_batch.get("data_source", ["unknown"] * reward_tensor.shape[0]) - ) - - self._maybe_log_val_generations( - inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores - ) - - reward_tensor = torch.cat(reward_tensor_lst, dim=0).sum(-1).cpu() # (batch_size,) - data_sources = np.concatenate(data_source_lst, axis=0) - - # evaluate test_score based on data source - data_source_reward = {} - for i in range(reward_tensor.shape[0]): - data_source = data_sources[i] - if data_source not in data_source_reward: - data_source_reward[data_source] = [] - data_source_reward[data_source].append(reward_tensor[i].item()) - - metric_dict = {} - for data_source, rewards in data_source_reward.items(): - metric_dict[f"val/test_score/{data_source}"] = np.mean(rewards) - - return metric_dict - - def init_workers(self): - """Init resource pool and worker group""" - self.resource_pool_manager.create_resource_pool() - - self.resource_pool_to_cls = { - pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values() - } - - # create actor and rollout - if self.hybrid_engine: - resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout) - actor_rollout_cls = RayClassWithInitArgs( - cls=self.role_worker_mapping[Role.ActorRollout], - config=self.config.actor_rollout_ref, - role="actor", - ) - self.resource_pool_to_cls[resource_pool]["actor"] = actor_rollout_cls - else: - raise NotImplementedError - - # create critic - if self.use_critic: - resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic) - critic_cls = RayClassWithInitArgs( - cls=self.role_worker_mapping[Role.Critic], config=self.config.critic - ) - self.resource_pool_to_cls[resource_pool]["critic"] = critic_cls - - # create reference policy if needed - if self.use_reference_policy: - resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy) - ref_policy_cls = RayClassWithInitArgs( - self.role_worker_mapping[Role.RefPolicy], - config=self.config.actor_rollout_ref, - role="ref", - ) - self.resource_pool_to_cls[resource_pool]["ref"] = ref_policy_cls - - # create a reward model if reward_fn is None - if self.use_rm: - # we create a RM here - resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel) - rm_cls = RayClassWithInitArgs( - self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model - ) - self.resource_pool_to_cls[resource_pool]["rm"] = rm_cls - - # initialize WorkerGroup - # NOTE: if you want to use a different resource pool for each role, which can support different parallel size, - # you should not use `create_colocated_worker_cls`. Instead, directly pass different resource pool to different worker groups. - # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information. - all_wg = {} - self.wg_dicts = [] - for resource_pool, class_dict in self.resource_pool_to_cls.items(): - worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) - wg_dict = self.ray_worker_group_cls( - resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls - ) - spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) - all_wg.update(spawn_wg) - # keep the referece of WorkerDict to support ray >= 2.31. Ref: https://github.com/ray-project/ray/pull/45699 - self.wg_dicts.append(wg_dict) - - if self.use_critic: - self.critic_wg = all_wg["critic"] - self.critic_wg.init_model() - - if self.use_reference_policy: - self.ref_policy_wg = all_wg["ref"] - self.ref_policy_wg.init_model() - - if self.use_rm: - self.rm_wg = all_wg["rm"] - self.rm_wg.init_model() - - # we should create rollout at the end so that vllm can have a better estimation of kv cache memory - self.actor_rollout_wg = all_wg["actor"] - self.actor_rollout_wg.init_model() - - def _save_checkpoint(self): - # path: given_path + `/global_step_{global_steps}` + `/actor` - local_global_step_folder = os.path.join( - self.config.trainer.default_local_dir, f"global_step_{self.global_steps}" - ) - - print(f"local_global_step_folder: {local_global_step_folder}") - actor_local_path = os.path.join(local_global_step_folder, "actor") - - actor_remote_path = ( - None - if self.config.trainer.default_hdfs_dir is None - else os.path.join( - self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "actor" - ) - ) - - remove_previous_ckpt_in_save = self.config.trainer.get( - "remove_previous_ckpt_in_save", False - ) - if remove_previous_ckpt_in_save: - print( - "Warning: remove_previous_ckpt_in_save is deprecated, set max_actor_ckpt_to_keep=1 and max_critic_ckpt_to_keep=1 instead" - ) - max_actor_ckpt_to_keep = ( - self.config.trainer.get("max_actor_ckpt_to_keep", None) - if not remove_previous_ckpt_in_save - else 1 - ) - max_critic_ckpt_to_keep = ( - self.config.trainer.get("max_critic_ckpt_to_keep", None) - if not remove_previous_ckpt_in_save - else 1 - ) - - self.actor_rollout_wg.save_checkpoint( - actor_local_path, - actor_remote_path, - self.global_steps, - max_ckpt_to_keep=max_actor_ckpt_to_keep, - ) - - if self.use_critic: - critic_local_path = os.path.join(local_global_step_folder, "critic") - critic_remote_path = ( - None - if self.config.trainer.default_hdfs_dir is None - else os.path.join( - self.config.trainer.default_hdfs_dir, - f"global_step_{self.global_steps}", - "critic", - ) - ) - self.critic_wg.save_checkpoint( - critic_local_path, - critic_remote_path, - self.global_steps, - max_ckpt_to_keep=max_critic_ckpt_to_keep, - ) - - # save dataloader - dataloader_local_path = os.path.join(local_global_step_folder, "data.pt") - dataloader_state_dict = self.train_dataloader.state_dict() - torch.save(dataloader_state_dict, dataloader_local_path) - - # latest checkpointed iteration tracker (for atomic usage) - local_latest_checkpointed_iteration = os.path.join( - self.config.trainer.default_local_dir, "latest_checkpointed_iteration.txt" - ) - with open(local_latest_checkpointed_iteration, "w") as f: - f.write(str(self.global_steps)) - - def _load_checkpoint(self): - if self.config.trainer.resume_mode == "disable": - return 0 - - # load from hdfs - if self.config.trainer.default_hdfs_dir is not None: - raise NotImplementedError("load from hdfs is not implemented yet") - else: - checkpoint_folder = self.config.trainer.default_local_dir # TODO: check path - if not os.path.isabs(checkpoint_folder): - working_dir = os.getcwd() - checkpoint_folder = os.path.join(working_dir, checkpoint_folder) - global_step_folder = find_latest_ckpt_path(checkpoint_folder) # None if no latest - - # find global_step_folder - if self.config.trainer.resume_mode == "auto": - if global_step_folder is None: - print("Training from scratch") - return 0 - else: - if self.config.trainer.resume_mode == "resume_path": - assert isinstance( - self.config.trainer.resume_from_path, str - ), "resume ckpt must be str type" - assert ( - "global_step_" in self.config.trainer.resume_from_path - ), "resume ckpt must specify the global_steps" - global_step_folder = self.config.trainer.resume_from_path - if not os.path.isabs(global_step_folder): - working_dir = os.getcwd() - global_step_folder = os.path.join(working_dir, global_step_folder) - print(f"Load from checkpoint folder: {global_step_folder}") - # set global step - self.global_steps = int(global_step_folder.split("global_step_")[-1]) - - print(f"Setting global step to {self.global_steps}") - print(f"Resuming from {global_step_folder}") - - actor_path = os.path.join(global_step_folder, "actor") - critic_path = os.path.join(global_step_folder, "critic") - # load actor - self.actor_rollout_wg.load_checkpoint( - actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load - ) - # load critic - if self.use_critic: - self.critic_wg.load_checkpoint( - critic_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load - ) - - # load dataloader, - # TODO: from remote not implemented yet - dataloader_local_path = os.path.join(global_step_folder, "data.pt") - if os.path.exists(dataloader_local_path): - dataloader_state_dict = torch.load(dataloader_local_path, weights_only=False) - self.train_dataloader.load_state_dict(dataloader_state_dict) - else: - print( - f"Warning: No dataloader state found at {dataloader_local_path}, will start from scratch" - ) - - def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqlen"): - """Reorder the data on single controller such that each dp rank gets similar total tokens""" - attention_mask = batch.batch["attention_mask"] - batch_size = attention_mask.shape[0] - global_seqlen_lst = ( - batch.batch["attention_mask"].view(batch_size, -1).sum(-1).tolist() - ) # (train_batch_size,) - world_size = self.actor_rollout_wg.world_size - global_partition_lst = get_seqlen_balanced_partitions( - global_seqlen_lst, k_partitions=world_size, equal_size=True - ) - # reorder based on index. The data will be automatically equally partitioned by dispatch function - global_idx = torch.tensor([j for partition in global_partition_lst for j in partition]) - batch.reorder(global_idx) - global_balance_stats = log_seqlen_unbalance( - seqlen_list=global_seqlen_lst, partitions=global_partition_lst, prefix=logging_prefix - ) - metrics.update(global_balance_stats) diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index 83e3480dc3..d040c329dd 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -19,24 +19,27 @@ compute_timing_metrics, reduce_metrics, ) -from verl.utils import hf_tokenizer -from verl.utils.fs import copy_local_path_from_hdfs - -from trinity.algorithm import ADVANTAGE_FN, KL_FN -from trinity.algorithm.utils import prefix_metrics -from trinity.common.config import AlgorithmConfig, Config -from trinity.common.constants import AlgorithmType -from trinity.common.experience import Experiences -from trinity.trainer.trainer import TrainEngineWrapper -from trinity.trainer.verl.ray_trainer import ( +from verl.trainer.ppo.ray_trainer import ( DataProto, + RayClassWithInitArgs, RayPPOTrainer, RayWorkerGroup, ResourcePoolManager, Role, _timer, + create_colocated_worker_cls, find_latest_ckpt_path, ) +from verl.utils import hf_tokenizer +from verl.utils.fs import copy_local_path_from_hdfs + +from trinity.algorithm import ADVANTAGE_FN, KL_FN +from trinity.algorithm.algorithm import ALGORITHM_TYPE, SFTAlgorithm +from trinity.algorithm.algorithm_manager import AlgorithmManager +from trinity.algorithm.utils import prefix_metrics +from trinity.common.config import Config +from trinity.common.experience import Experiences +from trinity.trainer.trainer import TrainEngineWrapper from trinity.utils.monitor import Monitor @@ -119,6 +122,19 @@ def __init__( resource_pool_manager = ResourcePoolManager( resource_pool_spec=resource_pool_spec, mapping=mapping ) + self.algorithm_config = global_config.algorithm + self.algorithm = None + self.algorithm_manager = AlgorithmManager(global_config) + + # specify advantage function for various rft algorithms + algorithm = ALGORITHM_TYPE.get(self.algorithm_config.algorithm_type) + if algorithm.use_advantage: + self.advantage_fn = ADVANTAGE_FN.get(self.algorithm_config.advantage_fn)( + **self.algorithm_config.advantage_fn_args + ) + self.kl_fn = KL_FN.get(self.algorithm_config.kl_penalty_fn)( + **self.algorithm_config.kl_penalty_fn_args + ) super().__init__( config, @@ -128,15 +144,6 @@ def __init__( ray_worker_group_cls, ) self.init_workers() - self.algorithm_type = AlgorithmType.PPO - - # specify advantage function for various rft algorithms - algo_config = global_config.algorithm - if algo_config.algorithm_type.is_rft(): - self.advantage_fn = ADVANTAGE_FN.get(algo_config.advantage_fn)( - **algo_config.advantage_fn_args - ) - self.kl_fn = KL_FN.get(algo_config.kl_penalty_fn)(**algo_config.kl_penalty_fn_args) self.logger = Monitor( project=config.trainer.project_name, @@ -146,20 +153,109 @@ def __init__( ) self.reset_experiences_example_table() + def _validate_config(self): # TODO + algorithm = ALGORITHM_TYPE.get(self.algorithm_config.algorithm_type) + self.use_critic = algorithm.use_critic + super()._validate_config() + + def init_workers(self): + """Init resource pool and worker group""" + self.resource_pool_manager.create_resource_pool() + + self.resource_pool_to_cls = { + pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values() + } + + # create actor and rollout + if self.hybrid_engine: + resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout) + actor_rollout_cls = RayClassWithInitArgs( + cls=self.role_worker_mapping[Role.ActorRollout], + config=self.config.actor_rollout_ref, + role="actor", + ) + self.resource_pool_to_cls[resource_pool]["actor"] = actor_rollout_cls + else: + raise NotImplementedError + + # create critic + if self.use_critic: + resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic) + critic_cls = RayClassWithInitArgs( + cls=self.role_worker_mapping[Role.Critic], config=self.config.critic + ) + self.resource_pool_to_cls[resource_pool]["critic"] = critic_cls + + # create reference policy if needed + if self.use_reference_policy: + resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy) + ref_policy_cls = RayClassWithInitArgs( + self.role_worker_mapping[Role.RefPolicy], + config=self.config.actor_rollout_ref, + role="ref", + ) + self.resource_pool_to_cls[resource_pool]["ref"] = ref_policy_cls + + # create a reward model if reward_fn is None + if self.use_rm: + # we create a RM here + resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel) + rm_cls = RayClassWithInitArgs( + self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model + ) + self.resource_pool_to_cls[resource_pool]["rm"] = rm_cls + + # initialize WorkerGroup + # NOTE: if you want to use a different resource pool for each role, which can support different parallel size, + # you should not use `create_colocated_worker_cls`. Instead, directly pass different resource pool to different worker groups. + # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information. + all_wg = {} + self.wg_dicts = [] + for resource_pool, class_dict in self.resource_pool_to_cls.items(): + worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) + wg_dict = self.ray_worker_group_cls( + resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls + ) + spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) + all_wg.update(spawn_wg) + # keep the referece of WorkerDict to support ray >= 2.31. Ref: https://github.com/ray-project/ray/pull/45699 + self.wg_dicts.append(wg_dict) + + if self.use_critic: + self.critic_wg = all_wg["critic"] + self.critic_wg.init_model() + + if self.use_reference_policy: + self.ref_policy_wg = all_wg["ref"] + self.ref_policy_wg.init_model() + + if self.use_rm: + self.rm_wg = all_wg["rm"] + self.rm_wg.init_model() + + # we should create rollout at the end so that vllm can have a better estimation of kv cache memory + self.actor_rollout_wg = all_wg["actor"] + self.actor_rollout_wg.init_model() + def reset_experiences_example_table(self): self.experiences_example_table = pd.DataFrame( columns=["step", "reward", "prompt", "response"] ) + @property + def train_step_num(self) -> int: + return self.global_steps + def prepare(self): self.actor_rollout_wg.setup_weight_sync_group() + # The global step counter, initialized to 0 + # It represents the total number of training steps completed so far + # We increment this counter at the beginning of each training step self.global_steps = 0 - self.sft_warmup_step_num = 0 # load checkpoint before doing anything self._load_checkpoint() - self.sft_warmup_step_num = min(self.global_steps, self.config.trainer.sft_warmup_steps) # perform validation before training # currently, we only support validation using the reward_function. @@ -170,190 +266,60 @@ def prepare(self): if self.config.trainer.get("val_only", False): return - # we start from step 1 - def _create_dataloader(self): self.train_dataloader = _InternalDataLoader(self.config) # TODO: compute total training steps self.total_training_steps = self.config.trainer.total_training_steps or sys.maxsize - def train_dpo_step(self, experiences: Experiences) -> Tuple[bool, int]: - self.global_steps += 1 - metrics = {} - timing_raw = {} - - with _timer("step", timing_raw): - # generate a batch - attention_mask = experiences.attention_masks - cumsum = torch.cumsum(attention_mask, dim=-1) - position_ids = torch.clip(cumsum - 1, 0, None).long() - - batch = DataProto.from_single_dict( - { - "uid": np.array(experiences.run_ids), # useless - "position_ids": position_ids, - "input_ids": experiences.tokens.long(), - "responses": experiences.tokens[:, experiences.prompt_length :].long(), - "attention_mask": attention_mask.long(), - "response_mask": ( - experiences.action_masks[:, experiences.prompt_length :].long() - if hasattr(experiences, "action_masks") - and experiences.action_masks is not None - else attention_mask[:, experiences.prompt_length :].long() - ), - } - ) - batch.meta_info["temperature"] = self.config.actor_rollout_ref.rollout.temperature - - # self._balance_batch(batch, metrics=metrics) # _balance_batch will shuffle the batch, which will break DPO - # TODO: implement a new _balance_batch for DPO - - # compute global_valid tokens - batch.meta_info["global_token_num"] = torch.sum( - batch.batch["attention_mask"], dim=-1 - ).tolist() - - if self.use_reference_policy: - # compute reference log_prob - with _timer("ref", timing_raw): - ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) - batch = batch.union(ref_log_prob) - - # update actor - with _timer("update_actor", timing_raw): - actor_output = self.actor_rollout_wg.update_actor(batch) - actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) - metrics.update(actor_output_metrics) - - # collect metrics - metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) - - self.logger.log(data=metrics, step=self.global_steps) - - # save checkpoint - if ( - self.config.trainer.save_freq > 0 - and self.global_steps % self.config.trainer.save_freq == 0 - ): - with _timer("save_checkpoint", timing_raw): - self._save_checkpoint() - - if self.global_steps >= self.total_training_steps: - if ( - self.config.trainer.save_freq > 0 - and self.global_steps % self.config.trainer.save_freq != 0 - ): - with _timer("save_checkpoint", timing_raw): - self._save_checkpoint() - # stop training - return False, self.global_steps - else: - # continue - return True, self.global_steps - - def train_sft_step(self, experiences: Experiences) -> Tuple[bool, int]: - if self.sft_warmup_step_num >= self.config.trainer.sft_warmup_steps: - return False, self.global_steps - self.global_steps += 1 - metrics = {} - timing_raw = {} - - with _timer("step", timing_raw): - # generate a batch - attention_mask = experiences.attention_masks - cumsum = torch.cumsum(attention_mask, dim=-1) - position_ids = torch.clip(cumsum - 1, 0, None).long() - - batch = DataProto.from_single_dict( - { - "uid": np.array(experiences.run_ids), - "position_ids": position_ids, - "input_ids": experiences.tokens.long(), - "responses": experiences.tokens[:, experiences.prompt_length :].long(), - "attention_mask": attention_mask.long(), - "response_mask": ( - experiences.action_masks[:, experiences.prompt_length :].long() - if hasattr(experiences, "action_masks") - and experiences.action_masks is not None - else attention_mask[:, experiences.prompt_length :].long() - ), - } - ) - batch.meta_info["temperature"] = self.config.actor_rollout_ref.rollout.temperature - - self._balance_batch(batch, metrics=metrics) # TODO this may affect multi-turn - - # compute global_valid tokens - batch.meta_info["global_token_num"] = torch.sum( - batch.batch["attention_mask"], dim=-1 - ).tolist() - - if self.use_reference_policy: - # compute reference log_prob - with _timer("ref", timing_raw): - ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) - batch = batch.union(ref_log_prob) - - # update actor - with _timer("update_actor", timing_raw): - actor_output = self.actor_rollout_wg.update_actor(batch) - actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) - metrics.update(actor_output_metrics) - - # collect metrics - metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) - - # TODO: log as sft metrics - self.logger.log(data=metrics, step=self.global_steps) - self.sft_warmup_step_num += 1 - train_status = True - if self.sft_warmup_step_num == self.config.trainer.sft_warmup_steps: - self.logger.log( - data={"sft_warmup_steps": self.sft_warmup_step_num}, - step=self.global_steps, - ) - with _timer("save_checkpoint", timing_raw): - self._save_checkpoint() - train_status = False - return train_status, self.global_steps - - def train_rft_step(self, experiences: Experiences) -> Tuple[bool, int]: + def train_step(self, experiences: Experiences) -> Tuple[bool, int]: self.global_steps += 1 metrics = {} timing_raw = {} + algorithm_config = self.algorithm_manager.get_current_algorithm_config(self.global_steps) + algorithm = ALGORITHM_TYPE.get(algorithm_config.algorithm_type) + if self.algorithm != algorithm: + self.actor_rollout_wg.set_algorithm(algorithm_config) + if self.algorithm == SFTAlgorithm: + self.sft_to_rft() + self.algorithm = algorithm with _timer("step", timing_raw): # Convert rewards to token_level_rewards attention_mask = experiences.attention_masks - token_level_rewards = torch.zeros(attention_mask.shape, dtype=experiences.rewards.dtype) cumsum = torch.cumsum(attention_mask, dim=-1) - eos_mask_idx = cumsum.argmax(dim=-1) position_ids = torch.clip(cumsum - 1, 0, None).long() - token_level_rewards[ - torch.arange(experiences.batch_size), eos_mask_idx - ] = experiences.rewards - token_level_rewards = token_level_rewards[:, experiences.prompt_length :] - - batch = DataProto.from_single_dict( - { - "uid": np.array(experiences.run_ids), - "position_ids": position_ids, - "input_ids": experiences.tokens.long(), - "responses": experiences.tokens[:, experiences.prompt_length :].long(), - "attention_mask": attention_mask.long(), - "response_mask": ( - experiences.action_masks[:, experiences.prompt_length :].long() - if hasattr(experiences, "action_masks") - and experiences.action_masks is not None - else attention_mask[:, experiences.prompt_length :].long() - ), - "token_level_scores": token_level_rewards, - "old_log_probs": experiences.logprobs[:, experiences.prompt_length :], # type: ignore - } - ) + batch_dict = { + "uid": np.array(experiences.run_ids), + "position_ids": position_ids, + "input_ids": experiences.tokens.long(), + "responses": experiences.tokens[:, experiences.prompt_length :].long(), + "attention_mask": attention_mask.long(), + "response_mask": ( + experiences.action_masks[:, experiences.prompt_length :].long() + if hasattr(experiences, "action_masks") and experiences.action_masks is not None + else attention_mask[:, experiences.prompt_length :].long() + ), + } + if self.algorithm.use_advantage: + token_level_rewards = torch.zeros( + attention_mask.shape, dtype=experiences.rewards.dtype + ) + eos_mask_idx = cumsum.argmax(dim=-1) + token_level_rewards[ + torch.arange(experiences.batch_size), eos_mask_idx + ] = experiences.rewards + token_level_rewards = token_level_rewards[:, experiences.prompt_length :] + batch_dict.update( + { + "token_level_scores": token_level_rewards, + "old_log_probs": experiences.logprobs[:, experiences.prompt_length :], # type: ignore + } + ) + + batch = DataProto.from_single_dict(batch_dict) batch.meta_info["temperature"] = self.config.actor_rollout_ref.rollout.temperature - if self.config.trainer.balance_batch: + if self.algorithm.can_balance_batch and self.config.trainer.balance_batch: self._balance_batch(batch, metrics=metrics) # TODO this may affect multi-turn # compute global_valid tokens @@ -361,34 +327,37 @@ def train_rft_step(self, experiences: Experiences) -> Tuple[bool, int]: batch.batch["attention_mask"], dim=-1 ).tolist() - if self.use_reference_policy: + if self.algorithm.use_reference: # ref_logprob may not be used # compute reference log_prob with _timer("ref", timing_raw): ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) batch = batch.union(ref_log_prob) - # compute values - if self.use_critic: + if self.algorithm.use_critic: with _timer("values", timing_raw): values = self.critic_wg.compute_values(batch) batch = batch.union(values) - with _timer("adv", timing_raw): - # compute kl penalty - batch, kl_metrics = self.kl_fn.apply_kl_penalty_to_reward(batch) - metrics.update(prefix_metrics(kl_metrics, prefix="critic")) - # compute advantages, executed on the driver process - batch, _ = self.advantage_fn(batch) + if self.algorithm.use_advantage: + with _timer("adv", timing_raw): + # compute kl penalty + batch, kl_metrics = self.kl_fn.apply_kl_penalty_to_reward(batch) + metrics.update(prefix_metrics(kl_metrics, prefix="critic")) + # compute advantages, executed on the driver process + batch, _ = self.advantage_fn(batch) - # update critic - if self.use_critic: + # update critic + if self.algorithm.use_critic: with _timer("update_critic", timing_raw): critic_output = self.critic_wg.update_critic(batch) critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"]) metrics.update(critic_output_metrics) # implement critic warmup - if self.config.trainer.critic_warmup <= self.global_steps: + if ( + not self.algorithm.use_critic + or self.config.trainer.critic_warmup <= self.global_steps + ): # update actor with _timer("update_actor", timing_raw): actor_output = self.actor_rollout_wg.update_actor(batch) @@ -404,31 +373,29 @@ def train_rft_step(self, experiences: Experiences) -> Tuple[bool, int]: self._save_checkpoint() # collect metrics - metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) + if self.algorithm.use_advantage: # TODO + metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) n_gpus = self.resource_pool_manager.get_n_gpus() metrics.update( compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus) ) - if self.config.enable_preview: + if self.algorithm.use_advantage and self.config.enable_preview: # TODO self._log_experiences(experiences) # TODO: make a canonical logger that supports various backend self.logger.log(data=metrics, step=self.global_steps) - if self.global_steps >= self.total_training_steps: + train_status = self.global_steps < self.total_training_steps + if not train_status or self.algorithm_manager.need_save(self.global_steps): if ( - self.config.trainer.save_freq > 0 - and self.global_steps % self.config.trainer.save_freq != 0 + self.config.trainer.save_freq == 0 + or self.global_steps % self.config.trainer.save_freq != 0 ): with _timer("save_checkpoint", timing_raw): self._save_checkpoint() - # stop training - return False, self.global_steps - else: - # continue - return True, self.global_steps + return train_status, self.global_steps def _log_single_experience( self, experiences: Experiences, idx: int, skip_special_tokens: bool @@ -477,12 +444,6 @@ def save_checkpoint(self) -> None: def sync_weight(self) -> None: self.actor_rollout_wg.sync_weight() - def set_algorithm(self, algorithm_config: AlgorithmConfig) -> None: - self.actor_rollout_wg.set_algorithm(algorithm_config) - if self.algorithm_type.is_sft() and (not algorithm_config.algorithm_type.is_sft()): - self.sft_to_rft() - self.algorithm_type = algorithm_config.algorithm_type - def sft_to_rft(self) -> None: # load from hdfs if self.config.trainer.default_hdfs_dir is not None: @@ -513,9 +474,9 @@ def sft_to_rft(self) -> None: global_step_folder = os.path.join(working_dir, global_step_folder) print(f"Load from checkpoint folder: {global_step_folder}") # set global step - self.global_steps = int(global_step_folder.split("global_step_")[-1]) + global_steps = int(global_step_folder.split("global_step_")[-1]) + assert self.global_steps == global_steps + 1 - print(f"Setting global step to {self.global_steps}") print(f"Resuming from {global_step_folder}") actor_path = os.path.join(global_step_folder, "actor")