From 0eab23bc764770b5c32dbd036f050ce071ccd513 Mon Sep 17 00:00:00 2001 From: yanxi-chen Date: Wed, 28 May 2025 18:52:16 +0800 Subject: [PATCH 1/6] Refactor advantage calculation. TODO: adjust config accordingly. --- .../algorithm/advantage_fn/advantage_fn.py | 181 ++++++++++ trinity/common/config.py | 4 + trinity/trainer/verl/ray_trainer.py | 334 ------------------ trinity/trainer/verl_trainer.py | 31 +- 4 files changed, 195 insertions(+), 355 deletions(-) diff --git a/trinity/algorithm/advantage_fn/advantage_fn.py b/trinity/algorithm/advantage_fn/advantage_fn.py index 7e965b017c..0fb5ef199c 100644 --- a/trinity/algorithm/advantage_fn/advantage_fn.py +++ b/trinity/algorithm/advantage_fn/advantage_fn.py @@ -1,6 +1,8 @@ from abc import ABC, abstractmethod from typing import Any, Dict, Tuple +from verl import DataProto +from trinity.trainer.verl import core_algos from trinity.utils.registry import Registry ADVANTAGE_FN = Registry("advantage_fn") @@ -19,3 +21,182 @@ def __call__(self, exps: Any, **kwargs: Dict) -> Tuple[Any, Dict]: `Any`: The experiences with advantages. `Dict`: The metrics for logging. """ + + +@ADVANTAGE_FN.register("ppo_adv_fn") +class PPOAdvantageFn(AdvantageFn): + """PPO's GAE advantage computation""" + + def __init__( + self, + gamma, + lam, + ): + self.gamma = gamma + self.lam = lam + + + def __call__(self, exps: DataProto) -> Tuple[DataProto, Dict]: + """Adapted from compute_advantage_ppo in ray_trainer.py""" + + advantages, returns = core_algos.compute_gae_advantage_return( + token_level_rewards=exps.batch["token_level_rewards"], + values=exps.batch["values"], + eos_mask=exps.batch["response_mask"], + gamma=self.gamma, + lam=self.lam, + ) + exps.batch["advantages"] = advantages + exps.batch["returns"] = returns + + metrics = { + "abc": "xyz", # TODO: add meaningful metrics + } + + return exps, metrics + + +@ADVANTAGE_FN.register("grpo_adv_fn") +class GRPOAdvantageFn(AdvantageFn): + """GRPO advantage computation""" + + def __init__( + self, + ): + pass + + + def __call__(self, exps: DataProto) -> Tuple[DataProto, Dict]: + """Adapted from compute_advantage_ppo in ray_trainer.py""" + + advantages, returns = core_algos.compute_grpo_outcome_advantage( + token_level_rewards=exps.batch["token_level_rewards"], + eos_mask=exps.batch["response_mask"], + index=exps.non_tensor_batch["uid"], + ) + exps.batch["advantages"] = advantages + exps.batch["returns"] = returns + + metrics = { + "abc": "xyz", # TODO: add meaningful metrics + } + + return exps, metrics + + +@ADVANTAGE_FN.register("reinforceplusplus_adv_fn") +class REINFORCEPLUSPLUSAdvantageFn(AdvantageFn): + """REINFORCE++ advantage computation""" + + def __init__( + self, + gamma, + ): + self.gamma = gamma + + + def __call__(self, exps: DataProto) -> Tuple[DataProto, Dict]: + """Adapted from compute_advantage_ppo in ray_trainer.py""" + + advantages, returns = core_algos.compute_reinforce_plus_plus_outcome_advantage( + token_level_rewards=exps.batch["token_level_rewards"], + eos_mask=exps.batch["response_mask"], + gamma=self.gamma, + ) + exps.batch["advantages"] = advantages + exps.batch["returns"] = returns + + metrics = { + "abc": "xyz", # TODO: add meaningful metrics + } + + return exps, metrics + + +@ADVANTAGE_FN.register("remax_adv_fn") +class REMAXAdvantageFn(AdvantageFn): + """REMAX advantage computation""" + + def __init__( + self, + ): + pass + + + def __call__(self, exps: DataProto) -> Tuple[DataProto, Dict]: + """Adapted from compute_advantage_ppo in ray_trainer.py""" + + advantages, returns = core_algos.compute_remax_outcome_advantage( + token_level_rewards=exps.batch["token_level_rewards"], + reward_baselines=exps.batch["reward_baselines"], + eos_mask=exps.batch["response_mask"], + ) + exps.batch["advantages"] = advantages + exps.batch["returns"] = returns + + metrics = { + "abc": "xyz", # TODO: add meaningful metrics + } + + return exps, metrics + + + +@ADVANTAGE_FN.register("rloo_adv_fn") +class RLOOAdvantageFn(AdvantageFn): + """RLOO advantage computation""" + + def __init__( + self, + ): + pass + + + def __call__(self, exps: DataProto) -> Tuple[DataProto, Dict]: + """Adapted from compute_advantage_ppo in ray_trainer.py""" + + advantages, returns = core_algos.compute_rloo_outcome_advantage( + token_level_rewards=exps.batch["token_level_rewards"], + eos_mask=exps.batch["response_mask"], + index=exps.non_tensor_batch["uid"], + ) + exps.batch["advantages"] = advantages + exps.batch["returns"] = returns + + metrics = { + "abc": "xyz", # TODO: add meaningful metrics + } + + return exps, metrics + + +@ADVANTAGE_FN.register("opmd_adv_fn") +class OPMDAdvantageFn(AdvantageFn): + """OPMD advantage computation""" + + def __init__( + self, + ): + pass + + + def __call__(self, exps: DataProto) -> Tuple[DataProto, Dict]: + """Adapted from compute_advantage_opmd in ray_trainer.py""" + + advantages, returns = core_algos.compute_opmd_outcome_advantage( + token_level_rewards=exps.batch["token_level_rewards"], + eos_mask=exps.batch["response_mask"], + # TODO: check consistency with exps.batch["attention_mask"][:, -response_length:] in original implementation + index=exps.non_tensor_batch["uid"], + opmd_baseline="mean", + tau=1.0, + ) + exps.batch["advantages"] = advantages + exps.batch["returns"] = returns + + metrics = { + "abc": "xyz", # TODO: add meaningful metrics + } + + return exps, metrics + diff --git a/trinity/common/config.py b/trinity/common/config.py index e0660ab03a..f37f7783cf 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -173,10 +173,14 @@ class AlgorithmConfig: algorithm_type: AlgorithmType = AlgorithmType.PPO # for GRPO-like algorithms, repeat each task for `repeat_times` times repeat_times: int = 1 + # configs for advantage calculation + advantage_fn_type: Optional[str] = "ppo_adv_fn" gamma: Optional[float] = None lam: Optional[float] = None # TODO: add more algorithm params here + # TODO (yanxi): add advantage_fn_type, advantage_fn_args; or not? (keep it simple) + @dataclass class ClusterConfig: diff --git a/trinity/trainer/verl/ray_trainer.py b/trinity/trainer/verl/ray_trainer.py index 7073319db0..17fd174a01 100644 --- a/trinity/trainer/verl/ray_trainer.py +++ b/trinity/trainer/verl/ray_trainer.py @@ -206,116 +206,6 @@ def compute_response_mask(data: DataProto): return attention_mask[:, -response_length:] -def compute_advantage(data: DataProto, **kwargs): - """Extend verl's original compute_advantage with OPMD""" - - algorithm_type: AlgorithmType = kwargs.get("algorithm_type", AlgorithmType.PPO) - - if algorithm_type == AlgorithmType.OPMD: - tau = kwargs.get("tau", 1.0) - opmd_baseline = kwargs.get("opmd_baseline", "mean") - - return compute_advantage_opmd( - data=data, - tau=tau, - opmd_baseline=opmd_baseline, - ) - - elif algorithm_type == AlgorithmType.PAIRWISE_OPMD: - data.batch["advantages"] = None - data.batch["returns"] = None - return data - - elif algorithm_type.is_rft(): - adv_estimator = kwargs.get("adv_estimator", None) - gamma = kwargs.get("gamma", 1.0) - lam = kwargs.get("lam", 1.0) - num_repeat = kwargs.get("num_repeat", 1) - - return compute_advantage_ppo( - data=data, - adv_estimator=adv_estimator, - gamma=gamma, - lam=lam, - num_repeat=num_repeat, - ) - - else: - raise ValueError(f"Get invalid algorithm_type '{algorithm_type}'.") - - -def compute_advantage_opmd(data: DataProto, tau=1.0, opmd_baseline="mean"): - # Modified from GRPO version - token_level_rewards = data.batch["token_level_rewards"] - index = data.non_tensor_batch["uid"] - responses = data.batch["responses"] - response_length = responses.size(-1) - attention_mask = data.batch["attention_mask"] - response_mask = attention_mask[:, -response_length:] - advantages, returns = core_algos.compute_opmd_outcome_advantage( - token_level_rewards=token_level_rewards, - eos_mask=response_mask, - index=index, - opmd_baseline=opmd_baseline, - tau=tau, - ) - data.batch["advantages"] = advantages - data.batch["returns"] = returns - - return data - - -def compute_advantage_ppo(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_repeat=1): - # prepare response group - # TODO: add other ways to estimate advantages - if adv_estimator == AdvantageEstimator.GAE: - advantages, returns = core_algos.compute_gae_advantage_return( - token_level_rewards=data.batch["token_level_rewards"], - values=data.batch["values"], - eos_mask=data.batch["response_mask"], - gamma=gamma, - lam=lam, - ) - data.batch["advantages"] = advantages - data.batch["returns"] = returns - elif adv_estimator == AdvantageEstimator.GRPO: - advantages, returns = core_algos.compute_grpo_outcome_advantage( - token_level_rewards=data.batch["token_level_rewards"], - eos_mask=data.batch["response_mask"], - index=data.non_tensor_batch["uid"], - ) - data.batch["advantages"] = advantages - data.batch["returns"] = returns - elif adv_estimator == AdvantageEstimator.REINFORCE_PLUS_PLUS: - advantages, returns = core_algos.compute_reinforce_plus_plus_outcome_advantage( - token_level_rewards=data.batch["token_level_rewards"], - eos_mask=data.batch["response_mask"], - gamma=gamma, - ) - data.batch["advantages"] = advantages - data.batch["returns"] = returns - elif adv_estimator == AdvantageEstimator.REMAX: - advantages, returns = core_algos.compute_remax_outcome_advantage( - token_level_rewards=data.batch["token_level_rewards"], - reward_baselines=data.batch["reward_baselines"], - eos_mask=data.batch["response_mask"], - ) - - data.batch["advantages"] = advantages - data.batch["returns"] = returns - elif adv_estimator == AdvantageEstimator.RLOO: - advantages, returns = core_algos.compute_rloo_outcome_advantage( - token_level_rewards=data.batch["token_level_rewards"], - eos_mask=data.batch["response_mask"], - index=data.non_tensor_batch["uid"], - ) - data.batch["advantages"] = advantages - data.batch["returns"] = returns - else: - raise NotImplementedError - return data - - @contextmanager def _timer(name: str, timing_raw: Dict[str, float]): with Timer(name=name, logger=None) as timer: @@ -934,227 +824,3 @@ def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqle seqlen_list=global_seqlen_lst, partitions=global_partition_lst, prefix=logging_prefix ) metrics.update(global_balance_stats) - - def fit(self): # noqa: C901 - """ - The training loop of PPO. - The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow. - The light-weight advantage computation is done on the driver process. - """ - from omegaconf import OmegaConf - from verl.utils.tracking import Tracking - - logger = Tracking( - project_name=self.config.trainer.project_name, - experiment_name=self.config.trainer.experiment_name, - default_backend=self.config.trainer.logger, - config=OmegaConf.to_container(self.config, resolve=True), - ) - - self.global_steps = 0 - - # load checkpoint before doing anything - self._load_checkpoint() - - # perform validation before training - # currently, we only support validation using the reward_function. - if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): - val_metrics = self._validate() - pprint(f"Initial validation metrics: {val_metrics}") - logger.log(data=val_metrics, step=self.global_steps) - if self.config.trainer.get("val_only", False): - return - - # add tqdm - progress_bar = tqdm( - total=self.total_training_steps, initial=self.global_steps, desc="Training Progress" - ) - - # we start from step 1 - self.global_steps += 1 - last_val_metrics = None - - for epoch in range(self.config.trainer.total_epochs): - for batch_dict in self.train_dataloader: - metrics = {} - timing_raw = {} - - batch: DataProto = DataProto.from_single_dict(batch_dict) - - # pop those keys for generation - if "multi_modal_inputs" in batch.non_tensor_batch.keys(): - gen_batch = batch.pop( - batch_keys=["input_ids", "attention_mask", "position_ids"], - non_tensor_batch_keys=[ - "raw_prompt_ids", - "multi_modal_data", - "multi_modal_inputs", - ], - ) - else: - gen_batch = batch.pop( - batch_keys=["input_ids", "attention_mask", "position_ids"], - non_tensor_batch_keys=["raw_prompt_ids"], - ) - - is_last_step = self.global_steps >= self.total_training_steps - - with _timer("step", timing_raw): - # generate a batch - with _timer("gen", timing_raw): - gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) - - if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: - with _timer("gen_max", timing_raw): - gen_baseline_batch = deepcopy(gen_batch) - gen_baseline_batch.meta_info["do_sample"] = False - gen_baseline_output = self.actor_rollout_wg.generate_sequences( - gen_baseline_batch - ) - - batch = batch.union(gen_baseline_output) - reward_baseline_tensor = self.reward_fn(batch) - reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1) - - batch.pop(batch_keys=list(gen_baseline_output.batch.keys())) - - batch.batch["reward_baselines"] = reward_baseline_tensor - - del gen_baseline_batch, gen_baseline_output - - batch.non_tensor_batch["uid"] = np.array( - [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object - ) - # repeat to align with repeated responses in rollout - batch = batch.repeat( - repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True - ) - batch = batch.union(gen_batch_output) - - batch.batch["response_mask"] = compute_response_mask(batch) - - # balance the number of valid tokens on each dp rank. - # Note that this breaks the order of data inside the batch. - # Please take care when you implement group based adv computation such as GRPO and rloo - if self.config.trainer.balance_batch: - self._balance_batch(batch, metrics=metrics) - - # compute global_valid tokens - batch.meta_info["global_token_num"] = torch.sum( - batch.batch["attention_mask"], dim=-1 - ).tolist() - - # recompute old_log_probs - with _timer("old_log_prob", timing_raw): - old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) - batch = batch.union(old_log_prob) - - if self.use_reference_policy: - # compute reference log_prob - with _timer("ref", timing_raw): - ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) - batch = batch.union(ref_log_prob) - - # compute values - if self.use_critic: - with _timer("values", timing_raw): - values = self.critic_wg.compute_values(batch) - batch = batch.union(values) - - with _timer("adv", timing_raw): - # compute scores. Support both model and function-based. - # We first compute the scores using reward model. Then, we call reward_fn to combine - # the results from reward model and rule-based results. - if self.use_rm: - # we first compute reward model score - reward_tensor = self.rm_wg.compute_rm_score(batch) - batch = batch.union(reward_tensor) - - # we combine with rule-based rm - reward_tensor = self.reward_fn(batch) - batch.batch["token_level_scores"] = reward_tensor - - # compute rewards. apply_kl_penalty if available - if not self.config.actor_rollout_ref.actor.get("use_kl_loss", False): - batch, kl_metrics = apply_kl_penalty( - batch, - kl_ctrl=self.kl_ctrl, - kl_penalty=self.config.algorithm.kl_penalty, - ) - metrics.update(kl_metrics) - else: - batch.batch["token_level_rewards"] = batch.batch["token_level_scores"] - - # compute advantages, executed on the driver process - algorithm_type = self.config.actor_rollout_ref.actor.get( - "algorithm_type", AlgorithmType.PPO - ) - tau = self.config.actor_rollout_ref.actor.get("tau", 1.0) - opmd_baseline = self.config.actor_rollout_ref.actor.get( - "opmd_baseline", "mean" - ) - batch = compute_advantage( - batch, - algorithm_type=algorithm_type, - adv_estimator=self.config.algorithm.adv_estimator, - gamma=self.config.algorithm.gamma, - lam=self.config.algorithm.lam, - num_repeat=self.config.actor_rollout_ref.rollout.n, - # additional config params for OPMD - tau=tau, - opmd_baseline=opmd_baseline, - ) - - # update critic - if self.use_critic: - with _timer("update_critic", timing_raw): - critic_output = self.critic_wg.update_critic(batch) - critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"]) - metrics.update(critic_output_metrics) - - # implement critic warmup - if self.config.trainer.critic_warmup <= self.global_steps: - # update actor - with _timer("update_actor", timing_raw): - actor_output = self.actor_rollout_wg.update_actor(batch) - actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) - metrics.update(actor_output_metrics) - - # validate - if ( - self.val_reward_fn is not None - and self.config.trainer.test_freq > 0 - and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0) - ): - with _timer("testing", timing_raw): - val_metrics: dict = self._validate() - if is_last_step: - last_val_metrics = val_metrics - metrics.update(val_metrics) - - if self.config.trainer.save_freq > 0 and ( - is_last_step or self.global_steps % self.config.trainer.save_freq == 0 - ): - with _timer("save_checkpoint", timing_raw): - self._save_checkpoint() - - # collect metrics - metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) - metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) - - # Implement actual tflpo and theoretical tflpo - n_gpus = self.resource_pool_manager.get_n_gpus() - metrics.update( - compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus) - ) - - # TODO: make a canonical logger that supports various backend - logger.log(data=metrics, step=self.global_steps) - - if is_last_step: - pprint(f"Final validation metrics: {last_val_metrics}") - progress_bar.close() - return - - progress_bar.update(1) - self.global_steps += 1 diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index 7590d6075b..d092e79bbe 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -25,7 +25,6 @@ Role, _timer, apply_kl_penalty, - compute_advantage, compute_data_metrics, compute_throughout_metrics, compute_timing_metrics, @@ -35,6 +34,7 @@ reduce_metrics, ) from trinity.utils.monitor import Monitor +from trinity.algorithm import ADVANTAGE_FN class _InternalDataLoader: @@ -128,6 +128,14 @@ def __init__( self.algorithm_type = ( AlgorithmType.PPO ) # TODO: initialize algorithm_type according to config + + # specify advantage function for various rft algorithms + algo_config = global_config.algorithm + if algo_config.algorithm_type.is_rft(): + adv_fn_type = algo_config.advantage_fn_type + adv_fn_args = algo_config.get("advantage_fn_args", {}) # TODO: does this work properly?? + self.advantage_fn = ADVANTAGE_FN.get(adv_fn_type)(**adv_fn_args) + self.logger = Monitor( project=config.trainer.project_name, name=config.trainer.experiment_name, @@ -379,26 +387,7 @@ def train_rft_step(self, experiences: Experiences) -> Tuple[bool, int]: batch.batch["token_level_rewards"] = batch.batch["token_level_scores"] # compute advantages, executed on the driver process - kwargs = {} - algorithm_type = self.config.actor_rollout_ref.actor.get( - "algorithm_type", AlgorithmType.PPO - ) - if algorithm_type == AlgorithmType.OPMD: - tau = self.config.actor_rollout_ref.actor.get("tau", 0.0) - opmd_baseline = self.config.actor_rollout_ref.actor.get("opmd_baseline", "mean") - kwargs = { - "algorithm_type": algorithm_type, - "tau": tau, - "opmd_baseline": opmd_baseline, - } - batch = compute_advantage( - batch, - adv_estimator=self.config.algorithm.adv_estimator, - gamma=self.config.algorithm.gamma, - lam=self.config.algorithm.lam, - num_repeat=self.config.actor_rollout_ref.rollout.n, - **kwargs, - ) + batch, _ = self.advantage_fn(batch) # update critic if self.use_critic: From 453d0d449fac33e5627458952a529808739a484b Mon Sep 17 00:00:00 2001 From: yanxi-chen Date: Wed, 28 May 2025 20:15:40 +0800 Subject: [PATCH 2/6] Adjust API and config. TODO: update yaml config files and test. --- .../algorithm/advantage_fn/advantage_fn.py | 56 ++++++------------- trinity/common/config.py | 4 +- trinity/trainer/verl/ray_trainer.py | 10 ---- trinity/trainer/verl_trainer.py | 6 +- 4 files changed, 22 insertions(+), 54 deletions(-) diff --git a/trinity/algorithm/advantage_fn/advantage_fn.py b/trinity/algorithm/advantage_fn/advantage_fn.py index 0fb5ef199c..581d2cfce5 100644 --- a/trinity/algorithm/advantage_fn/advantage_fn.py +++ b/trinity/algorithm/advantage_fn/advantage_fn.py @@ -2,6 +2,7 @@ from typing import Any, Dict, Tuple from verl import DataProto + from trinity.trainer.verl import core_algos from trinity.utils.registry import Registry @@ -27,16 +28,11 @@ def __call__(self, exps: Any, **kwargs: Dict) -> Tuple[Any, Dict]: class PPOAdvantageFn(AdvantageFn): """PPO's GAE advantage computation""" - def __init__( - self, - gamma, - lam, - ): - self.gamma = gamma - self.lam = lam - + def __init__(self, **kwargs): + self.gamma = kwargs.get("gamma") + self.lam = kwargs.get("lam") - def __call__(self, exps: DataProto) -> Tuple[DataProto, Dict]: + def __call__(self, exps: DataProto, **kwargs) -> Tuple[DataProto, Dict]: """Adapted from compute_advantage_ppo in ray_trainer.py""" advantages, returns = core_algos.compute_gae_advantage_return( @@ -60,13 +56,10 @@ def __call__(self, exps: DataProto) -> Tuple[DataProto, Dict]: class GRPOAdvantageFn(AdvantageFn): """GRPO advantage computation""" - def __init__( - self, - ): + def __init__(self, **kwargs): pass - - def __call__(self, exps: DataProto) -> Tuple[DataProto, Dict]: + def __call__(self, exps: DataProto, **kwargs) -> Tuple[DataProto, Dict]: """Adapted from compute_advantage_ppo in ray_trainer.py""" advantages, returns = core_algos.compute_grpo_outcome_advantage( @@ -88,14 +81,10 @@ def __call__(self, exps: DataProto) -> Tuple[DataProto, Dict]: class REINFORCEPLUSPLUSAdvantageFn(AdvantageFn): """REINFORCE++ advantage computation""" - def __init__( - self, - gamma, - ): - self.gamma = gamma + def __init__(self, **kwargs): + self.gamma = kwargs.get("gamma") - - def __call__(self, exps: DataProto) -> Tuple[DataProto, Dict]: + def __call__(self, exps: DataProto, **kwargs) -> Tuple[DataProto, Dict]: """Adapted from compute_advantage_ppo in ray_trainer.py""" advantages, returns = core_algos.compute_reinforce_plus_plus_outcome_advantage( @@ -117,13 +106,10 @@ def __call__(self, exps: DataProto) -> Tuple[DataProto, Dict]: class REMAXAdvantageFn(AdvantageFn): """REMAX advantage computation""" - def __init__( - self, - ): + def __init__(self, **kwargs): pass - - def __call__(self, exps: DataProto) -> Tuple[DataProto, Dict]: + def __call__(self, exps: DataProto, **kwargs) -> Tuple[DataProto, Dict]: """Adapted from compute_advantage_ppo in ray_trainer.py""" advantages, returns = core_algos.compute_remax_outcome_advantage( @@ -141,18 +127,14 @@ def __call__(self, exps: DataProto) -> Tuple[DataProto, Dict]: return exps, metrics - @ADVANTAGE_FN.register("rloo_adv_fn") class RLOOAdvantageFn(AdvantageFn): """RLOO advantage computation""" - def __init__( - self, - ): + def __init__(self, **kwargs): pass - - def __call__(self, exps: DataProto) -> Tuple[DataProto, Dict]: + def __call__(self, exps: DataProto, **kwargs) -> Tuple[DataProto, Dict]: """Adapted from compute_advantage_ppo in ray_trainer.py""" advantages, returns = core_algos.compute_rloo_outcome_advantage( @@ -174,18 +156,15 @@ def __call__(self, exps: DataProto) -> Tuple[DataProto, Dict]: class OPMDAdvantageFn(AdvantageFn): """OPMD advantage computation""" - def __init__( - self, - ): + def __init__(self, **kwargs): pass - - def __call__(self, exps: DataProto) -> Tuple[DataProto, Dict]: + def __call__(self, exps: DataProto, **kwargs) -> Tuple[DataProto, Dict]: """Adapted from compute_advantage_opmd in ray_trainer.py""" advantages, returns = core_algos.compute_opmd_outcome_advantage( token_level_rewards=exps.batch["token_level_rewards"], - eos_mask=exps.batch["response_mask"], + eos_mask=exps.batch["response_mask"], # TODO: check consistency with exps.batch["attention_mask"][:, -response_length:] in original implementation index=exps.non_tensor_batch["uid"], opmd_baseline="mean", @@ -199,4 +178,3 @@ def __call__(self, exps: DataProto) -> Tuple[DataProto, Dict]: } return exps, metrics - diff --git a/trinity/common/config.py b/trinity/common/config.py index f37f7783cf..74d332712f 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -174,13 +174,11 @@ class AlgorithmConfig: # for GRPO-like algorithms, repeat each task for `repeat_times` times repeat_times: int = 1 # configs for advantage calculation - advantage_fn_type: Optional[str] = "ppo_adv_fn" + advantage_fn_type: str = "ppo_adv_fn" gamma: Optional[float] = None lam: Optional[float] = None # TODO: add more algorithm params here - # TODO (yanxi): add advantage_fn_type, advantage_fn_args; or not? (keep it simple) - @dataclass class ClusterConfig: diff --git a/trinity/trainer/verl/ray_trainer.py b/trinity/trainer/verl/ray_trainer.py index 17fd174a01..5d883d05bb 100644 --- a/trinity/trainer/verl/ray_trainer.py +++ b/trinity/trainer/verl/ray_trainer.py @@ -16,18 +16,14 @@ """ import os -import uuid from contextlib import contextmanager -from copy import deepcopy from dataclasses import dataclass, field from enum import Enum -from pprint import pprint from typing import Dict, Type import numpy as np import ray import torch -import tqdm from codetiming import Timer from omegaconf import OmegaConf, open_dict from torch.utils.data import RandomSampler, SequentialSampler @@ -41,12 +37,6 @@ RayWorkerGroup, ) from verl.single_controller.ray.base import create_colocated_worker_cls -from verl.trainer.ppo.metric_utils import ( - compute_data_metrics, - compute_throughout_metrics, - compute_timing_metrics, - reduce_metrics, -) from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn from verl.utils.seqlen_balancing import ( diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index d092e79bbe..f8dca2b36a 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -13,6 +13,7 @@ from verl.utils import hf_tokenizer from verl.utils.fs import copy_local_path_from_hdfs +from trinity.algorithm import ADVANTAGE_FN from trinity.common.config import Config from trinity.common.constants import AlgorithmType from trinity.common.experience import Experiences @@ -34,7 +35,6 @@ reduce_metrics, ) from trinity.utils.monitor import Monitor -from trinity.algorithm import ADVANTAGE_FN class _InternalDataLoader: @@ -133,7 +133,9 @@ def __init__( algo_config = global_config.algorithm if algo_config.algorithm_type.is_rft(): adv_fn_type = algo_config.advantage_fn_type - adv_fn_args = algo_config.get("advantage_fn_args", {}) # TODO: does this work properly?? + adv_fn_args = algo_config.get( + "advantage_fn_args", {} + ) # TODO (yanxi): does this work properly?? self.advantage_fn = ADVANTAGE_FN.get(adv_fn_type)(**adv_fn_args) self.logger = Monitor( From 0d5bad266625b2b031560ade91056c866f4a7aaa Mon Sep 17 00:00:00 2001 From: yanxi-chen Date: Thu, 29 May 2025 11:54:49 +0800 Subject: [PATCH 3/6] Split adv fn into separate files, and other update (TODO: update yaml configs and config manager) --- tests/template/config.yaml | 5 + trinity/algorithm/__init__.py | 2 +- trinity/algorithm/advantage_fn/__init__.py | 20 +++ .../algorithm/advantage_fn/advantage_fn.py | 167 +----------------- .../algorithm/advantage_fn/grpo_advantage.py | 42 +++++ .../algorithm/advantage_fn/opmd_advantage.py | 45 +++++ .../algorithm/advantage_fn/ppo_advantage.py | 50 ++++++ .../reinforce_plus_plus_advantage.py | 42 +++++ .../algorithm/advantage_fn/remax_advantage.py | 40 +++++ .../algorithm/advantage_fn/rloo_advantage.py | 40 +++++ trinity/common/config.py | 16 +- trinity/common/verl_config.py | 12 +- trinity/trainer/verl_trainer.py | 4 +- 13 files changed, 313 insertions(+), 172 deletions(-) create mode 100644 trinity/algorithm/advantage_fn/grpo_advantage.py create mode 100644 trinity/algorithm/advantage_fn/opmd_advantage.py create mode 100644 trinity/algorithm/advantage_fn/ppo_advantage.py create mode 100644 trinity/algorithm/advantage_fn/reinforce_plus_plus_advantage.py create mode 100644 trinity/algorithm/advantage_fn/remax_advantage.py create mode 100644 trinity/algorithm/advantage_fn/rloo_advantage.py diff --git a/tests/template/config.yaml b/tests/template/config.yaml index a83a82655f..c83d938c66 100644 --- a/tests/template/config.yaml +++ b/tests/template/config.yaml @@ -8,6 +8,11 @@ algorithm: policy_loss_fn: ppo policy_loss_fn_args: clip_range: 0.2 + advantage_fn_type: ppo_adv_fn + advantage_fn_args: + gamma: 1.0 + lam: 1.0 + model: model_path: '' max_prompt_tokens: 2048 diff --git a/trinity/algorithm/__init__.py b/trinity/algorithm/__init__.py index 51d3da8317..170507663f 100644 --- a/trinity/algorithm/__init__.py +++ b/trinity/algorithm/__init__.py @@ -1,4 +1,4 @@ -from trinity.algorithm.advantage_fn.advantage_fn import ADVANTAGE_FN, AdvantageFn +from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn from trinity.algorithm.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn __all__ = [ diff --git a/trinity/algorithm/advantage_fn/__init__.py b/trinity/algorithm/advantage_fn/__init__.py index e69de29bb2..7bcf682e4b 100644 --- a/trinity/algorithm/advantage_fn/__init__.py +++ b/trinity/algorithm/advantage_fn/__init__.py @@ -0,0 +1,20 @@ +from trinity.algorithm.advantage_fn.advantage_fn import ADVANTAGE_FN, AdvantageFn +from trinity.algorithm.advantage_fn.grpo_advantage import GRPOAdvantageFn +from trinity.algorithm.advantage_fn.opmd_advantage import OPMDAdvantageFn +from trinity.algorithm.advantage_fn.ppo_advantage import PPOAdvantageFn +from trinity.algorithm.advantage_fn.reinforce_plus_plus_advantage import ( + REINFORCEPLUSPLUSAdvantageFn, +) +from trinity.algorithm.advantage_fn.remax_advantage import REMAXAdvantageFn +from trinity.algorithm.advantage_fn.rloo_advantage import RLOOAdvantageFn + +__all__ = [ + "ADVANTAGE_FN", + "AdvantageFn", + "PPOAdvantageFn", + "GRPOAdvantageFn", + "REINFORCEPLUSPLUSAdvantageFn", + "REMAXAdvantageFn", + "RLOOAdvantageFn", + "OPMDAdvantageFn", +] diff --git a/trinity/algorithm/advantage_fn/advantage_fn.py b/trinity/algorithm/advantage_fn/advantage_fn.py index 581d2cfce5..21e3668a53 100644 --- a/trinity/algorithm/advantage_fn/advantage_fn.py +++ b/trinity/algorithm/advantage_fn/advantage_fn.py @@ -1,9 +1,6 @@ from abc import ABC, abstractmethod from typing import Any, Dict, Tuple -from verl import DataProto - -from trinity.trainer.verl import core_algos from trinity.utils.registry import Registry ADVANTAGE_FN = Registry("advantage_fn") @@ -19,162 +16,14 @@ def __call__(self, exps: Any, **kwargs: Dict) -> Tuple[Any, Dict]: kwargs (`Dict`): The step-level parameters for calculating advantages. Returns: - `Any`: The experiences with advantages. + `DataProto`: The experiences with advantages. `Dict`: The metrics for logging. """ - -@ADVANTAGE_FN.register("ppo_adv_fn") -class PPOAdvantageFn(AdvantageFn): - """PPO's GAE advantage computation""" - - def __init__(self, **kwargs): - self.gamma = kwargs.get("gamma") - self.lam = kwargs.get("lam") - - def __call__(self, exps: DataProto, **kwargs) -> Tuple[DataProto, Dict]: - """Adapted from compute_advantage_ppo in ray_trainer.py""" - - advantages, returns = core_algos.compute_gae_advantage_return( - token_level_rewards=exps.batch["token_level_rewards"], - values=exps.batch["values"], - eos_mask=exps.batch["response_mask"], - gamma=self.gamma, - lam=self.lam, - ) - exps.batch["advantages"] = advantages - exps.batch["returns"] = returns - - metrics = { - "abc": "xyz", # TODO: add meaningful metrics - } - - return exps, metrics - - -@ADVANTAGE_FN.register("grpo_adv_fn") -class GRPOAdvantageFn(AdvantageFn): - """GRPO advantage computation""" - - def __init__(self, **kwargs): - pass - - def __call__(self, exps: DataProto, **kwargs) -> Tuple[DataProto, Dict]: - """Adapted from compute_advantage_ppo in ray_trainer.py""" - - advantages, returns = core_algos.compute_grpo_outcome_advantage( - token_level_rewards=exps.batch["token_level_rewards"], - eos_mask=exps.batch["response_mask"], - index=exps.non_tensor_batch["uid"], - ) - exps.batch["advantages"] = advantages - exps.batch["returns"] = returns - - metrics = { - "abc": "xyz", # TODO: add meaningful metrics - } - - return exps, metrics - - -@ADVANTAGE_FN.register("reinforceplusplus_adv_fn") -class REINFORCEPLUSPLUSAdvantageFn(AdvantageFn): - """REINFORCE++ advantage computation""" - - def __init__(self, **kwargs): - self.gamma = kwargs.get("gamma") - - def __call__(self, exps: DataProto, **kwargs) -> Tuple[DataProto, Dict]: - """Adapted from compute_advantage_ppo in ray_trainer.py""" - - advantages, returns = core_algos.compute_reinforce_plus_plus_outcome_advantage( - token_level_rewards=exps.batch["token_level_rewards"], - eos_mask=exps.batch["response_mask"], - gamma=self.gamma, - ) - exps.batch["advantages"] = advantages - exps.batch["returns"] = returns - - metrics = { - "abc": "xyz", # TODO: add meaningful metrics - } - - return exps, metrics - - -@ADVANTAGE_FN.register("remax_adv_fn") -class REMAXAdvantageFn(AdvantageFn): - """REMAX advantage computation""" - - def __init__(self, **kwargs): - pass - - def __call__(self, exps: DataProto, **kwargs) -> Tuple[DataProto, Dict]: - """Adapted from compute_advantage_ppo in ray_trainer.py""" - - advantages, returns = core_algos.compute_remax_outcome_advantage( - token_level_rewards=exps.batch["token_level_rewards"], - reward_baselines=exps.batch["reward_baselines"], - eos_mask=exps.batch["response_mask"], - ) - exps.batch["advantages"] = advantages - exps.batch["returns"] = returns - - metrics = { - "abc": "xyz", # TODO: add meaningful metrics - } - - return exps, metrics - - -@ADVANTAGE_FN.register("rloo_adv_fn") -class RLOOAdvantageFn(AdvantageFn): - """RLOO advantage computation""" - - def __init__(self, **kwargs): - pass - - def __call__(self, exps: DataProto, **kwargs) -> Tuple[DataProto, Dict]: - """Adapted from compute_advantage_ppo in ray_trainer.py""" - - advantages, returns = core_algos.compute_rloo_outcome_advantage( - token_level_rewards=exps.batch["token_level_rewards"], - eos_mask=exps.batch["response_mask"], - index=exps.non_tensor_batch["uid"], - ) - exps.batch["advantages"] = advantages - exps.batch["returns"] = returns - - metrics = { - "abc": "xyz", # TODO: add meaningful metrics - } - - return exps, metrics - - -@ADVANTAGE_FN.register("opmd_adv_fn") -class OPMDAdvantageFn(AdvantageFn): - """OPMD advantage computation""" - - def __init__(self, **kwargs): - pass - - def __call__(self, exps: DataProto, **kwargs) -> Tuple[DataProto, Dict]: - """Adapted from compute_advantage_opmd in ray_trainer.py""" - - advantages, returns = core_algos.compute_opmd_outcome_advantage( - token_level_rewards=exps.batch["token_level_rewards"], - eos_mask=exps.batch["response_mask"], - # TODO: check consistency with exps.batch["attention_mask"][:, -response_length:] in original implementation - index=exps.non_tensor_batch["uid"], - opmd_baseline="mean", - tau=1.0, - ) - exps.batch["advantages"] = advantages - exps.batch["returns"] = returns - - metrics = { - "abc": "xyz", # TODO: add meaningful metrics - } - - return exps, metrics + @classmethod + @abstractmethod + def default_args(cls) -> Dict: + """ + Returns: + `Dict`: The default init arguments for the advantage function. + """ diff --git a/trinity/algorithm/advantage_fn/grpo_advantage.py b/trinity/algorithm/advantage_fn/grpo_advantage.py new file mode 100644 index 0000000000..89a8282752 --- /dev/null +++ b/trinity/algorithm/advantage_fn/grpo_advantage.py @@ -0,0 +1,42 @@ +"""GRPO advantage computation + +Adapted from compute_advantage_ppo in original ray_trainer.py +""" + +from typing import Dict, Tuple + +from verl import DataProto + +from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn +from trinity.trainer.verl import core_algos + + +@ADVANTAGE_FN.register_module("grpo_adv_fn") +class GRPOAdvantageFn(AdvantageFn): + """GRPO advantage computation""" + + def __init__(self) -> None: + pass + + def __call__( + self, + exps: DataProto, + **kwargs, + ) -> Tuple[DataProto, Dict]: + advantages, returns = core_algos.compute_grpo_outcome_advantage( + token_level_rewards=exps.batch["token_level_rewards"], + eos_mask=exps.batch["response_mask"], + index=exps.non_tensor_batch["uid"], + ) + exps.batch["advantages"] = advantages + exps.batch["returns"] = returns + + metrics = { + # TODO: add meaningful metrics + } + + return exps, metrics + + @classmethod + def default_args(cls) -> Dict: + return {} diff --git a/trinity/algorithm/advantage_fn/opmd_advantage.py b/trinity/algorithm/advantage_fn/opmd_advantage.py new file mode 100644 index 0000000000..abf74686d3 --- /dev/null +++ b/trinity/algorithm/advantage_fn/opmd_advantage.py @@ -0,0 +1,45 @@ +"""OPMD advantage computation + +Adapted from compute_advantage_opmd in original ray_trainer.py +""" + +from typing import Dict, Tuple + +from verl import DataProto + +from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn +from trinity.trainer.verl import core_algos + + +@ADVANTAGE_FN.register_module("opmd_adv_fn") +class OPMDAdvantageFn(AdvantageFn): + """OPMD advantage computation""" + + def __init__(self) -> None: + pass + + def __call__( + self, + exps: DataProto, + **kwargs, + ) -> Tuple[DataProto, Dict]: + advantages, returns = core_algos.compute_opmd_outcome_advantage( + token_level_rewards=exps.batch["token_level_rewards"], + eos_mask=exps.batch["response_mask"], + # TODO (yanxi): check consistency with exps.batch["attention_mask"][:, -response_length:] in original implementation + index=exps.non_tensor_batch["uid"], + opmd_baseline="mean", + tau=1.0, + ) + exps.batch["advantages"] = advantages + exps.batch["returns"] = returns + + metrics = { + # TODO: add meaningful metrics + } + + return exps, metrics + + @classmethod + def default_args(cls) -> Dict: + return {} diff --git a/trinity/algorithm/advantage_fn/ppo_advantage.py b/trinity/algorithm/advantage_fn/ppo_advantage.py new file mode 100644 index 0000000000..5afd51311c --- /dev/null +++ b/trinity/algorithm/advantage_fn/ppo_advantage.py @@ -0,0 +1,50 @@ +"""PPO's GAE advantage computation + +Adapted from compute_advantage_ppo in original ray_trainer.py +""" + +from typing import Dict, Tuple + +from verl import DataProto + +from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn +from trinity.trainer.verl import core_algos + + +@ADVANTAGE_FN.register_module("ppo_adv_fn") +class PPOAdvantageFn(AdvantageFn): + def __init__( + self, + gamma: float = 1.0, + lam: float = 1.0, + ) -> None: + self.gamma = gamma + self.lam = lam + + def __call__( + self, + exps: DataProto, + **kwargs, + ) -> Tuple[DataProto, Dict]: + advantages, returns = core_algos.compute_gae_advantage_return( + token_level_rewards=exps.batch["token_level_rewards"], + values=exps.batch["values"], + eos_mask=exps.batch["response_mask"], + gamma=self.gamma, + lam=self.lam, + ) + exps.batch["advantages"] = advantages + exps.batch["returns"] = returns + + metrics = { + # TODO: add meaningful metrics + } + + return exps, metrics + + @classmethod + def default_args(cls) -> Dict: + return { + "gamma": 1.0, + "lam": 1.0, + } diff --git a/trinity/algorithm/advantage_fn/reinforce_plus_plus_advantage.py b/trinity/algorithm/advantage_fn/reinforce_plus_plus_advantage.py new file mode 100644 index 0000000000..9c668f7640 --- /dev/null +++ b/trinity/algorithm/advantage_fn/reinforce_plus_plus_advantage.py @@ -0,0 +1,42 @@ +"""REINFORCE++ advantage computation + +Adapted from compute_advantage_ppo in original ray_trainer.py +""" + +from typing import Dict, Tuple + +from verl import DataProto + +from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn +from trinity.trainer.verl import core_algos + + +@ADVANTAGE_FN.register_module("reinforceplusplus_adv_fn") +class REINFORCEPLUSPLUSAdvantageFn(AdvantageFn): + def __init__(self, gamma: float = 1.0) -> None: + self.gamma = gamma + + def __call__( + self, + exps: DataProto, + **kwargs, + ) -> Tuple[DataProto, Dict]: + advantages, returns = core_algos.compute_reinforce_plus_plus_outcome_advantage( + token_level_rewards=exps.batch["token_level_rewards"], + eos_mask=exps.batch["response_mask"], + gamma=self.gamma, + ) + exps.batch["advantages"] = advantages + exps.batch["returns"] = returns + + metrics = { + # TODO: add meaningful metrics + } + + return exps, metrics + + @classmethod + def default_args(cls) -> Dict: + return { + "gamma": 1.0, + } diff --git a/trinity/algorithm/advantage_fn/remax_advantage.py b/trinity/algorithm/advantage_fn/remax_advantage.py new file mode 100644 index 0000000000..05a13d7d60 --- /dev/null +++ b/trinity/algorithm/advantage_fn/remax_advantage.py @@ -0,0 +1,40 @@ +"""REMAX advantage computation + +Adapted from compute_advantage_ppo in original ray_trainer.py +""" + +from typing import Dict, Tuple + +from verl import DataProto + +from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn +from trinity.trainer.verl import core_algos + + +@ADVANTAGE_FN.register_module("remax_adv_fn") +class REMAXAdvantageFn(AdvantageFn): + def __init__(self) -> None: + pass + + def __call__( + self, + exps: DataProto, + **kwargs, + ) -> Tuple[DataProto, Dict]: + advantages, returns = core_algos.compute_remax_outcome_advantage( + token_level_rewards=exps.batch["token_level_rewards"], + reward_baselines=exps.batch["reward_baselines"], + eos_mask=exps.batch["response_mask"], + ) + exps.batch["advantages"] = advantages + exps.batch["returns"] = returns + + metrics = { + # TODO: add meaningful metrics + } + + return exps, metrics + + @classmethod + def default_args(cls) -> Dict: + return {} diff --git a/trinity/algorithm/advantage_fn/rloo_advantage.py b/trinity/algorithm/advantage_fn/rloo_advantage.py new file mode 100644 index 0000000000..3da61c9da4 --- /dev/null +++ b/trinity/algorithm/advantage_fn/rloo_advantage.py @@ -0,0 +1,40 @@ +"""RLOO advantage computation + +Adapted from compute_advantage_ppo in original ray_trainer.py +""" + +from typing import Dict, Tuple + +from verl import DataProto + +from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn +from trinity.trainer.verl import core_algos + + +@ADVANTAGE_FN.register_module("rloo_adv_fn") +class RLOOAdvantageFn(AdvantageFn): + def __init__(self) -> None: + pass + + def __call__( + self, + exps: DataProto, + **kwargs, + ) -> Tuple[DataProto, Dict]: + advantages, returns = core_algos.compute_rloo_outcome_advantage( + token_level_rewards=exps.batch["token_level_rewards"], + eos_mask=exps.batch["response_mask"], + index=exps.non_tensor_batch["uid"], + ) + exps.batch["advantages"] = advantages + exps.batch["returns"] = returns + + metrics = { + # TODO: add meaningful metrics + } + + return exps, metrics + + @classmethod + def default_args(cls) -> Dict: + return {} diff --git a/trinity/common/config.py b/trinity/common/config.py index 00fa7a2c58..794202bab0 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -173,15 +173,15 @@ class AlgorithmConfig: algorithm_type: AlgorithmType = AlgorithmType.PPO # for GRPO-like algorithms, repeat each task for `repeat_times` times repeat_times: int = 1 - # configs for advantage calculation - advantage_fn_type: str = "ppo_adv_fn" - gamma: Optional[float] = None - lam: Optional[float] = None policy_loss_fn: str = "ppo" # If not set, use PolicyLossFn.default_args() policy_loss_fn_args: Optional[dict] = None + advantage_fn_type: str = "ppo_adv_fn" + # If not set, use AdvantageFn.default_args() + advantage_fn_args: Optional[dict] = None + @dataclass class ClusterConfig: @@ -472,7 +472,7 @@ def _check_buffer(self) -> None: # noqa: C901 self.buffer.tokenizer_path = self.model.model_path def _check_algorithm(self) -> None: - from trinity.algorithm import POLICY_LOSS_FN + from trinity.algorithm import ADVANTAGE_FN, POLICY_LOSS_FN policy_fn_cls = POLICY_LOSS_FN.get(self.algorithm.policy_loss_fn) if policy_fn_cls is None: @@ -480,6 +480,12 @@ def _check_algorithm(self) -> None: if self.algorithm.policy_loss_fn_args is None: self.algorithm.policy_loss_fn_args = policy_fn_cls.default_args() + advantage_fn_cls = ADVANTAGE_FN.get(self.algorithm.advantage_fn_type) + if advantage_fn_cls is None: + raise ValueError(f"Invalid advantage_fn_type: {self.algorithm.advantage_fn_type}") + if self.algorithm.advantage_fn_args is None: + self.algorithm.advantage_fn_args = advantage_fn_cls.default_args() + def check_and_update(self) -> None: # noqa: C901 """Check and update the config.""" self._check_deprecated() diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index a4d9a6e8d9..b519246b8e 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -185,12 +185,13 @@ class Algorithm: gamma: float = 1.0 lam: float = 1.0 adv_estimator: str = "gae" + # TODO (yanxi): remove the above advantage-related parameters? norm_adv_by_std_in_grpo: bool = True use_kl_in_reward: bool = False kl_penalty: str = "kl" kl_ctrl: KL_Ctrl = field(default_factory=KL_Ctrl) - # ! DO NOT SET THE FLOWING PARAMETERS + # ! DO NOT SET THE FOLLOWING PARAMETERS policy_loss_fn: str = "ppo" policy_loss_fn_args: Optional[dict] = None @@ -321,11 +322,14 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901 self.algorithm.lam = config.algorithm.lam self.actor_rollout_ref.actor.algorithm_type = config.algorithm.algorithm_type if config.algorithm.algorithm_type == AlgorithmType.PPO: - logger.info("Using GAE `adv_estimator` for PPO") + logger.info("Setting `adv_estimator` to 'gae' for PPO") self.algorithm.adv_estimator = AdvantageEstimator.GAE.value - elif config.algorithm.algorithm_type == AlgorithmType.GRPO: - logger.info("Using GRPO `adv_estimator` for GRPO") + elif config.algorithm.algorithm_type in (AlgorithmType.GRPO, AlgorithmType.OPMD): + logger.info("Setting `adv_estimator` to 'grpo' for GRPO/OPMD") self.algorithm.adv_estimator = AdvantageEstimator.GRPO.value + # TODO (yanxi): it seems that adv_estimator only affects whether use_critic is set to + # True or False in RayPPOTrainer.__init__() (and hence in VerlPPOTrainerWrapper); + # need to double check whether this is indeed the case. if self.actor_rollout_ref.actor.algorithm_type.is_dpo(): # for DPO if not self.actor_rollout_ref.actor.use_kl_loss: diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index 1ee8c29887..0ff56e4140 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -131,9 +131,7 @@ def __init__( algo_config = global_config.algorithm if algo_config.algorithm_type.is_rft(): adv_fn_type = algo_config.advantage_fn_type - adv_fn_args = algo_config.get( - "advantage_fn_args", {} - ) # TODO (yanxi): does this work properly?? + adv_fn_args = algo_config.advantage_fn_args self.advantage_fn = ADVANTAGE_FN.get(adv_fn_type)(**adv_fn_args) self.logger = Monitor( From b9346fdaa9e788fc6003535932453a97c86ec8b8 Mon Sep 17 00:00:00 2001 From: yanxi-chen Date: Thu, 29 May 2025 14:07:34 +0800 Subject: [PATCH 4/6] Fix gamma/lam issue --- trinity/common/verl_config.py | 15 +++++---------- trinity/trainer/verl/core_algos.py | 6 +++--- 2 files changed, 8 insertions(+), 13 deletions(-) diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index b519246b8e..44b9b88852 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -182,10 +182,8 @@ class KL_Ctrl: @dataclass class Algorithm: - gamma: float = 1.0 - lam: float = 1.0 adv_estimator: str = "gae" - # TODO (yanxi): remove the above advantage-related parameters? + # TODO (yanxi): might remove adv_estimator completely, use AlgorithmConfig.advantage_fn_type instead norm_adv_by_std_in_grpo: bool = True use_kl_in_reward: bool = False kl_penalty: str = "kl" @@ -316,10 +314,6 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901 self.actor_rollout_ref.actor.clip_ratio = config.trainer.actor_clip_ratio # Algorithm related config - if config.algorithm.gamma is not None: - self.algorithm.gamma = config.algorithm.gamma - if config.algorithm.lam is not None: - self.algorithm.lam = config.algorithm.lam self.actor_rollout_ref.actor.algorithm_type = config.algorithm.algorithm_type if config.algorithm.algorithm_type == AlgorithmType.PPO: logger.info("Setting `adv_estimator` to 'gae' for PPO") @@ -327,9 +321,10 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901 elif config.algorithm.algorithm_type in (AlgorithmType.GRPO, AlgorithmType.OPMD): logger.info("Setting `adv_estimator` to 'grpo' for GRPO/OPMD") self.algorithm.adv_estimator = AdvantageEstimator.GRPO.value - # TODO (yanxi): it seems that adv_estimator only affects whether use_critic is set to - # True or False in RayPPOTrainer.__init__() (and hence in VerlPPOTrainerWrapper); - # need to double check whether this is indeed the case. + # TODO (yanxi): it seems that adv_estimator now only affects whether use_critic is set to + # True or False in RayPPOTrainer.__init__() (and hence in VerlPPOTrainerWrapper). + # Need to double check whether this is indeed the case, + # and see if adv_estimator can be removed completely. if self.actor_rollout_ref.actor.algorithm_type.is_dpo(): # for DPO if not self.actor_rollout_ref.actor.use_kl_loss: diff --git a/trinity/trainer/verl/core_algos.py b/trinity/trainer/verl/core_algos.py index 20cffc9962..f104e0f4f4 100644 --- a/trinity/trainer/verl/core_algos.py +++ b/trinity/trainer/verl/core_algos.py @@ -139,8 +139,8 @@ def compute_gae_advantage_return( token_level_rewards: torch.Tensor, values: torch.Tensor, eos_mask: torch.Tensor, - gamma: torch.Tensor, - lam: torch.Tensor, + gamma: float, + lam: float, ): """Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py @@ -283,7 +283,7 @@ def compute_rloo_outcome_advantage( def compute_reinforce_plus_plus_outcome_advantage( - token_level_rewards: torch.Tensor, eos_mask: torch.Tensor, gamma: torch.Tensor + token_level_rewards: torch.Tensor, eos_mask: torch.Tensor, gamma: float ): """ Compute advantage for REINFORCE++. From 6d0caa145f258b25504ccbd6da1747d110745d64 Mon Sep 17 00:00:00 2001 From: yanxi-chen Date: Thu, 29 May 2025 15:21:46 +0800 Subject: [PATCH 5/6] Add gamma and lam back to verl_config to ensure compatibility --- trinity/common/verl_config.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index 44b9b88852..fb9f810dee 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -182,8 +182,12 @@ class KL_Ctrl: @dataclass class Algorithm: + # ! DO NOT SET gamma or lam below; they are kept here merely for compatibility with verl, + # and their values will be overwritten by those in AlgorithmConfig.advantage_fn_args + # if they are really needed (e.g., for GAE advantage/returns computation) + gamma: float = 1.0 + lam: float = 1.0 adv_estimator: str = "gae" - # TODO (yanxi): might remove adv_estimator completely, use AlgorithmConfig.advantage_fn_type instead norm_adv_by_std_in_grpo: bool = True use_kl_in_reward: bool = False kl_penalty: str = "kl" @@ -314,6 +318,11 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901 self.actor_rollout_ref.actor.clip_ratio = config.trainer.actor_clip_ratio # Algorithm related config + adv_fn_args = config.algorithm.advantage_fn_args + if adv_fn_args is not None and "gamma" in adv_fn_args: + self.algorithm.gamma = adv_fn_args["gamma"] + if adv_fn_args is not None and "lam" in adv_fn_args: + self.algorithm.lam = adv_fn_args["lam"] self.actor_rollout_ref.actor.algorithm_type = config.algorithm.algorithm_type if config.algorithm.algorithm_type == AlgorithmType.PPO: logger.info("Setting `adv_estimator` to 'gae' for PPO") From db96689c2b5f808b9b6479bb0c93f7c27a416565 Mon Sep 17 00:00:00 2001 From: yanxi-chen Date: Thu, 29 May 2025 16:02:00 +0800 Subject: [PATCH 6/6] Fix import error --- trinity/trainer/verl_trainer.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index 0ff56e4140..b6397adde7 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -4,12 +4,20 @@ Modified from verl/trainer/ppo/ray_trainer.py """ import os +from pprint import pprint from typing import Tuple +import numpy as np import pandas as pd import ray import torch from omegaconf import OmegaConf +from verl.trainer.ppo.metric_utils import ( + compute_data_metrics, + compute_throughout_metrics, + compute_timing_metrics, + reduce_metrics, +) from verl.utils import hf_tokenizer from verl.utils.fs import copy_local_path_from_hdfs @@ -26,13 +34,7 @@ Role, _timer, apply_kl_penalty, - compute_data_metrics, - compute_throughout_metrics, - compute_timing_metrics, find_latest_ckpt_path, - np, - pprint, - reduce_metrics, ) from trinity.utils.monitor import Monitor