diff --git a/tests/template/config.yaml b/tests/template/config.yaml index 09b6f9ca0d..a83a82655f 100644 --- a/tests/template/config.yaml +++ b/tests/template/config.yaml @@ -5,6 +5,9 @@ checkpoint_root_dir: '' algorithm: algorithm_type: ppo repeat_times: 1 + policy_loss_fn: ppo + policy_loss_fn_args: + clip_range: 0.2 model: model_path: '' max_prompt_tokens: 2048 diff --git a/trinity/algorithm/__init__.py b/trinity/algorithm/__init__.py index f65ec67b47..51d3da8317 100644 --- a/trinity/algorithm/__init__.py +++ b/trinity/algorithm/__init__.py @@ -1,5 +1,5 @@ from trinity.algorithm.advantage_fn.advantage_fn import ADVANTAGE_FN, AdvantageFn -from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn +from trinity.algorithm.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn __all__ = [ "AdvantageFn", diff --git a/trinity/algorithm/policy_loss_fn/__init__.py b/trinity/algorithm/policy_loss_fn/__init__.py index e69de29bb2..66dce16cab 100644 --- a/trinity/algorithm/policy_loss_fn/__init__.py +++ b/trinity/algorithm/policy_loss_fn/__init__.py @@ -0,0 +1,14 @@ +from trinity.algorithm.policy_loss_fn.dpo_loss import DPOLossFn +from trinity.algorithm.policy_loss_fn.opmd_policy_loss import OPMDPolicyLossFn +from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn +from trinity.algorithm.policy_loss_fn.ppo_policy_loss import PPOPolicyLossFn +from trinity.algorithm.policy_loss_fn.sft_loss import SFTLossFn + +__all__ = [ + "POLICY_LOSS_FN", + "PolicyLossFn", + "PPOPolicyLossFn", + "OPMDPolicyLossFn", + "DPOLossFn", + "SFTLossFn", +] diff --git a/trinity/algorithm/policy_loss_fn/dpo_loss.py b/trinity/algorithm/policy_loss_fn/dpo_loss.py new file mode 100644 index 0000000000..3a9ea92f5c --- /dev/null +++ b/trinity/algorithm/policy_loss_fn/dpo_loss.py @@ -0,0 +1,67 @@ +"""DPO loss function.""" + +from typing import Any, Dict, Tuple + +import torch +import torch.nn.functional as F + +from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn +from trinity.algorithm.utils import masked_sum + + +@POLICY_LOSS_FN.register_module("dpo") +class DPOLossFn(PolicyLossFn): + def __init__( + self, + beta: float = 0.1, + label_smoothing: float = 0.0, + ) -> None: + self.beta = beta + self.label_smoothing = label_smoothing + + def __call__( + self, + logprob: torch.Tensor, + old_logprob: torch.Tensor, + action_mask: torch.Tensor, + advantages: torch.Tensor, + experiences: Any, + **kwargs, + ) -> Tuple[torch.Tensor, Dict]: + chosen_logprob = logprob[::2] + rejected_logprob = logprob[1::2] + chosen_mask = action_mask[::2] + rejected_mask = action_mask[1::2] + chosen_logprob_sum = masked_sum(chosen_logprob, chosen_mask) + rejected_logprob_sum = masked_sum(rejected_logprob, rejected_mask) + + chosen_ref_logprob = old_logprob[::2] + rejected_ref_logprob = old_logprob[1::2] + chosen_ref_logprob_sum = masked_sum(chosen_ref_logprob, chosen_mask) + rejected_ref_logprob_sum = masked_sum(rejected_ref_logprob, rejected_mask) + + chosen_ratios = chosen_logprob_sum - chosen_ref_logprob_sum + rejected_ratios = rejected_logprob_sum - rejected_ref_logprob_sum + logits = chosen_ratios - rejected_ratios + # TODO: support other loss functions + losses = ( + -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) + - F.logsigmoid(-self.beta * logits) * self.label_smoothing + ) + loss = losses.mean() + chosen_reward = self.beta * chosen_ratios.detach().mean().item() + rejected_reward = self.beta * rejected_ratios.detach().mean().item() + accuracy_mean = (chosen_ratios.detach() > rejected_ratios.detach()).float().mean().item() + return loss, { + "chosen_reward": chosen_reward, + "rejected_reward": rejected_reward, + "accuracy_mean": accuracy_mean, + "dpo_loss": loss.detach().item(), + } + + @classmethod + def default_args(cls) -> Dict: + return { + "beta": 0.1, + "label_smoothing": 0.0, + } diff --git a/trinity/algorithm/policy_loss_fn/opmd_policy_loss.py b/trinity/algorithm/policy_loss_fn/opmd_policy_loss.py new file mode 100644 index 0000000000..dd521f9ee0 --- /dev/null +++ b/trinity/algorithm/policy_loss_fn/opmd_policy_loss.py @@ -0,0 +1,35 @@ +"""PPO policy loss function. + +Modified from https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py +""" + +from typing import Any, Dict, Tuple + +import torch + +from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn +from trinity.algorithm.utils import masked_mean + + +@POLICY_LOSS_FN.register_module("opmd") +class OPMDPolicyLossFn(PolicyLossFn): + def __init__(self, tau: float = 1.0) -> None: + self.tau = tau + + def __call__( + self, + logprob: torch.Tensor, + old_logprob: torch.Tensor, + action_mask: torch.Tensor, + advantages: torch.Tensor, + experiences: Any, + **kwargs, + ) -> Tuple[torch.Tensor, Dict]: + pg_losses = -advantages * logprob + opmd_loss = masked_mean(pg_losses, action_mask) + opmd_loss = opmd_loss / (1.0 + self.tau) # for regularization (w.r.t. current pi_theta) + return opmd_loss, {"opmd_loss": opmd_loss.detach().item()} + + @classmethod + def default_args(cls) -> Dict: + return {"tau": 1.0} diff --git a/trinity/algorithm/policy_loss_fn/policy_loss_fn.py b/trinity/algorithm/policy_loss_fn/policy_loss_fn.py index 392f80e521..eb02c49b46 100644 --- a/trinity/algorithm/policy_loss_fn/policy_loss_fn.py +++ b/trinity/algorithm/policy_loss_fn/policy_loss_fn.py @@ -36,3 +36,11 @@ def __call__( `torch.Tensor`: Policy loss `Dict`: The metrics for logging. """ + + @classmethod + @abstractmethod + def default_args(cls) -> Dict: + """ + Returns: + `Dict`: The default init arguments for the policy loss function. + """ diff --git a/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py b/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py new file mode 100644 index 0000000000..9831f048d6 --- /dev/null +++ b/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py @@ -0,0 +1,64 @@ +"""PPO policy loss function. + +Modified from https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py +""" + +from typing import Any, Dict, Optional, Tuple + +import torch + +from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn +from trinity.algorithm.utils import masked_mean + + +@POLICY_LOSS_FN.register_module("ppo") +class PPOPolicyLossFn(PolicyLossFn): + def __init__( + self, + clip_range: Optional[float] = None, + clip_range_low: Optional[float] = None, + clip_range_high: Optional[float] = None, + ) -> None: + if clip_range_low is None: + self.clip_range_low = clip_range + else: + self.clip_range_low = clip_range_low + if clip_range_high is None: + self.clip_range_high = clip_range + else: + self.clip_range_high = clip_range_high + assert self.clip_range_low is not None, "clip_range_low must be specified." + assert self.clip_range_high is not None, "clip_range_high must be specified." + + def __call__( + self, + logprob: torch.Tensor, + old_logprob: torch.Tensor, + action_mask: torch.Tensor, + advantages: torch.Tensor, + experiences: Any, + **kwargs, + ) -> Tuple[torch.Tensor, Dict]: + negative_approx_kl = logprob - old_logprob + ratio = torch.exp(negative_approx_kl) + ppo_kl = masked_mean(-negative_approx_kl, action_mask) + + pg_losses = -advantages * ratio + pg_losses2 = -advantages * torch.clamp( + ratio, 1.0 - self.clip_range_low, 1.0 + self.clip_range_high # type: ignore + ) + + pg_loss = masked_mean(torch.max(pg_losses, pg_losses2), action_mask) + pg_clipfrac = masked_mean(torch.gt(pg_losses2, pg_losses).float(), action_mask) + metrics = { + "pg_clipfrac": pg_clipfrac.detach().item(), + "ppo_kl": ppo_kl.detach().item(), + "pg_loss": pg_loss.detach().item(), + } + return pg_loss, metrics + + @classmethod + def default_args(cls) -> Dict: + return { + "clip_range": 0.2, + } diff --git a/trinity/algorithm/policy_loss_fn/sft_loss.py b/trinity/algorithm/policy_loss_fn/sft_loss.py new file mode 100644 index 0000000000..c04f775fa3 --- /dev/null +++ b/trinity/algorithm/policy_loss_fn/sft_loss.py @@ -0,0 +1,35 @@ +"""SFT loss function.""" + +from typing import Any, Dict, Tuple + +import torch + +from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn +from trinity.algorithm.utils import masked_mean + + +@POLICY_LOSS_FN.register_module("sft") +class SFTLossFn(PolicyLossFn): + def __init__(self, use_token_level_loss: bool = True) -> None: + self.use_token_level_loss = use_token_level_loss + + def __call__( + self, + logprob: torch.Tensor, + old_logprob: torch.Tensor, + action_mask: torch.Tensor, + advantages: torch.Tensor, + experiences: Any, + **kwargs, + ) -> Tuple[torch.Tensor, Dict]: + if self.use_token_level_loss: + sft_loss = masked_mean(-logprob, action_mask) + else: + sft_loss = masked_mean(-logprob, action_mask, axis=1).mean() + return sft_loss, {"sft_loss": sft_loss.detach().item()} + + @classmethod + def default_args(cls): + return { + "use_token_level_loss": True, + } diff --git a/trinity/algorithm/utils.py b/trinity/algorithm/utils.py new file mode 100644 index 0000000000..d5cfb72d8c --- /dev/null +++ b/trinity/algorithm/utils.py @@ -0,0 +1,14 @@ +"""Common utils for algorithm module. + +Modified from https://github.com/volcengine/verl/blob/main/verl/utils/torch_functional.py +""" + + +def masked_sum(values, mask, axis=None): + """Compute mean of tensor with a masked values.""" + return (values * mask).sum(axis=axis) + + +def masked_mean(values, mask, axis=None): + """Compute mean of tensor with a masked values.""" + return (values * mask).sum(axis=axis) / (mask.sum(axis=axis) + 1e-8) diff --git a/trinity/common/config.py b/trinity/common/config.py index e0660ab03a..9c3b582618 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -175,7 +175,10 @@ class AlgorithmConfig: repeat_times: int = 1 gamma: Optional[float] = None lam: Optional[float] = None - # TODO: add more algorithm params here + + policy_loss_fn: str = "ppo" + # If not set, use PolicyLossFn.default_args() + policy_loss_fn_args: Optional[dict] = None @dataclass @@ -466,6 +469,15 @@ def _check_buffer(self) -> None: # noqa: C901 self.buffer.pad_token_id = 0 self.buffer.tokenizer_path = self.model.model_path + def _check_algorithm(self) -> None: + from trinity.algorithm import POLICY_LOSS_FN + + policy_fn_cls = POLICY_LOSS_FN.get(self.algorithm.policy_loss_fn) + if policy_fn_cls is None: + raise ValueError(f"Invalid policy_loss_fn: {self.algorithm.policy_loss_fn}") + if self.algorithm.policy_loss_fn_args is None: + self.algorithm.policy_loss_fn_args = policy_fn_cls.default_args() + def check_and_update(self) -> None: # noqa: C901 """Check and update the config.""" self._check_deprecated() @@ -489,6 +501,9 @@ def check_and_update(self) -> None: # noqa: C901 if not self.model.critic_model_path: self.model.critic_model_path = self.model.model_path + # check algorithm + self._check_algorithm() + # check explorer if ( self.explorer.rollout_model.engine_type != "vllm_async" diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index e5d0d9d55f..a4d9a6e8d9 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -190,6 +190,10 @@ class Algorithm: kl_penalty: str = "kl" kl_ctrl: KL_Ctrl = field(default_factory=KL_Ctrl) + # ! DO NOT SET THE FLOWING PARAMETERS + policy_loss_fn: str = "ppo" + policy_loss_fn_args: Optional[dict] = None + @dataclass class Trainer: diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py index 36d23e7628..876ca2835f 100644 --- a/trinity/trainer/trainer.py +++ b/trinity/trainer/trainer.py @@ -13,7 +13,7 @@ import ray from trinity.buffer import get_buffer_reader -from trinity.common.config import Config +from trinity.common.config import AlgorithmConfig, Config from trinity.common.constants import AlgorithmType, SyncMethod from trinity.common.experience import Experiences from trinity.utils.log import get_logger @@ -73,7 +73,7 @@ def train_step(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> Tuple[bool Returns: bool: Whether to continue training. """ - self.engine.set_mode(algo_type) + self.engine.set_algorithm(self.config.algorithm) if algo_type.is_rft() and self.config.buffer.trainer_input.read_experience_strategy: strategy = self.config.buffer.trainer_input.read_experience_strategy else: @@ -157,8 +157,8 @@ def sync_weight(self) -> None: """Sync the model weight.""" @abstractmethod - def set_mode(self, algo_type: AlgorithmType) -> None: - """Set training mode.""" + def set_algorithm(self, algorithm_config: AlgorithmConfig) -> None: + """Set training algorithm config.""" @abstractmethod def shutdown(self) -> None: diff --git a/trinity/trainer/verl/dp_actor.py b/trinity/trainer/verl/dp_actor.py index 246cd1f21c..b598bb6dad 100644 --- a/trinity/trainer/verl/dp_actor.py +++ b/trinity/trainer/verl/dp_actor.py @@ -30,6 +30,8 @@ from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs from verl.workers.actor import BasePPOActor +from trinity.algorithm import POLICY_LOSS_FN +from trinity.common.config import AlgorithmConfig from trinity.common.constants import AlgorithmType from trinity.trainer.verl import core_algos @@ -54,9 +56,13 @@ def __init__( self.compute_entropy_from_logits = torch.compile(verl_F.entropy_from_logits, dynamic=True) self.algorithm_type = AlgorithmType.PPO + self.policy_loss_fn = None - def set_mode(self, algorithm_type: AlgorithmType = AlgorithmType.PPO): - self.algorithm_type = algorithm_type + def set_algorithm(self, algorithm_config: AlgorithmConfig): + self.algorithm_type = algorithm_config.algorithm_type + self.policy_loss_fn = POLICY_LOSS_FN.get(algorithm_config.policy_loss_fn)( + **algorithm_config.policy_loss_fn_args + ) def _forward_micro_batch(self, micro_batch, temperature) -> Tuple[torch.Tensor, torch.Tensor]: """ @@ -129,27 +135,6 @@ def _forward_micro_batch(self, micro_batch, temperature) -> Tuple[torch.Tensor, use_cache=False, ) # prevent model thinks we are generating logits_rmpad = output.logits.squeeze(0) # (total_nnz, vocab_size) - if self.algorithm_type.is_sft(): # SFT - loss_fct = nn.CrossEntropyLoss(reduction="none") - loss = loss_fct(logits_rmpad, input_ids_rmpad_rolled) - if self.use_ulysses_sp: - loss = gather_outpus_and_unpad( - loss, gather_dim=0, unpad_dim=0, padding_size=pad_size - ) - response_mask = attention_mask[:, -response_length:].bool() - # pad back to (bsz, seqlen) - full_loss = pad_input( - hidden_states=loss.unsqueeze(-1), - indices=indices, - batch=batch_size, - seqlen=seqlen, - ).squeeze(-1) - full_loss = torch.where( - response_mask, full_loss[:, -response_length - 1 : -1], 0.0 - ) - full_loss = full_loss.sum(-1) / response_mask.sum(-1) - full_loss = full_loss.mean() - return full_loss logits_rmpad.div_(temperature) @@ -201,21 +186,6 @@ def _forward_micro_batch(self, micro_batch, temperature) -> Tuple[torch.Tensor, use_cache=False, ) # prevent model thinks we are generating logits = output.logits - if self.algorithm_type.is_sft(): - loss_fct = nn.CrossEntropyLoss(reduction="none", ignore_index=-100) - response_mask = attention_mask[:, -response_length:].bool() - response_labels = torch.where( - response_mask, input_ids[:, -response_length:], -100 - ) - response_logits = logits[:, -response_length - 1 : -1, :] - loss = loss_fct( - response_logits.reshape(-1, response_logits.shape[-1]), - response_labels.reshape(-1), - ) - loss = loss.view(response_labels.shape) - loss = loss.sum(-1) / response_mask.sum(-1) - loss = loss.mean() - return loss logits.div_(temperature) logits = logits[ :, -response_length - 1 : -1, : @@ -308,57 +278,25 @@ def update_policy(self, data: DataProto): # noqa: C901 temperature = data.meta_info[ "temperature" ] # temperature must be in the data.meta_info to avoid slient error - - algorithm_type: AlgorithmType = self.config.get("algorithm_type", AlgorithmType.PPO) - if self.algorithm_type.is_rft(): - select_keys = [ - "responses", - "input_ids", - "attention_mask", - "position_ids", - "old_log_probs", - "advantages", - "response_mask", - ] - if self.config.use_kl_loss: - select_keys.append("ref_log_prob") - - if algorithm_type == AlgorithmType.PAIRWISE_OPMD: - select_keys.append("token_level_scores") - elif self.algorithm_type.is_dpo(): - select_keys = [ - "attention_mask", - "input_ids", - "position_ids", - "response_mask", - "responses", - "ref_log_prob", - ] - else: # sft - select_keys = [ - "attention_mask", - "input_ids", - "position_ids", - "response_mask", - "responses", - ] - use_uid = self.config.get("use_uid", False) - + select_keys = [ + "responses", + "input_ids", + "attention_mask", + "position_ids", + "old_log_probs", + "advantages", + "response_mask", + ] + if self.config.use_kl_loss: + select_keys.append("ref_log_prob") batch = data.select(batch_keys=select_keys).batch has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() # Split to make minibatch iterator for updating the actor # See PPO paper for details. https://arxiv.org/abs/1707.06347 - if has_multi_modal_inputs or ((algorithm_type == AlgorithmType.PAIRWISE_OPMD) and use_uid): - # TODO: for now, we treat algorithm_type == AlgorithmType.PAIRWISE_OPMD in the same way that - # has_multi_modal_inputs was treated originally (to handle non_tensor_select_keys); - # need to double check if this is the best approach. + if has_multi_modal_inputs: num_mini_batches = data.batch.batch_size[0] // self.config.ppo_mini_batch_size - non_tensor_select_keys = [] - if has_multi_modal_inputs: - non_tensor_select_keys.append("multi_modal_inputs") - if (algorithm_type == AlgorithmType.PAIRWISE_OPMD) and use_uid: - non_tensor_select_keys.append("uid") + non_tensor_select_keys = ["multi_modal_inputs"] dataloader = data.select(select_keys, non_tensor_select_keys).chunk(num_mini_batches) else: dataloader = batch.split(self.config.ppo_mini_batch_size) @@ -373,9 +311,7 @@ def update_policy(self, data: DataProto): # noqa: C901 for batch_idx, data in enumerate(dataloader): # split batch into micro_batches mini_batch = data - if has_multi_modal_inputs or ( - (algorithm_type == AlgorithmType.PAIRWISE_OPMD) and use_uid - ): + if has_multi_modal_inputs: self.gradient_accumulation = ( self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu ) @@ -412,93 +348,48 @@ def update_policy(self, data: DataProto): # noqa: C901 data = data.to( torch.cuda.current_device() ) # actor device is cpu when using offload + responses = data["responses"] + response_length = responses.size(1) + attention_mask = data["attention_mask"] + # response_mask = attention_mask[:, -response_length:] + response_mask = data["response_mask"] + assert response_mask.shape == attention_mask[:, -response_length:].shape + old_log_prob = data["old_log_probs"] + advantages = data["advantages"] + entropy_coeff = self.config.entropy_coeff + + # all return: (bsz, response_length) + entropy, log_prob = self._forward_micro_batch( + micro_batch=data, temperature=temperature + ) - # TODO: it is better to unify the returns of several modes (sft, dpo) - if self.algorithm_type.is_sft(): - policy_loss = self._forward_micro_batch( - micro_batch=data, temperature=temperature - ) + pg_loss, metric = self.policy_loss_fn( # type: ignore + logprob=log_prob, + old_logprob=old_log_prob, + action_mask=response_mask, + advantages=advantages, + experiences=data, + ) - elif self.algorithm_type.is_dpo(): - response_mask = data["response_mask"] + # compute entropy loss from entropy + entropy_loss = verl_F.masked_mean(entropy, response_mask) - _, log_prob = self._forward_micro_batch( - micro_batch=data, temperature=temperature - ) - if self.config.use_kl_loss: - ref_log_prob = data["ref_log_prob"] - else: - ref_log_prob = None - - ( - policy_loss, - chosen_reward, - rejected_reward, - ) = core_algos.compute_policy_loss_dpo( - log_prob=log_prob, - ref_log_prob=ref_log_prob, - eos_mask=response_mask, - beta=self.config.kl_loss_coef, - # label_smoothing=self.config.label_smoothing # TODO: add configs for dpo - ) + # compute policy loss + policy_loss = pg_loss - entropy_loss * entropy_coeff - else: # rft - responses = data["responses"] - response_length = responses.size(1) - attention_mask = data["attention_mask"] - # response_mask = attention_mask[:, -response_length:] - response_mask = data["response_mask"] - assert response_mask.shape == attention_mask[:, -response_length:].shape - old_log_prob = data["old_log_probs"] - advantages = data["advantages"] - - clip_ratio = self.config.clip_ratio - entropy_coeff = self.config.entropy_coeff - - tau = self.config.get("tau", 1.0) - token_level_scores = None - index = None - if algorithm_type == AlgorithmType.PAIRWISE_OPMD: - token_level_scores = data["token_level_scores"] - if use_uid: - index = data["uid"] - - # all return: (bsz, response_length) - entropy, log_prob = self._forward_micro_batch( - micro_batch=data, temperature=temperature + if self.config.use_kl_loss: + ref_log_prob = data["ref_log_prob"] + # compute kl loss + kld = core_algos.kl_penalty( + logprob=log_prob, + ref_logprob=ref_log_prob, + kl_penalty=self.config.kl_loss_type, ) + kl_loss = masked_mean(kld, response_mask) - pg_loss, pg_clipfrac, ppo_kl = core_algos.compute_policy_loss( - old_log_prob=old_log_prob, - log_prob=log_prob, - eos_mask=response_mask, - algorithm_type=algorithm_type, - advantages=advantages, - cliprange=clip_ratio, - # for opmd / pairwise_opmd - tau=tau, - token_level_scores=token_level_scores, - index=index, - ) - # compute entropy loss from entropy - entropy_loss = verl_F.masked_mean(entropy, response_mask) - - # compute policy loss - policy_loss = pg_loss - entropy_loss * entropy_coeff - - if self.config.use_kl_loss: - ref_log_prob = data["ref_log_prob"] - # compute kl loss - kld = core_algos.kl_penalty( - logprob=log_prob, - ref_logprob=ref_log_prob, - kl_penalty=self.config.kl_loss_type, - ) - kl_loss = masked_mean(kld, response_mask) - - policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef - metrics["actor/kl_loss"] = kl_loss.detach().item() - metrics["actor/kl_coef"] = self.config.kl_loss_coef + policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef + metrics["actor/kl_loss"] = kl_loss.detach().item() + metrics["actor/kl_coef"] = self.config.kl_loss_coef if self.config.use_dynamic_bsz: # relative to the dynamic bsz @@ -507,28 +398,9 @@ def update_policy(self, data: DataProto): # noqa: C901 loss = policy_loss / self.gradient_accumulation loss.backward() - if self.algorithm_type.is_rft(): - data = { - "actor/entropy_loss": entropy_loss.detach().item(), - "actor/pg_loss": pg_loss.detach().item(), - "actor/pg_clipfrac": pg_clipfrac.detach().item(), - "actor/ppo_kl": ppo_kl.detach().item(), - } - elif self.algorithm_type.is_dpo(): - data = { - "dpo/loss": policy_loss.detach().item(), - "dpo/loss_mean": loss.detach().item(), - "dpo/chosen_reward": chosen_reward.detach().mean().item(), - "dpo/rejected_reward": rejected_reward.detach().mean().item(), - "dpo/accuracy_mean": (chosen_reward > rejected_reward) - .float() - .mean() - .item(), - } - else: - data = { - "sft/loss": loss.detach().item(), - } + data = {f"actor/{key}": value for key, value in metric.items()} + # TODO: refactor entropy loss + data["actor/entropy_loss"] = entropy_loss.detach().item() append_to_dict(metrics, data) grad_norm = self._optimizer_step() diff --git a/trinity/trainer/verl/fsdp_workers.py b/trinity/trainer/verl/fsdp_workers.py index 26b640e871..c0af427b4a 100644 --- a/trinity/trainer/verl/fsdp_workers.py +++ b/trinity/trainer/verl/fsdp_workers.py @@ -50,7 +50,8 @@ from verl.utils.model import compute_position_id_with_mask from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager -from trinity.common.constants import AlgorithmType, SyncMethod +from trinity.common.config import AlgorithmConfig +from trinity.common.constants import SyncMethod from trinity.utils.distributed import init_process_group, is_ipv6_address logger = logging.getLogger(__file__) @@ -623,8 +624,8 @@ def sync_weight(self): torch.cuda.empty_cache() @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def set_mode(self, algo_type: AlgorithmType = AlgorithmType.PPO): - self.actor.set_mode(algo_type) + def set_algorithm(self, algo_config: AlgorithmConfig): + self.actor.set_algorithm(algo_config) @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def update_actor(self, data: DataProto): diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index 7590d6075b..5324a13f7c 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -13,7 +13,7 @@ from verl.utils import hf_tokenizer from verl.utils.fs import copy_local_path_from_hdfs -from trinity.common.config import Config +from trinity.common.config import AlgorithmConfig, Config from trinity.common.constants import AlgorithmType from trinity.common.experience import Experiences from trinity.trainer.trainer import TrainEngineWrapper @@ -125,9 +125,7 @@ def __init__( ray_worker_group_cls, ) self.init_workers() - self.algorithm_type = ( - AlgorithmType.PPO - ) # TODO: initialize algorithm_type according to config + self.algorithm_type = AlgorithmType.PPO self.logger = Monitor( project=config.trainer.project_name, name=config.trainer.experiment_name, @@ -499,11 +497,11 @@ def save_checkpoint(self) -> None: def sync_weight(self) -> None: self.actor_rollout_wg.sync_weight() - def set_mode(self, algorithm_type: AlgorithmType = AlgorithmType.PPO) -> None: - self.actor_rollout_wg.set_mode(algorithm_type) - if self.algorithm_type.is_sft() and (not algorithm_type.is_sft()): + def set_algorithm(self, algorithm_config: AlgorithmConfig) -> None: + self.actor_rollout_wg.set_algorithm(algorithm_config) + if self.algorithm_type.is_sft() and (not algorithm_config.algorithm_type.is_sft()): self.sft_to_rft() - self.algorithm_type = algorithm_type + self.algorithm_type = algorithm_config.algorithm_type def sft_to_rft(self) -> None: # load from hdfs diff --git a/trinity/utils/registry.py b/trinity/utils/registry.py index 70fb2930c9..b31f6872bd 100644 --- a/trinity/utils/registry.py +++ b/trinity/utils/registry.py @@ -22,6 +22,8 @@ logger = get_logger(__name__) +# TODO: support lazy load +# e.g. @MODULES.register_module("name", lazy=True) class Registry(object): """This class is used to register some modules to registry by a repo name."""