diff --git a/tests/template/config.yaml b/tests/template/config.yaml index a83a82655f..c83d938c66 100644 --- a/tests/template/config.yaml +++ b/tests/template/config.yaml @@ -8,6 +8,11 @@ algorithm: policy_loss_fn: ppo policy_loss_fn_args: clip_range: 0.2 + advantage_fn_type: ppo_adv_fn + advantage_fn_args: + gamma: 1.0 + lam: 1.0 + model: model_path: '' max_prompt_tokens: 2048 diff --git a/trinity/algorithm/__init__.py b/trinity/algorithm/__init__.py index 51d3da8317..170507663f 100644 --- a/trinity/algorithm/__init__.py +++ b/trinity/algorithm/__init__.py @@ -1,4 +1,4 @@ -from trinity.algorithm.advantage_fn.advantage_fn import ADVANTAGE_FN, AdvantageFn +from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn from trinity.algorithm.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn __all__ = [ diff --git a/trinity/algorithm/advantage_fn/__init__.py b/trinity/algorithm/advantage_fn/__init__.py index e69de29bb2..7bcf682e4b 100644 --- a/trinity/algorithm/advantage_fn/__init__.py +++ b/trinity/algorithm/advantage_fn/__init__.py @@ -0,0 +1,20 @@ +from trinity.algorithm.advantage_fn.advantage_fn import ADVANTAGE_FN, AdvantageFn +from trinity.algorithm.advantage_fn.grpo_advantage import GRPOAdvantageFn +from trinity.algorithm.advantage_fn.opmd_advantage import OPMDAdvantageFn +from trinity.algorithm.advantage_fn.ppo_advantage import PPOAdvantageFn +from trinity.algorithm.advantage_fn.reinforce_plus_plus_advantage import ( + REINFORCEPLUSPLUSAdvantageFn, +) +from trinity.algorithm.advantage_fn.remax_advantage import REMAXAdvantageFn +from trinity.algorithm.advantage_fn.rloo_advantage import RLOOAdvantageFn + +__all__ = [ + "ADVANTAGE_FN", + "AdvantageFn", + "PPOAdvantageFn", + "GRPOAdvantageFn", + "REINFORCEPLUSPLUSAdvantageFn", + "REMAXAdvantageFn", + "RLOOAdvantageFn", + "OPMDAdvantageFn", +] diff --git a/trinity/algorithm/advantage_fn/advantage_fn.py b/trinity/algorithm/advantage_fn/advantage_fn.py index 7e965b017c..21e3668a53 100644 --- a/trinity/algorithm/advantage_fn/advantage_fn.py +++ b/trinity/algorithm/advantage_fn/advantage_fn.py @@ -16,6 +16,14 @@ def __call__(self, exps: Any, **kwargs: Dict) -> Tuple[Any, Dict]: kwargs (`Dict`): The step-level parameters for calculating advantages. Returns: - `Any`: The experiences with advantages. + `DataProto`: The experiences with advantages. `Dict`: The metrics for logging. """ + + @classmethod + @abstractmethod + def default_args(cls) -> Dict: + """ + Returns: + `Dict`: The default init arguments for the advantage function. + """ diff --git a/trinity/algorithm/advantage_fn/grpo_advantage.py b/trinity/algorithm/advantage_fn/grpo_advantage.py new file mode 100644 index 0000000000..89a8282752 --- /dev/null +++ b/trinity/algorithm/advantage_fn/grpo_advantage.py @@ -0,0 +1,42 @@ +"""GRPO advantage computation + +Adapted from compute_advantage_ppo in original ray_trainer.py +""" + +from typing import Dict, Tuple + +from verl import DataProto + +from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn +from trinity.trainer.verl import core_algos + + +@ADVANTAGE_FN.register_module("grpo_adv_fn") +class GRPOAdvantageFn(AdvantageFn): + """GRPO advantage computation""" + + def __init__(self) -> None: + pass + + def __call__( + self, + exps: DataProto, + **kwargs, + ) -> Tuple[DataProto, Dict]: + advantages, returns = core_algos.compute_grpo_outcome_advantage( + token_level_rewards=exps.batch["token_level_rewards"], + eos_mask=exps.batch["response_mask"], + index=exps.non_tensor_batch["uid"], + ) + exps.batch["advantages"] = advantages + exps.batch["returns"] = returns + + metrics = { + # TODO: add meaningful metrics + } + + return exps, metrics + + @classmethod + def default_args(cls) -> Dict: + return {} diff --git a/trinity/algorithm/advantage_fn/opmd_advantage.py b/trinity/algorithm/advantage_fn/opmd_advantage.py new file mode 100644 index 0000000000..abf74686d3 --- /dev/null +++ b/trinity/algorithm/advantage_fn/opmd_advantage.py @@ -0,0 +1,45 @@ +"""OPMD advantage computation + +Adapted from compute_advantage_opmd in original ray_trainer.py +""" + +from typing import Dict, Tuple + +from verl import DataProto + +from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn +from trinity.trainer.verl import core_algos + + +@ADVANTAGE_FN.register_module("opmd_adv_fn") +class OPMDAdvantageFn(AdvantageFn): + """OPMD advantage computation""" + + def __init__(self) -> None: + pass + + def __call__( + self, + exps: DataProto, + **kwargs, + ) -> Tuple[DataProto, Dict]: + advantages, returns = core_algos.compute_opmd_outcome_advantage( + token_level_rewards=exps.batch["token_level_rewards"], + eos_mask=exps.batch["response_mask"], + # TODO (yanxi): check consistency with exps.batch["attention_mask"][:, -response_length:] in original implementation + index=exps.non_tensor_batch["uid"], + opmd_baseline="mean", + tau=1.0, + ) + exps.batch["advantages"] = advantages + exps.batch["returns"] = returns + + metrics = { + # TODO: add meaningful metrics + } + + return exps, metrics + + @classmethod + def default_args(cls) -> Dict: + return {} diff --git a/trinity/algorithm/advantage_fn/ppo_advantage.py b/trinity/algorithm/advantage_fn/ppo_advantage.py new file mode 100644 index 0000000000..5afd51311c --- /dev/null +++ b/trinity/algorithm/advantage_fn/ppo_advantage.py @@ -0,0 +1,50 @@ +"""PPO's GAE advantage computation + +Adapted from compute_advantage_ppo in original ray_trainer.py +""" + +from typing import Dict, Tuple + +from verl import DataProto + +from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn +from trinity.trainer.verl import core_algos + + +@ADVANTAGE_FN.register_module("ppo_adv_fn") +class PPOAdvantageFn(AdvantageFn): + def __init__( + self, + gamma: float = 1.0, + lam: float = 1.0, + ) -> None: + self.gamma = gamma + self.lam = lam + + def __call__( + self, + exps: DataProto, + **kwargs, + ) -> Tuple[DataProto, Dict]: + advantages, returns = core_algos.compute_gae_advantage_return( + token_level_rewards=exps.batch["token_level_rewards"], + values=exps.batch["values"], + eos_mask=exps.batch["response_mask"], + gamma=self.gamma, + lam=self.lam, + ) + exps.batch["advantages"] = advantages + exps.batch["returns"] = returns + + metrics = { + # TODO: add meaningful metrics + } + + return exps, metrics + + @classmethod + def default_args(cls) -> Dict: + return { + "gamma": 1.0, + "lam": 1.0, + } diff --git a/trinity/algorithm/advantage_fn/reinforce_plus_plus_advantage.py b/trinity/algorithm/advantage_fn/reinforce_plus_plus_advantage.py new file mode 100644 index 0000000000..9c668f7640 --- /dev/null +++ b/trinity/algorithm/advantage_fn/reinforce_plus_plus_advantage.py @@ -0,0 +1,42 @@ +"""REINFORCE++ advantage computation + +Adapted from compute_advantage_ppo in original ray_trainer.py +""" + +from typing import Dict, Tuple + +from verl import DataProto + +from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn +from trinity.trainer.verl import core_algos + + +@ADVANTAGE_FN.register_module("reinforceplusplus_adv_fn") +class REINFORCEPLUSPLUSAdvantageFn(AdvantageFn): + def __init__(self, gamma: float = 1.0) -> None: + self.gamma = gamma + + def __call__( + self, + exps: DataProto, + **kwargs, + ) -> Tuple[DataProto, Dict]: + advantages, returns = core_algos.compute_reinforce_plus_plus_outcome_advantage( + token_level_rewards=exps.batch["token_level_rewards"], + eos_mask=exps.batch["response_mask"], + gamma=self.gamma, + ) + exps.batch["advantages"] = advantages + exps.batch["returns"] = returns + + metrics = { + # TODO: add meaningful metrics + } + + return exps, metrics + + @classmethod + def default_args(cls) -> Dict: + return { + "gamma": 1.0, + } diff --git a/trinity/algorithm/advantage_fn/remax_advantage.py b/trinity/algorithm/advantage_fn/remax_advantage.py new file mode 100644 index 0000000000..05a13d7d60 --- /dev/null +++ b/trinity/algorithm/advantage_fn/remax_advantage.py @@ -0,0 +1,40 @@ +"""REMAX advantage computation + +Adapted from compute_advantage_ppo in original ray_trainer.py +""" + +from typing import Dict, Tuple + +from verl import DataProto + +from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn +from trinity.trainer.verl import core_algos + + +@ADVANTAGE_FN.register_module("remax_adv_fn") +class REMAXAdvantageFn(AdvantageFn): + def __init__(self) -> None: + pass + + def __call__( + self, + exps: DataProto, + **kwargs, + ) -> Tuple[DataProto, Dict]: + advantages, returns = core_algos.compute_remax_outcome_advantage( + token_level_rewards=exps.batch["token_level_rewards"], + reward_baselines=exps.batch["reward_baselines"], + eos_mask=exps.batch["response_mask"], + ) + exps.batch["advantages"] = advantages + exps.batch["returns"] = returns + + metrics = { + # TODO: add meaningful metrics + } + + return exps, metrics + + @classmethod + def default_args(cls) -> Dict: + return {} diff --git a/trinity/algorithm/advantage_fn/rloo_advantage.py b/trinity/algorithm/advantage_fn/rloo_advantage.py new file mode 100644 index 0000000000..3da61c9da4 --- /dev/null +++ b/trinity/algorithm/advantage_fn/rloo_advantage.py @@ -0,0 +1,40 @@ +"""RLOO advantage computation + +Adapted from compute_advantage_ppo in original ray_trainer.py +""" + +from typing import Dict, Tuple + +from verl import DataProto + +from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn +from trinity.trainer.verl import core_algos + + +@ADVANTAGE_FN.register_module("rloo_adv_fn") +class RLOOAdvantageFn(AdvantageFn): + def __init__(self) -> None: + pass + + def __call__( + self, + exps: DataProto, + **kwargs, + ) -> Tuple[DataProto, Dict]: + advantages, returns = core_algos.compute_rloo_outcome_advantage( + token_level_rewards=exps.batch["token_level_rewards"], + eos_mask=exps.batch["response_mask"], + index=exps.non_tensor_batch["uid"], + ) + exps.batch["advantages"] = advantages + exps.batch["returns"] = returns + + metrics = { + # TODO: add meaningful metrics + } + + return exps, metrics + + @classmethod + def default_args(cls) -> Dict: + return {} diff --git a/trinity/common/config.py b/trinity/common/config.py index 9c3b582618..794202bab0 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -173,13 +173,15 @@ class AlgorithmConfig: algorithm_type: AlgorithmType = AlgorithmType.PPO # for GRPO-like algorithms, repeat each task for `repeat_times` times repeat_times: int = 1 - gamma: Optional[float] = None - lam: Optional[float] = None policy_loss_fn: str = "ppo" # If not set, use PolicyLossFn.default_args() policy_loss_fn_args: Optional[dict] = None + advantage_fn_type: str = "ppo_adv_fn" + # If not set, use AdvantageFn.default_args() + advantage_fn_args: Optional[dict] = None + @dataclass class ClusterConfig: @@ -470,7 +472,7 @@ def _check_buffer(self) -> None: # noqa: C901 self.buffer.tokenizer_path = self.model.model_path def _check_algorithm(self) -> None: - from trinity.algorithm import POLICY_LOSS_FN + from trinity.algorithm import ADVANTAGE_FN, POLICY_LOSS_FN policy_fn_cls = POLICY_LOSS_FN.get(self.algorithm.policy_loss_fn) if policy_fn_cls is None: @@ -478,6 +480,12 @@ def _check_algorithm(self) -> None: if self.algorithm.policy_loss_fn_args is None: self.algorithm.policy_loss_fn_args = policy_fn_cls.default_args() + advantage_fn_cls = ADVANTAGE_FN.get(self.algorithm.advantage_fn_type) + if advantage_fn_cls is None: + raise ValueError(f"Invalid advantage_fn_type: {self.algorithm.advantage_fn_type}") + if self.algorithm.advantage_fn_args is None: + self.algorithm.advantage_fn_args = advantage_fn_cls.default_args() + def check_and_update(self) -> None: # noqa: C901 """Check and update the config.""" self._check_deprecated() diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index a4d9a6e8d9..fb9f810dee 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -182,6 +182,9 @@ class KL_Ctrl: @dataclass class Algorithm: + # ! DO NOT SET gamma or lam below; they are kept here merely for compatibility with verl, + # and their values will be overwritten by those in AlgorithmConfig.advantage_fn_args + # if they are really needed (e.g., for GAE advantage/returns computation) gamma: float = 1.0 lam: float = 1.0 adv_estimator: str = "gae" @@ -190,7 +193,7 @@ class Algorithm: kl_penalty: str = "kl" kl_ctrl: KL_Ctrl = field(default_factory=KL_Ctrl) - # ! DO NOT SET THE FLOWING PARAMETERS + # ! DO NOT SET THE FOLLOWING PARAMETERS policy_loss_fn: str = "ppo" policy_loss_fn_args: Optional[dict] = None @@ -315,17 +318,22 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901 self.actor_rollout_ref.actor.clip_ratio = config.trainer.actor_clip_ratio # Algorithm related config - if config.algorithm.gamma is not None: - self.algorithm.gamma = config.algorithm.gamma - if config.algorithm.lam is not None: - self.algorithm.lam = config.algorithm.lam + adv_fn_args = config.algorithm.advantage_fn_args + if adv_fn_args is not None and "gamma" in adv_fn_args: + self.algorithm.gamma = adv_fn_args["gamma"] + if adv_fn_args is not None and "lam" in adv_fn_args: + self.algorithm.lam = adv_fn_args["lam"] self.actor_rollout_ref.actor.algorithm_type = config.algorithm.algorithm_type if config.algorithm.algorithm_type == AlgorithmType.PPO: - logger.info("Using GAE `adv_estimator` for PPO") + logger.info("Setting `adv_estimator` to 'gae' for PPO") self.algorithm.adv_estimator = AdvantageEstimator.GAE.value - elif config.algorithm.algorithm_type == AlgorithmType.GRPO: - logger.info("Using GRPO `adv_estimator` for GRPO") + elif config.algorithm.algorithm_type in (AlgorithmType.GRPO, AlgorithmType.OPMD): + logger.info("Setting `adv_estimator` to 'grpo' for GRPO/OPMD") self.algorithm.adv_estimator = AdvantageEstimator.GRPO.value + # TODO (yanxi): it seems that adv_estimator now only affects whether use_critic is set to + # True or False in RayPPOTrainer.__init__() (and hence in VerlPPOTrainerWrapper). + # Need to double check whether this is indeed the case, + # and see if adv_estimator can be removed completely. if self.actor_rollout_ref.actor.algorithm_type.is_dpo(): # for DPO if not self.actor_rollout_ref.actor.use_kl_loss: diff --git a/trinity/trainer/verl/core_algos.py b/trinity/trainer/verl/core_algos.py index 20cffc9962..f104e0f4f4 100644 --- a/trinity/trainer/verl/core_algos.py +++ b/trinity/trainer/verl/core_algos.py @@ -139,8 +139,8 @@ def compute_gae_advantage_return( token_level_rewards: torch.Tensor, values: torch.Tensor, eos_mask: torch.Tensor, - gamma: torch.Tensor, - lam: torch.Tensor, + gamma: float, + lam: float, ): """Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py @@ -283,7 +283,7 @@ def compute_rloo_outcome_advantage( def compute_reinforce_plus_plus_outcome_advantage( - token_level_rewards: torch.Tensor, eos_mask: torch.Tensor, gamma: torch.Tensor + token_level_rewards: torch.Tensor, eos_mask: torch.Tensor, gamma: float ): """ Compute advantage for REINFORCE++. diff --git a/trinity/trainer/verl/ray_trainer.py b/trinity/trainer/verl/ray_trainer.py index 7073319db0..5d883d05bb 100644 --- a/trinity/trainer/verl/ray_trainer.py +++ b/trinity/trainer/verl/ray_trainer.py @@ -16,18 +16,14 @@ """ import os -import uuid from contextlib import contextmanager -from copy import deepcopy from dataclasses import dataclass, field from enum import Enum -from pprint import pprint from typing import Dict, Type import numpy as np import ray import torch -import tqdm from codetiming import Timer from omegaconf import OmegaConf, open_dict from torch.utils.data import RandomSampler, SequentialSampler @@ -41,12 +37,6 @@ RayWorkerGroup, ) from verl.single_controller.ray.base import create_colocated_worker_cls -from verl.trainer.ppo.metric_utils import ( - compute_data_metrics, - compute_throughout_metrics, - compute_timing_metrics, - reduce_metrics, -) from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn from verl.utils.seqlen_balancing import ( @@ -206,116 +196,6 @@ def compute_response_mask(data: DataProto): return attention_mask[:, -response_length:] -def compute_advantage(data: DataProto, **kwargs): - """Extend verl's original compute_advantage with OPMD""" - - algorithm_type: AlgorithmType = kwargs.get("algorithm_type", AlgorithmType.PPO) - - if algorithm_type == AlgorithmType.OPMD: - tau = kwargs.get("tau", 1.0) - opmd_baseline = kwargs.get("opmd_baseline", "mean") - - return compute_advantage_opmd( - data=data, - tau=tau, - opmd_baseline=opmd_baseline, - ) - - elif algorithm_type == AlgorithmType.PAIRWISE_OPMD: - data.batch["advantages"] = None - data.batch["returns"] = None - return data - - elif algorithm_type.is_rft(): - adv_estimator = kwargs.get("adv_estimator", None) - gamma = kwargs.get("gamma", 1.0) - lam = kwargs.get("lam", 1.0) - num_repeat = kwargs.get("num_repeat", 1) - - return compute_advantage_ppo( - data=data, - adv_estimator=adv_estimator, - gamma=gamma, - lam=lam, - num_repeat=num_repeat, - ) - - else: - raise ValueError(f"Get invalid algorithm_type '{algorithm_type}'.") - - -def compute_advantage_opmd(data: DataProto, tau=1.0, opmd_baseline="mean"): - # Modified from GRPO version - token_level_rewards = data.batch["token_level_rewards"] - index = data.non_tensor_batch["uid"] - responses = data.batch["responses"] - response_length = responses.size(-1) - attention_mask = data.batch["attention_mask"] - response_mask = attention_mask[:, -response_length:] - advantages, returns = core_algos.compute_opmd_outcome_advantage( - token_level_rewards=token_level_rewards, - eos_mask=response_mask, - index=index, - opmd_baseline=opmd_baseline, - tau=tau, - ) - data.batch["advantages"] = advantages - data.batch["returns"] = returns - - return data - - -def compute_advantage_ppo(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_repeat=1): - # prepare response group - # TODO: add other ways to estimate advantages - if adv_estimator == AdvantageEstimator.GAE: - advantages, returns = core_algos.compute_gae_advantage_return( - token_level_rewards=data.batch["token_level_rewards"], - values=data.batch["values"], - eos_mask=data.batch["response_mask"], - gamma=gamma, - lam=lam, - ) - data.batch["advantages"] = advantages - data.batch["returns"] = returns - elif adv_estimator == AdvantageEstimator.GRPO: - advantages, returns = core_algos.compute_grpo_outcome_advantage( - token_level_rewards=data.batch["token_level_rewards"], - eos_mask=data.batch["response_mask"], - index=data.non_tensor_batch["uid"], - ) - data.batch["advantages"] = advantages - data.batch["returns"] = returns - elif adv_estimator == AdvantageEstimator.REINFORCE_PLUS_PLUS: - advantages, returns = core_algos.compute_reinforce_plus_plus_outcome_advantage( - token_level_rewards=data.batch["token_level_rewards"], - eos_mask=data.batch["response_mask"], - gamma=gamma, - ) - data.batch["advantages"] = advantages - data.batch["returns"] = returns - elif adv_estimator == AdvantageEstimator.REMAX: - advantages, returns = core_algos.compute_remax_outcome_advantage( - token_level_rewards=data.batch["token_level_rewards"], - reward_baselines=data.batch["reward_baselines"], - eos_mask=data.batch["response_mask"], - ) - - data.batch["advantages"] = advantages - data.batch["returns"] = returns - elif adv_estimator == AdvantageEstimator.RLOO: - advantages, returns = core_algos.compute_rloo_outcome_advantage( - token_level_rewards=data.batch["token_level_rewards"], - eos_mask=data.batch["response_mask"], - index=data.non_tensor_batch["uid"], - ) - data.batch["advantages"] = advantages - data.batch["returns"] = returns - else: - raise NotImplementedError - return data - - @contextmanager def _timer(name: str, timing_raw: Dict[str, float]): with Timer(name=name, logger=None) as timer: @@ -934,227 +814,3 @@ def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqle seqlen_list=global_seqlen_lst, partitions=global_partition_lst, prefix=logging_prefix ) metrics.update(global_balance_stats) - - def fit(self): # noqa: C901 - """ - The training loop of PPO. - The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow. - The light-weight advantage computation is done on the driver process. - """ - from omegaconf import OmegaConf - from verl.utils.tracking import Tracking - - logger = Tracking( - project_name=self.config.trainer.project_name, - experiment_name=self.config.trainer.experiment_name, - default_backend=self.config.trainer.logger, - config=OmegaConf.to_container(self.config, resolve=True), - ) - - self.global_steps = 0 - - # load checkpoint before doing anything - self._load_checkpoint() - - # perform validation before training - # currently, we only support validation using the reward_function. - if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): - val_metrics = self._validate() - pprint(f"Initial validation metrics: {val_metrics}") - logger.log(data=val_metrics, step=self.global_steps) - if self.config.trainer.get("val_only", False): - return - - # add tqdm - progress_bar = tqdm( - total=self.total_training_steps, initial=self.global_steps, desc="Training Progress" - ) - - # we start from step 1 - self.global_steps += 1 - last_val_metrics = None - - for epoch in range(self.config.trainer.total_epochs): - for batch_dict in self.train_dataloader: - metrics = {} - timing_raw = {} - - batch: DataProto = DataProto.from_single_dict(batch_dict) - - # pop those keys for generation - if "multi_modal_inputs" in batch.non_tensor_batch.keys(): - gen_batch = batch.pop( - batch_keys=["input_ids", "attention_mask", "position_ids"], - non_tensor_batch_keys=[ - "raw_prompt_ids", - "multi_modal_data", - "multi_modal_inputs", - ], - ) - else: - gen_batch = batch.pop( - batch_keys=["input_ids", "attention_mask", "position_ids"], - non_tensor_batch_keys=["raw_prompt_ids"], - ) - - is_last_step = self.global_steps >= self.total_training_steps - - with _timer("step", timing_raw): - # generate a batch - with _timer("gen", timing_raw): - gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) - - if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: - with _timer("gen_max", timing_raw): - gen_baseline_batch = deepcopy(gen_batch) - gen_baseline_batch.meta_info["do_sample"] = False - gen_baseline_output = self.actor_rollout_wg.generate_sequences( - gen_baseline_batch - ) - - batch = batch.union(gen_baseline_output) - reward_baseline_tensor = self.reward_fn(batch) - reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1) - - batch.pop(batch_keys=list(gen_baseline_output.batch.keys())) - - batch.batch["reward_baselines"] = reward_baseline_tensor - - del gen_baseline_batch, gen_baseline_output - - batch.non_tensor_batch["uid"] = np.array( - [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object - ) - # repeat to align with repeated responses in rollout - batch = batch.repeat( - repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True - ) - batch = batch.union(gen_batch_output) - - batch.batch["response_mask"] = compute_response_mask(batch) - - # balance the number of valid tokens on each dp rank. - # Note that this breaks the order of data inside the batch. - # Please take care when you implement group based adv computation such as GRPO and rloo - if self.config.trainer.balance_batch: - self._balance_batch(batch, metrics=metrics) - - # compute global_valid tokens - batch.meta_info["global_token_num"] = torch.sum( - batch.batch["attention_mask"], dim=-1 - ).tolist() - - # recompute old_log_probs - with _timer("old_log_prob", timing_raw): - old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) - batch = batch.union(old_log_prob) - - if self.use_reference_policy: - # compute reference log_prob - with _timer("ref", timing_raw): - ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) - batch = batch.union(ref_log_prob) - - # compute values - if self.use_critic: - with _timer("values", timing_raw): - values = self.critic_wg.compute_values(batch) - batch = batch.union(values) - - with _timer("adv", timing_raw): - # compute scores. Support both model and function-based. - # We first compute the scores using reward model. Then, we call reward_fn to combine - # the results from reward model and rule-based results. - if self.use_rm: - # we first compute reward model score - reward_tensor = self.rm_wg.compute_rm_score(batch) - batch = batch.union(reward_tensor) - - # we combine with rule-based rm - reward_tensor = self.reward_fn(batch) - batch.batch["token_level_scores"] = reward_tensor - - # compute rewards. apply_kl_penalty if available - if not self.config.actor_rollout_ref.actor.get("use_kl_loss", False): - batch, kl_metrics = apply_kl_penalty( - batch, - kl_ctrl=self.kl_ctrl, - kl_penalty=self.config.algorithm.kl_penalty, - ) - metrics.update(kl_metrics) - else: - batch.batch["token_level_rewards"] = batch.batch["token_level_scores"] - - # compute advantages, executed on the driver process - algorithm_type = self.config.actor_rollout_ref.actor.get( - "algorithm_type", AlgorithmType.PPO - ) - tau = self.config.actor_rollout_ref.actor.get("tau", 1.0) - opmd_baseline = self.config.actor_rollout_ref.actor.get( - "opmd_baseline", "mean" - ) - batch = compute_advantage( - batch, - algorithm_type=algorithm_type, - adv_estimator=self.config.algorithm.adv_estimator, - gamma=self.config.algorithm.gamma, - lam=self.config.algorithm.lam, - num_repeat=self.config.actor_rollout_ref.rollout.n, - # additional config params for OPMD - tau=tau, - opmd_baseline=opmd_baseline, - ) - - # update critic - if self.use_critic: - with _timer("update_critic", timing_raw): - critic_output = self.critic_wg.update_critic(batch) - critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"]) - metrics.update(critic_output_metrics) - - # implement critic warmup - if self.config.trainer.critic_warmup <= self.global_steps: - # update actor - with _timer("update_actor", timing_raw): - actor_output = self.actor_rollout_wg.update_actor(batch) - actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) - metrics.update(actor_output_metrics) - - # validate - if ( - self.val_reward_fn is not None - and self.config.trainer.test_freq > 0 - and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0) - ): - with _timer("testing", timing_raw): - val_metrics: dict = self._validate() - if is_last_step: - last_val_metrics = val_metrics - metrics.update(val_metrics) - - if self.config.trainer.save_freq > 0 and ( - is_last_step or self.global_steps % self.config.trainer.save_freq == 0 - ): - with _timer("save_checkpoint", timing_raw): - self._save_checkpoint() - - # collect metrics - metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) - metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) - - # Implement actual tflpo and theoretical tflpo - n_gpus = self.resource_pool_manager.get_n_gpus() - metrics.update( - compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus) - ) - - # TODO: make a canonical logger that supports various backend - logger.log(data=metrics, step=self.global_steps) - - if is_last_step: - pprint(f"Final validation metrics: {last_val_metrics}") - progress_bar.close() - return - - progress_bar.update(1) - self.global_steps += 1 diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index 5324a13f7c..b6397adde7 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -4,15 +4,24 @@ Modified from verl/trainer/ppo/ray_trainer.py """ import os +from pprint import pprint from typing import Tuple +import numpy as np import pandas as pd import ray import torch from omegaconf import OmegaConf +from verl.trainer.ppo.metric_utils import ( + compute_data_metrics, + compute_throughout_metrics, + compute_timing_metrics, + reduce_metrics, +) from verl.utils import hf_tokenizer from verl.utils.fs import copy_local_path_from_hdfs +from trinity.algorithm import ADVANTAGE_FN from trinity.common.config import AlgorithmConfig, Config from trinity.common.constants import AlgorithmType from trinity.common.experience import Experiences @@ -25,14 +34,7 @@ Role, _timer, apply_kl_penalty, - compute_advantage, - compute_data_metrics, - compute_throughout_metrics, - compute_timing_metrics, find_latest_ckpt_path, - np, - pprint, - reduce_metrics, ) from trinity.utils.monitor import Monitor @@ -126,6 +128,14 @@ def __init__( ) self.init_workers() self.algorithm_type = AlgorithmType.PPO + + # specify advantage function for various rft algorithms + algo_config = global_config.algorithm + if algo_config.algorithm_type.is_rft(): + adv_fn_type = algo_config.advantage_fn_type + adv_fn_args = algo_config.advantage_fn_args + self.advantage_fn = ADVANTAGE_FN.get(adv_fn_type)(**adv_fn_args) + self.logger = Monitor( project=config.trainer.project_name, name=config.trainer.experiment_name, @@ -377,26 +387,7 @@ def train_rft_step(self, experiences: Experiences) -> Tuple[bool, int]: batch.batch["token_level_rewards"] = batch.batch["token_level_scores"] # compute advantages, executed on the driver process - kwargs = {} - algorithm_type = self.config.actor_rollout_ref.actor.get( - "algorithm_type", AlgorithmType.PPO - ) - if algorithm_type == AlgorithmType.OPMD: - tau = self.config.actor_rollout_ref.actor.get("tau", 0.0) - opmd_baseline = self.config.actor_rollout_ref.actor.get("opmd_baseline", "mean") - kwargs = { - "algorithm_type": algorithm_type, - "tau": tau, - "opmd_baseline": opmd_baseline, - } - batch = compute_advantage( - batch, - adv_estimator=self.config.algorithm.adv_estimator, - gamma=self.config.algorithm.gamma, - lam=self.config.algorithm.lam, - num_repeat=self.config.actor_rollout_ref.rollout.n, - **kwargs, - ) + batch, _ = self.advantage_fn(batch) # update critic if self.use_critic: