diff --git a/docs/sphinx_doc/source/tutorial/example_dpo.md b/docs/sphinx_doc/source/tutorial/example_dpo.md index a6f70f5e62..44543ff2bc 100644 --- a/docs/sphinx_doc/source/tutorial/example_dpo.md +++ b/docs/sphinx_doc/source/tutorial/example_dpo.md @@ -48,6 +48,9 @@ name: mode: train algorithm: algorithm_type: dpo + kl_loss_fn: k1 + kl_loss_fn_args: + kl_coef: 0.1 # value of beta in DPO checkpoint_root_dir: /PATH/TO/CHECKPOINT/ model: model_path: /PATH/TO/MODEL/ @@ -70,8 +73,6 @@ buffer: trainer: trainer_config_path: 'examples/dpo_humanlike/train_dpo.yaml' save_interval: 30 - actor_use_kl_loss: True - actor_kl_loss_coef: 0.1 # value of beta in DPO ``` ### Run the Experiment diff --git a/examples/dpo_humanlike/dpo.yaml b/examples/dpo_humanlike/dpo.yaml index 8cd3dbe0c8..0a0864b8ef 100644 --- a/examples/dpo_humanlike/dpo.yaml +++ b/examples/dpo_humanlike/dpo.yaml @@ -3,6 +3,9 @@ name: "trinity_dpo" mode: train algorithm: algorithm_type: dpo + kl_loss_fn: k1 + kl_loss_fn_args: + kl_coef: 0.1 checkpoint_root_dir: /PATH/TO/CHECKPOINT/ model: model_path: /PATH/TO/MODEL @@ -34,5 +37,3 @@ trainer: trainer_type: 'verl' trainer_config_path: 'examples/dpo_humanlike/train_dpo.yaml' save_interval: 30 - actor_use_kl_loss: True - actor_kl_loss_coef: 0.1 diff --git a/tests/template/config.yaml b/tests/template/config.yaml index 3a767df243..98180fff48 100644 --- a/tests/template/config.yaml +++ b/tests/template/config.yaml @@ -8,10 +8,12 @@ algorithm: policy_loss_fn: ppo policy_loss_fn_args: clip_range: 0.2 - advantage_fn_type: ppo_adv_fn + advantage_fn: ppo advantage_fn_args: gamma: 1.0 lam: 1.0 + kl_penalty_fn: k3 + kl_loss_fn: k2 model: model_path: '' diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 55f63ae856..e83b443c4b 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -67,6 +67,10 @@ def test_trainer(self): actor_metrics = parser.metric_list("actor") self.assertTrue(len(actor_metrics) > 0) self.assertEqual(parser.metric_max_step(actor_metrics[0]), 8) + actor_kl_metrics = parser.metric_list("actor/kl") + self.assertTrue(len(actor_kl_metrics) > 0) + critic_kl_metrics = parser.metric_list("critic/kl") + self.assertTrue(len(critic_kl_metrics) > 0) response_metrics = parser.metric_list("response_length") self.assertTrue(len(response_metrics) > 0) self.assertEqual(parser.metric_max_step(response_metrics[0]), 8) @@ -86,7 +90,7 @@ def test_trainer(self): ) self.assertTrue(os.path.exists(checkpoint_step_4)) self.assertTrue(os.path.exists(checkpoint_step_8)) - + # TODO: Reinit will fail when using v1 engine, find a way to fix it ray.init(ignore_reinit_error=True) # test bench mode self.config.mode = "bench" @@ -118,7 +122,7 @@ def test_trainer(self): self.config.algorithm.algorithm_type = AlgorithmType.GRPO self.config.algorithm.repeat_times = 4 # self.config.algorithm.repeat_times = 8 # TODO: used for real testing - self.config.algorithm.advantage_fn_type = "grpo_adv_fn" + self.config.algorithm.advantage_fn = "grpo" self.config.algorithm.advantage_fn_args = {} # self.config.buffer.batch_size = 96 # TODO: used for real testing self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k") @@ -143,8 +147,6 @@ def test_trainer(self): # self.assertTrue(0.4 < rewards[1] < 0.55) # self.assertTrue(0.6 < rewards[2] < 0.7) # self.assertTrue(0.6 < rewards[3] < 0.7) - ray.shutdown(_exiting_interpreter=True) - # check checkpoint def tearDown(self): # remove dir only when the test passed @@ -157,7 +159,7 @@ def test_trainer(self): # test both mode self.config.algorithm.algorithm_type = AlgorithmType.GRPO self.config.algorithm.repeat_times = 4 - self.config.algorithm.advantage_fn_type = "grpo_adv_fn" + self.config.algorithm.advantage_fn = "grpo" self.config.algorithm.advantage_fn_args = {} self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k") self.config.buffer.trainer_input.sft_warmup_steps = 2 @@ -180,8 +182,6 @@ def test_trainer(self): response_metrics = parser.metric_list("response_length") self.assertTrue(len(response_metrics) > 0) self.assertEqual(parser.metric_max_step(response_metrics[0]), 4) - ray.shutdown(_exiting_interpreter=True) - # check checkpoint def tearDown(self): # remove dir only when the test passed @@ -207,8 +207,6 @@ def test_trainer(self): actor_metrics = parser.metric_list("actor") self.assertTrue(len(actor_metrics) > 0) self.assertEqual(parser.metric_max_step(actor_metrics[0]), 4) - ray.shutdown(_exiting_interpreter=True) - # check checkpoint def tearDown(self): # remove dir only when the test passed diff --git a/trinity/algorithm/__init__.py b/trinity/algorithm/__init__.py index 170507663f..101364c57c 100644 --- a/trinity/algorithm/__init__.py +++ b/trinity/algorithm/__init__.py @@ -1,4 +1,6 @@ from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn +from trinity.algorithm.entropy_loss_fn import ENTROPY_LOSS_FN, EntropyLossFn +from trinity.algorithm.kl_fn import KL_FN, KLFn from trinity.algorithm.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn __all__ = [ @@ -6,4 +8,8 @@ "ADVANTAGE_FN", "PolicyLossFn", "POLICY_LOSS_FN", + "KLFn", + "KL_FN", + "EntropyLossFn", + "ENTROPY_LOSS_FN", ] diff --git a/trinity/algorithm/advantage_fn/grpo_advantage.py b/trinity/algorithm/advantage_fn/grpo_advantage.py index 89a8282752..37f824de4f 100644 --- a/trinity/algorithm/advantage_fn/grpo_advantage.py +++ b/trinity/algorithm/advantage_fn/grpo_advantage.py @@ -11,7 +11,7 @@ from trinity.trainer.verl import core_algos -@ADVANTAGE_FN.register_module("grpo_adv_fn") +@ADVANTAGE_FN.register_module("grpo") class GRPOAdvantageFn(AdvantageFn): """GRPO advantage computation""" diff --git a/trinity/algorithm/advantage_fn/opmd_advantage.py b/trinity/algorithm/advantage_fn/opmd_advantage.py index abf74686d3..e9e0eb090f 100644 --- a/trinity/algorithm/advantage_fn/opmd_advantage.py +++ b/trinity/algorithm/advantage_fn/opmd_advantage.py @@ -11,7 +11,7 @@ from trinity.trainer.verl import core_algos -@ADVANTAGE_FN.register_module("opmd_adv_fn") +@ADVANTAGE_FN.register_module("opmd") class OPMDAdvantageFn(AdvantageFn): """OPMD advantage computation""" diff --git a/trinity/algorithm/advantage_fn/ppo_advantage.py b/trinity/algorithm/advantage_fn/ppo_advantage.py index 5afd51311c..896deca116 100644 --- a/trinity/algorithm/advantage_fn/ppo_advantage.py +++ b/trinity/algorithm/advantage_fn/ppo_advantage.py @@ -11,7 +11,7 @@ from trinity.trainer.verl import core_algos -@ADVANTAGE_FN.register_module("ppo_adv_fn") +@ADVANTAGE_FN.register_module("ppo") class PPOAdvantageFn(AdvantageFn): def __init__( self, diff --git a/trinity/algorithm/advantage_fn/reinforce_plus_plus_advantage.py b/trinity/algorithm/advantage_fn/reinforce_plus_plus_advantage.py index 9c668f7640..d53052c83f 100644 --- a/trinity/algorithm/advantage_fn/reinforce_plus_plus_advantage.py +++ b/trinity/algorithm/advantage_fn/reinforce_plus_plus_advantage.py @@ -11,7 +11,7 @@ from trinity.trainer.verl import core_algos -@ADVANTAGE_FN.register_module("reinforceplusplus_adv_fn") +@ADVANTAGE_FN.register_module("reinforceplusplus") class REINFORCEPLUSPLUSAdvantageFn(AdvantageFn): def __init__(self, gamma: float = 1.0) -> None: self.gamma = gamma diff --git a/trinity/algorithm/advantage_fn/remax_advantage.py b/trinity/algorithm/advantage_fn/remax_advantage.py index 05a13d7d60..516213c0c2 100644 --- a/trinity/algorithm/advantage_fn/remax_advantage.py +++ b/trinity/algorithm/advantage_fn/remax_advantage.py @@ -11,7 +11,7 @@ from trinity.trainer.verl import core_algos -@ADVANTAGE_FN.register_module("remax_adv_fn") +@ADVANTAGE_FN.register_module("remax") class REMAXAdvantageFn(AdvantageFn): def __init__(self) -> None: pass diff --git a/trinity/algorithm/advantage_fn/rloo_advantage.py b/trinity/algorithm/advantage_fn/rloo_advantage.py index 3da61c9da4..c88276e836 100644 --- a/trinity/algorithm/advantage_fn/rloo_advantage.py +++ b/trinity/algorithm/advantage_fn/rloo_advantage.py @@ -11,7 +11,7 @@ from trinity.trainer.verl import core_algos -@ADVANTAGE_FN.register_module("rloo_adv_fn") +@ADVANTAGE_FN.register_module("rloo") class RLOOAdvantageFn(AdvantageFn): def __init__(self) -> None: pass diff --git a/trinity/algorithm/entropy_loss/__init__.py b/trinity/algorithm/entropy_loss/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/trinity/algorithm/entropy_loss_fn/__init__.py b/trinity/algorithm/entropy_loss_fn/__init__.py new file mode 100644 index 0000000000..d932b94fde --- /dev/null +++ b/trinity/algorithm/entropy_loss_fn/__init__.py @@ -0,0 +1,9 @@ +from trinity.algorithm.entropy_loss_fn.entropy_loss_fn import ( + ENTROPY_LOSS_FN, + EntropyLossFn, +) + +__all__ = [ + "EntropyLossFn", + "ENTROPY_LOSS_FN", +] diff --git a/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py b/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py new file mode 100644 index 0000000000..4df9272ca0 --- /dev/null +++ b/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py @@ -0,0 +1,63 @@ +from abc import ABC, abstractmethod +from typing import Dict, Tuple + +import torch + +from trinity.algorithm.utils import masked_mean +from trinity.utils.registry import Registry + +ENTROPY_LOSS_FN = Registry("entropy_loss_fn") + + +class EntropyLossFn(ABC): + """ + Entropy loss function. + """ + + @abstractmethod + def __call__( + self, + entropy: torch.Tensor, + action_mask: torch.Tensor, + **kwargs, + ) -> Tuple[torch.Tensor, Dict]: + """ + Args: + entropy (`torch.Tensor`): The entropy generated by the policy model. + action_mask (`torch.Tensor`): The action mask. + + Returns: + `torch.Tensor`: The calculated entropy loss. + `Dict`: The metrics for logging + """ + + @classmethod + @abstractmethod + def default_args(cls) -> Dict: + """ + Returns: + `Dict`: The default arguments for the entropy loss function. + """ + + +@ENTROPY_LOSS_FN.register_module("basic") +class BasicEntropyLossFn(EntropyLossFn): + """ + Basic entropy loss function. + """ + + def __init__(self, entropy_coef: float): + self.entropy_coef = entropy_coef + + def __call__( + self, + entropy: torch.Tensor, + action_mask: torch.Tensor, + **kwargs, + ) -> Tuple[torch.Tensor, Dict]: + entropy_loss = masked_mean(entropy, action_mask) + return entropy_loss * self.entropy_coef, {"entropy_loss": entropy_loss.detach().item()} + + @classmethod + def default_args(cls) -> Dict: + return {"entropy_coef": 0.0} diff --git a/trinity/algorithm/kl_fn/__init__.py b/trinity/algorithm/kl_fn/__init__.py new file mode 100644 index 0000000000..875c620442 --- /dev/null +++ b/trinity/algorithm/kl_fn/__init__.py @@ -0,0 +1,3 @@ +from trinity.algorithm.kl_fn.kl_fn import KL_FN, KLFn + +__all__ = ["KLFn", "KL_FN"] diff --git a/trinity/algorithm/kl_fn/kl_fn.py b/trinity/algorithm/kl_fn/kl_fn.py new file mode 100644 index 0000000000..3901ea7f3c --- /dev/null +++ b/trinity/algorithm/kl_fn/kl_fn.py @@ -0,0 +1,157 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict, Optional, Tuple + +import torch + +from trinity.algorithm.utils import masked_mean +from trinity.utils.registry import Registry + +KL_FN = Registry("kl_fn") + + +class KLFn(ABC): + """ + KL controller. + """ + + def __init__( + self, + adaptive: bool = False, + kl_coef: float = 0.001, + target_kl: Optional[float] = None, + horizon: Optional[float] = None, + ) -> None: + self.kl_coef = kl_coef + self.adaptive = adaptive + self.target_kl = target_kl + self.horizon = horizon + if adaptive and (target_kl is None or horizon is None): + raise ValueError("Target KL and horizon must be provided for adaptive KL.") + + def update_kl_coef(self, current_kl: float, batch_size: int) -> None: + """Update kl coefficient.""" + if self.adaptive: + target_kl = self.target_kl + proportional_error = torch.clip(current_kl / target_kl - 1, -0.2, 0.2).item() # type: ignore + multiplier = 1 + proportional_error * batch_size / self.horizon + self.kl_coef *= multiplier + + def apply_kl_penalty_to_reward(self, experiences: Any) -> Tuple[Any, Dict]: + """Apply KL penalty to reward. Only support DataProto input for now.""" + responses = experiences.batch["responses"] + response_length = responses.size(1) + token_level_scores = experiences.batch["token_level_scores"] + batch_size = experiences.batch.batch_size[0] + attention_mask = experiences.batch["attention_mask"] + response_mask = experiences.batch["response_mask"] + assert response_mask.shape == attention_mask[:, -response_length:].shape + logprob = experiences.batch["old_log_probs"] + ref_logprob = experiences.batch["ref_log_prob"] + + if "ref_log_prob" in experiences.batch.keys(): + kl = self.calculate_kl(logprob, ref_logprob) + kl = kl * response_mask + kl_coef = self.kl_coef + experiences.batch["token_level_rewards"] = token_level_scores - kl_coef * kl + else: + kl_coef = 0.0 + kl = torch.zeros_like(response_mask, dtype=torch.float32) + experiences.batch["token_level_rewards"] = token_level_scores + + current_kl = masked_mean(kl, mask=response_mask, axis=-1).mean(dim=0).item() + self.update_kl_coef(current_kl=current_kl, batch_size=batch_size) + + metrics = { + "kl": current_kl, + "kl_coef": kl_coef, + } + + return experiences, metrics + + def calculate_kl_loss( + self, + logprob: torch.Tensor, + ref_logprob: torch.Tensor, + response_mask: torch.Tensor, + ) -> Tuple[torch.Tensor, Dict]: + """Compute KL loss.""" + kl = self.calculate_kl(logprob, ref_logprob) + kl_loss = masked_mean(kl, response_mask) + metrics = { + "kl_loss": kl_loss.detach().item(), + "kl_coef": self.kl_coef, + } + return kl_loss * self.kl_coef, metrics + + @abstractmethod + def calculate_kl(self, logprob: torch.Tensor, ref_logprob: torch.Tensor) -> torch.Tensor: + """Compute KL divergence between logprob and ref_logprob.""" + + @classmethod + def default_args(cls): + """Get the default initialization arguments.""" + return {"adaptive": False, "kl_coef": 0.001} + + +@KL_FN.register_module("none") +class DummyFn(KLFn): + """ + Dummy KL function. + """ + + def calculate_kl(self, logprob: torch.Tensor, ref_logprob: torch.Tensor) -> torch.Tensor: + return torch.zeros_like(logprob) + + def apply_kl_penalty_to_reward(self, experiences: Any) -> Tuple[Any, Dict]: + experiences.batch["token_level_rewards"] = experiences.batch["token_level_scores"] + return experiences, {} + + def calculate_kl_loss( + self, + logprob: torch.Tensor, + ref_logprob: torch.Tensor, + response_mask: torch.Tensor, + ) -> Tuple[torch.Tensor, Dict]: + # return a zero tensor + return torch.tensor(0.0), {} + + +@KL_FN.register_module("k1") +class K1Fn(KLFn): + """ + KL K1 function. + """ + + def calculate_kl(self, logprob: torch.Tensor, ref_logprob: torch.Tensor) -> torch.Tensor: + return logprob - ref_logprob + + +@KL_FN.register_module("k2") +class K2Fn(KLFn): + """ + KL K2 function. + """ + + def calculate_kl(self, logprob: torch.Tensor, ref_logprob: torch.Tensor) -> torch.Tensor: + return (logprob - ref_logprob).square() * 0.5 + + +@KL_FN.register_module("k3") +class K3Fn(KLFn): + """ + KL K3 function. + """ + + def calculate_kl(self, logprob: torch.Tensor, ref_logprob: torch.Tensor) -> torch.Tensor: + logr = ref_logprob - logprob + return logr.exp() - 1 - logr + + +@KL_FN.register_module("abs") +class AbsFn(KLFn): + """ + KL Abs function. + """ + + def calculate_kl(self, logprob: torch.Tensor, ref_logprob: torch.Tensor) -> torch.Tensor: + return torch.abs(logprob - ref_logprob) diff --git a/trinity/algorithm/kl_loss/__init__.py b/trinity/algorithm/kl_loss/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/trinity/algorithm/utils.py b/trinity/algorithm/utils.py index d5cfb72d8c..01356cc066 100644 --- a/trinity/algorithm/utils.py +++ b/trinity/algorithm/utils.py @@ -12,3 +12,11 @@ def masked_sum(values, mask, axis=None): def masked_mean(values, mask, axis=None): """Compute mean of tensor with a masked values.""" return (values * mask).sum(axis=axis) / (mask.sum(axis=axis) + 1e-8) + + +def prefix_metrics(src_metrics: dict, prefix: str, dst_metrics: dict = None) -> dict: + if dst_metrics is None: + dst_metrics = {} + for k, v in src_metrics.items(): + dst_metrics[f"{prefix}/{k}"] = v + return dst_metrics diff --git a/trinity/common/config.py b/trinity/common/config.py index 5d294abdfd..91c7790571 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -178,11 +178,24 @@ class AlgorithmConfig: # If not set, use PolicyLossFn.default_args() policy_loss_fn_args: Optional[dict] = None - advantage_fn_type: str = "ppo_adv_fn" + advantage_fn: str = "ppo" # If not set, use AdvantageFn.default_args() advantage_fn_args: Optional[dict] = None - # used for SFT + kl_penalty_fn: str = "none" # set to "none" to disable kl penalty in reward + # If not set, use kl_penalty_fn.default_args() + kl_penalty_fn_args: Optional[dict] = None + + kl_loss_fn: str = "k2" # set to "none" to disable kl loss + # If not set, use kl_loss_fn.default_args() + kl_loss_fn_args: Optional[dict] = None + + entropy_loss_fn: str = "basic" + # If not set, use entropy_loss_fn.default_args() + entropy_loss_fn_args: Optional[dict] = None + + # used for SFT warmup + # TODO: move this to SFT warmup use_token_level_loss: bool = True @@ -271,9 +284,6 @@ class TrainerConfig: enable_preview: bool = True # enable rollout preview in wandb # trainer configs - actor_use_kl_loss: Optional[bool] = None - actor_kl_loss_coef: Optional[float] = None - actor_entropy_coef: Optional[float] = None actor_grad_clip: Optional[float] = None actor_clip_ratio: Optional[float] = None # TODO: extract more train-related params from underlying trainer engine @@ -475,7 +485,12 @@ def _check_buffer(self) -> None: # noqa: C901 self.buffer.tokenizer_path = self.model.model_path def _check_algorithm(self) -> None: - from trinity.algorithm import ADVANTAGE_FN, POLICY_LOSS_FN + from trinity.algorithm import ( + ADVANTAGE_FN, + ENTROPY_LOSS_FN, + KL_FN, + POLICY_LOSS_FN, + ) policy_fn_cls = POLICY_LOSS_FN.get(self.algorithm.policy_loss_fn) if policy_fn_cls is None: @@ -483,12 +498,30 @@ def _check_algorithm(self) -> None: if self.algorithm.policy_loss_fn_args is None: self.algorithm.policy_loss_fn_args = policy_fn_cls.default_args() - advantage_fn_cls = ADVANTAGE_FN.get(self.algorithm.advantage_fn_type) + advantage_fn_cls = ADVANTAGE_FN.get(self.algorithm.advantage_fn) if advantage_fn_cls is None: - raise ValueError(f"Invalid advantage_fn_type: {self.algorithm.advantage_fn_type}") + raise ValueError(f"Invalid advantage_fn: {self.algorithm.advantage_fn}") if self.algorithm.advantage_fn_args is None: self.algorithm.advantage_fn_args = advantage_fn_cls.default_args() + kl_loss_fn_cls = KL_FN.get(self.algorithm.kl_loss_fn) + if kl_loss_fn_cls is None: + raise ValueError(f"Invalid kl_loss_fn: {self.algorithm.kl_loss_fn}") + if self.algorithm.kl_loss_fn_args is None: + self.algorithm.kl_loss_fn_args = kl_loss_fn_cls.default_args() + + kl_penalty_fn_cls = KL_FN.get(self.algorithm.kl_penalty_fn) + if kl_penalty_fn_cls is None: + raise ValueError(f"Invalid kl_penalty_fn: {self.algorithm.kl_penalty_fn}") + if self.algorithm.kl_penalty_fn_args is None: + self.algorithm.kl_penalty_fn_args = kl_penalty_fn_cls.default_args() + + entropy_loss_fn_cls = ENTROPY_LOSS_FN.get(self.algorithm.entropy_loss_fn) + if entropy_loss_fn_cls is None: + raise ValueError(f"Invalid entropy_loss_fn: {self.algorithm.entropy_loss_fn}") + if self.algorithm.entropy_loss_fn_args is None: + self.algorithm.entropy_loss_fn_args = entropy_loss_fn_cls.default_args() + def check_and_update(self) -> None: # noqa: C901 """Check and update the config.""" self._check_deprecated() diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index fb9f810dee..e8180f4718 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -306,12 +306,6 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901 self.critic.ppo_mini_batch_size = config.buffer.batch_size self.critic.rollout_n = self.actor_rollout_ref.rollout.n - if config.trainer.actor_use_kl_loss is not None: - self.actor_rollout_ref.actor.use_kl_loss = config.trainer.actor_use_kl_loss - if config.trainer.actor_kl_loss_coef is not None: - self.actor_rollout_ref.actor.kl_loss_coef = config.trainer.actor_kl_loss_coef - if config.trainer.actor_entropy_coef is not None: - self.actor_rollout_ref.actor.entropy_coeff = config.trainer.actor_entropy_coef if config.trainer.actor_grad_clip is not None: self.actor_rollout_ref.actor.grad_clip = config.trainer.actor_grad_clip if config.trainer.actor_clip_ratio is not None: @@ -330,6 +324,11 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901 elif config.algorithm.algorithm_type in (AlgorithmType.GRPO, AlgorithmType.OPMD): logger.info("Setting `adv_estimator` to 'grpo' for GRPO/OPMD") self.algorithm.adv_estimator = AdvantageEstimator.GRPO.value + self.actor_rollout_ref.actor.use_kl_loss = config.algorithm.kl_loss_fn != "none" + self.actor_rollout_ref.actor.kl_loss_coef = config.algorithm.kl_loss_fn_args["kl_coef"] # type: ignore + self.actor_rollout_ref.actor.entropy_coeff = config.algorithm.entropy_loss_fn_args[ # type: ignore + "entropy_coef" + ] # TODO (yanxi): it seems that adv_estimator now only affects whether use_critic is set to # True or False in RayPPOTrainer.__init__() (and hence in VerlPPOTrainerWrapper). # Need to double check whether this is indeed the case, diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py index 7208f83fb4..c2ad5fec96 100644 --- a/trinity/trainer/trainer.py +++ b/trinity/trainer/trainer.py @@ -80,6 +80,10 @@ def train_step(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> Tuple[bool policy_loss_fn_args={ "use_token_level_loss": self.config.algorithm.use_token_level_loss }, + kl_loss_fn="none", + kl_loss_fn_args={}, + entropy_loss_fn="basic", + entropy_loss_fn_args=self.config.algorithm.entropy_loss_fn_args, ) self.engine.set_algorithm(algorithm_config) else: diff --git a/trinity/trainer/verl/dp_actor.py b/trinity/trainer/verl/dp_actor.py index 97cd186c36..a7705fc6a0 100644 --- a/trinity/trainer/verl/dp_actor.py +++ b/trinity/trainer/verl/dp_actor.py @@ -26,14 +26,14 @@ from verl import DataProto from verl.utils.py_functional import append_to_dict from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches -from verl.utils.torch_functional import logprobs_from_logits, masked_mean +from verl.utils.torch_functional import logprobs_from_logits from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs from verl.workers.actor import BasePPOActor -from trinity.algorithm import POLICY_LOSS_FN +from trinity.algorithm import ENTROPY_LOSS_FN, KL_FN, POLICY_LOSS_FN +from trinity.algorithm.utils import prefix_metrics from trinity.common.config import AlgorithmConfig from trinity.common.constants import AlgorithmType -from trinity.trainer.verl import core_algos __all__ = ["DataParallelPPOActor"] @@ -63,6 +63,10 @@ def set_algorithm(self, algorithm_config: AlgorithmConfig): self.policy_loss_fn = POLICY_LOSS_FN.get(algorithm_config.policy_loss_fn)( **algorithm_config.policy_loss_fn_args ) + self.kl_loss_fn = KL_FN.get(algorithm_config.kl_loss_fn)(**algorithm_config.kl_loss_fn_args) + self.entropy_loss_fn = ENTROPY_LOSS_FN.get(algorithm_config.entropy_loss_fn)( + **algorithm_config.entropy_loss_fn_args + ) def _forward_micro_batch(self, micro_batch, temperature) -> Tuple[torch.Tensor, torch.Tensor]: """ @@ -347,6 +351,8 @@ def update_policy(self, data: DataProto): # noqa: C901 self.actor_optimizer.zero_grad() for data in micro_batches: + micro_batch_metrics = {} + # Support all hardwares if isinstance(data, DataProto): data = { @@ -362,7 +368,6 @@ def update_policy(self, data: DataProto): # noqa: C901 attention_mask = data["attention_mask"] response_mask = data["response_mask"] assert response_mask.shape == attention_mask[:, -response_length:].shape - entropy_coeff = self.config.entropy_coeff # all return: (bsz, response_length) entropy, log_prob = self._forward_micro_batch( @@ -374,30 +379,37 @@ def update_policy(self, data: DataProto): # noqa: C901 for verl_key, value in data.items() if verl_key in select_keys_verl2trinity } - pg_loss, metric = self.policy_loss_fn( # type: ignore + pg_loss, pg_loss_metrics = self.policy_loss_fn( # type: ignore logprob=log_prob, **kwargs, ) + prefix_metrics( + src_metrics=pg_loss_metrics, prefix="actor", dst_metrics=micro_batch_metrics + ) # compute entropy loss from entropy - entropy_loss = verl_F.masked_mean(entropy, response_mask) + entropy_loss, entropy_loss_metrics = self.entropy_loss_fn( + entropy=entropy, + action_mask=response_mask, + ) + prefix_metrics( + src_metrics=entropy_loss_metrics, + prefix="actor", + dst_metrics=micro_batch_metrics, + ) # compute policy loss - policy_loss = pg_loss - entropy_loss * entropy_coeff - - if self.config.use_kl_loss: - ref_log_prob = data["ref_log_prob"] - # compute kl loss - kld = core_algos.kl_penalty( - logprob=log_prob, - ref_logprob=ref_log_prob, - kl_penalty=self.config.kl_loss_type, - ) - kl_loss = masked_mean(kld, response_mask) + policy_loss = pg_loss - entropy_loss - policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef - metrics["actor/kl_loss"] = kl_loss.detach().item() - metrics["actor/kl_coef"] = self.config.kl_loss_coef + kl_loss, kl_loss_metrics = self.kl_loss_fn.calculate_kl_loss( + logprob=log_prob, + ref_logprob=data["ref_log_prob"], + response_mask=response_mask, + ) + prefix_metrics( + src_metrics=kl_loss_metrics, prefix="actor", dst_metrics=micro_batch_metrics + ) + policy_loss = policy_loss + kl_loss if self.config.use_dynamic_bsz: # relative to the dynamic bsz @@ -406,13 +418,10 @@ def update_policy(self, data: DataProto): # noqa: C901 loss = policy_loss / self.gradient_accumulation loss.backward() - data = {f"actor/{key}": value for key, value in metric.items()} - # TODO: refactor entropy loss - data["actor/entropy_loss"] = entropy_loss.detach().item() - append_to_dict(metrics, data) + append_to_dict(metrics, micro_batch_metrics) grad_norm = self._optimizer_step() data = {"actor/grad_norm": grad_norm.detach().item()} - append_to_dict(metrics, data) + append_to_dict(metrics, data) self.actor_optimizer.zero_grad() return metrics diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index ca02b6c288..83e3480dc3 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -22,7 +22,8 @@ from verl.utils import hf_tokenizer from verl.utils.fs import copy_local_path_from_hdfs -from trinity.algorithm import ADVANTAGE_FN +from trinity.algorithm import ADVANTAGE_FN, KL_FN +from trinity.algorithm.utils import prefix_metrics from trinity.common.config import AlgorithmConfig, Config from trinity.common.constants import AlgorithmType from trinity.common.experience import Experiences @@ -34,7 +35,6 @@ ResourcePoolManager, Role, _timer, - apply_kl_penalty, find_latest_ckpt_path, ) from trinity.utils.monitor import Monitor @@ -133,9 +133,10 @@ def __init__( # specify advantage function for various rft algorithms algo_config = global_config.algorithm if algo_config.algorithm_type.is_rft(): - adv_fn_type = algo_config.advantage_fn_type - adv_fn_args = algo_config.advantage_fn_args - self.advantage_fn = ADVANTAGE_FN.get(adv_fn_type)(**adv_fn_args) + self.advantage_fn = ADVANTAGE_FN.get(algo_config.advantage_fn)( + **algo_config.advantage_fn_args + ) + self.kl_fn = KL_FN.get(algo_config.kl_penalty_fn)(**algo_config.kl_penalty_fn_args) self.logger = Monitor( project=config.trainer.project_name, @@ -373,17 +374,9 @@ def train_rft_step(self, experiences: Experiences) -> Tuple[bool, int]: batch = batch.union(values) with _timer("adv", timing_raw): - # compute rewards. apply_kl_penalty if available - if not self.config.actor_rollout_ref.actor.get("use_kl_loss", False): - batch, kl_metrics = apply_kl_penalty( - batch, - kl_ctrl=self.kl_ctrl, - kl_penalty=self.config.algorithm.kl_penalty, - ) - metrics.update(kl_metrics) - else: - batch.batch["token_level_rewards"] = batch.batch["token_level_scores"] - + # compute kl penalty + batch, kl_metrics = self.kl_fn.apply_kl_penalty_to_reward(batch) + metrics.update(prefix_metrics(kl_metrics, prefix="critic")) # compute advantages, executed on the driver process batch, _ = self.advantage_fn(batch)