diff --git a/docs/sphinx_doc/source/conf.py b/docs/sphinx_doc/source/conf.py index 4842a34557..ffaabf72c9 100644 --- a/docs/sphinx_doc/source/conf.py +++ b/docs/sphinx_doc/source/conf.py @@ -22,12 +22,13 @@ "sphinx.ext.napoleon", "sphinx.ext.autosectionlabel", "myst_parser", + "sphinx.ext.mathjax", ] source_suffix = { ".rst": "restructuredtext", ".md": "markdown", } -myst_enable_extensions = ["colon_fence"] +myst_enable_extensions = ["colon_fence", "amsmath", "dollarmath"] # Prefix document path to section labels, otherwise autogenerated labels would # look like 'heading' rather than 'path/to/file:heading' diff --git a/docs/sphinx_doc/source/index.rst b/docs/sphinx_doc/source/index.rst index 5604faa15d..a1b6fde647 100644 --- a/docs/sphinx_doc/source/index.rst +++ b/docs/sphinx_doc/source/index.rst @@ -24,6 +24,7 @@ Welcome to Trinity-RFT's documentation! tutorial/example_data_functionalities.md tutorial/trinity_configs.md tutorial/trinity_programming_guide.md + tutorial/example_mix_algo.md .. toctree:: :maxdepth: 1 diff --git a/docs/sphinx_doc/source/tutorial/example_mix_algo.md b/docs/sphinx_doc/source/tutorial/example_mix_algo.md index 9dadc76b40..ee0010ba24 100644 --- a/docs/sphinx_doc/source/tutorial/example_mix_algo.md +++ b/docs/sphinx_doc/source/tutorial/example_mix_algo.md @@ -1,4 +1,4 @@ -# Integrate An New Algorithm +# Integrate A New Algorithm This guide introduces how to integrate a new algorithm to Trinity-RFT. @@ -6,7 +6,7 @@ As an example, we incorporate some "expert" data generated by a more advanced LL $$ \mathcal{J}_{\text{Mix}}(\theta) = -\mathcal{J}_{\text{GRPO}}(\theta) +(1-\mu) \mathcal{J}_{\text{GRPO}}(\theta) + \mu \cdot \underbrace{\frac{1}{B'} \sum_{b=1}^{B'} \left[ @@ -170,6 +170,7 @@ We define a `MixPolicyLoss` class in `trinity/algorithm/policy_loss_fn/mix_polic class MIXPolicyLossFn(PolicyLossFn): def __init__( self, + backend: str = "verl", mu: float = 0.1, clip_range: Optional[float] = None, clip_range_low: Optional[float] = None, @@ -183,6 +184,7 @@ class MIXPolicyLossFn(PolicyLossFn): read_batch_size_expert: Optional[int] = None, use_token_level_loss_in_sft: bool = True, ) -> None: + super().__init__(backend=backend) self.mu = mu self.use_dynamic_bsz = use_dynamic_bsz self.experience_per_gpu = ppo_mini_batch_size * repeat_times // ngpus_trainer # type: ignore @@ -204,11 +206,9 @@ class MIXPolicyLossFn(PolicyLossFn): old_logprob: torch.Tensor, action_mask: torch.Tensor, advantages: torch.Tensor, + is_expert_mask: torch.Tensor, **kwargs, ) -> Tuple[torch.Tensor, Dict]: - is_expert_mask = kwargs.get("is_expert_mask", None) - if is_expert_mask is None: - raise ValueError("is_expert_mask is required in MIX") assert ( len(is_expert_mask) == logprob.shape[0] ), f"Error: {len(is_expert_mask)=} != {logprob.shape[0]=}" @@ -271,10 +271,6 @@ class MIXPolicyLossFn(PolicyLossFn): "mu": 0.1, "clip_range": 0.2, } - - @property - def select_keys(self) -> List[str]: - return ["old_logprob", "action_mask", "advantages", "is_expert_mask"] ``` ## Step 4: Run the Experiment diff --git a/examples/mix_math/README.md b/examples/mix_math/README.md index 8e84f233bc..2ef160b0f2 100644 --- a/examples/mix_math/README.md +++ b/examples/mix_math/README.md @@ -4,4 +4,4 @@ This example shows the usage of a new algorithm MIX on the MATH dataset. For more detailed information, please refer to the [documentation](../../docs/sphinx_doc/source/tutorial/example_mix_algo.md). -The config files are located in [`mix_math.yaml`](mix.yaml) and [`train_mix_math.yaml`](train_mix_math.yaml). +The config files are located in [`mix_math.yaml`](mix_math.yaml) and [`train_mix_math.yaml`](train_mix_math.yaml). diff --git a/examples/mix_math/mix_math.yaml b/examples/mix_math/mix_math.yaml index 339d8df394..b92edd4b25 100644 --- a/examples/mix_math/mix_math.yaml +++ b/examples/mix_math/mix_math.yaml @@ -27,7 +27,6 @@ cluster: buffer: total_epochs: 1 batch_size: 40 - explore_batch_size: 36 max_retry_times: 3 max_retry_interval: 1 explorer_input: @@ -82,7 +81,7 @@ synchronizer: sync_timeout: 1200 trainer: trainer_type: 'verl' - trainer_config_path: 'examples/mix_math/train_math.yaml' + trainer_config_path: 'examples/mix_math/train_mix_math.yaml' save_interval: 50 monitor: monitor_type: wandb diff --git a/tests/algorithm/policy_loss_test.py b/tests/algorithm/policy_loss_test.py new file mode 100644 index 0000000000..ba88feb2d7 --- /dev/null +++ b/tests/algorithm/policy_loss_test.py @@ -0,0 +1,94 @@ +# -*- coding: utf-8 -*- +"""Test for policy loss functions""" + +import unittest + +import torch +from verl import DataProto + +from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN + + +class VerlPolicyLossTest(unittest.TestCase): + def setUp(self): + seed = 42 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + shape = (5, 20) + self.logprob = 2 * torch.rand(shape) - 1 + self.input_data = DataProto.from_dict( + { + "old_log_probs": 2 * torch.rand(shape) - 1, + "ref_log_prob": 2 * torch.rand(shape) - 1, + "response_mask": torch.rand(shape) > 0.5, + "advantages": 2 * torch.rand(shape) - 1, + "is_expert_mask": torch.rand(shape[0]) > 0.5, + } + ) + + def test_ppo_policy_loss(self): + policy_loss_fn_cls = POLICY_LOSS_FN.get("ppo") + policy_loss_fn_args = policy_loss_fn_cls.default_args() + policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args) + loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch) + ppo_loss = torch.tensor(0.28560468554496765) + pg_clipfrac = torch.tensor(0.3541666567325592) + ppo_kl = torch.tensor(-0.21663446724414825) + self.assertTrue(torch.allclose(loss, ppo_loss)) + self.assertTrue(torch.allclose(torch.tensor(metrics["pg_clipfrac"]), pg_clipfrac)) + self.assertTrue(torch.allclose(torch.tensor(metrics["ppo_kl"]), ppo_kl)) + self.assertTrue(torch.allclose(torch.tensor(metrics["pg_loss"]), ppo_loss)) + + def test_sft_policy_loss(self): + policy_loss_fn_cls = POLICY_LOSS_FN.get("sft") + policy_loss_fn_args = policy_loss_fn_cls.default_args() + policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args) + loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch) + sft_loss = torch.tensor(-0.07560186833143234) + self.assertTrue(torch.allclose(loss, sft_loss)) + self.assertTrue(torch.allclose(torch.tensor(metrics["sft_loss"]), sft_loss)) + + def test_dpo_policy_loss(self): + policy_loss_fn_cls = POLICY_LOSS_FN.get("dpo") + policy_loss_fn_args = policy_loss_fn_cls.default_args() + policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args) + loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch) + dpo_loss = torch.tensor(0.5406752228736877) + chosen_reward = torch.tensor(0.7082431316375732) + rejected_reward = torch.tensor(0.3757950782775879) + accuracy_mean = torch.tensor(1.0) + self.assertTrue(torch.allclose(loss, dpo_loss)) + self.assertTrue(torch.allclose(torch.tensor(metrics["chosen_reward"]), chosen_reward)) + self.assertTrue(torch.allclose(torch.tensor(metrics["rejected_reward"]), rejected_reward)) + self.assertTrue(torch.allclose(torch.tensor(metrics["accuracy_mean"]), accuracy_mean)) + self.assertTrue(torch.allclose(torch.tensor(metrics["dpo_loss"]), dpo_loss)) + + def test_opmd_policy_loss(self): + policy_loss_fn_cls = POLICY_LOSS_FN.get("opmd") + policy_loss_fn_args = policy_loss_fn_cls.default_args() + policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args) + loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch) + opmd_loss = torch.tensor(-0.009589947760105133) + self.assertTrue(torch.allclose(loss, opmd_loss)) + self.assertTrue(torch.allclose(torch.tensor(metrics["opmd_loss"]), opmd_loss)) + + def test_mix_policy_loss(self): + policy_loss_fn_cls = POLICY_LOSS_FN.get("mix") + policy_loss_fn_args = policy_loss_fn_cls.default_args() + policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args) + loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch) + mix_loss = torch.tensor(0.6581965088844299) + pg_clipfrac = torch.tensor(0.7777777910232544) + ppo_kl = torch.tensor(-1.0737695693969727) + pg_loss = torch.tensor(0.7236452102661133) + sft_loss = torch.tensor(0.06915830634534359) + self.assertTrue(torch.allclose(loss, mix_loss)) + self.assertTrue(torch.allclose(torch.tensor(metrics["usual/pg_clipfrac"]), pg_clipfrac)) + self.assertTrue(torch.allclose(torch.tensor(metrics["usual/ppo_kl"]), ppo_kl)) + self.assertTrue(torch.allclose(torch.tensor(metrics["usual/pg_loss"]), pg_loss)) + self.assertTrue(torch.allclose(torch.tensor(metrics["expert/sft_loss"]), sft_loss)) + self.assertTrue(torch.allclose(torch.tensor(metrics["loss"]), mix_loss)) diff --git a/tests/common/config_test.py b/tests/common/config_test.py index e1ac0aa7d4..da4fd914a0 100644 --- a/tests/common/config_test.py +++ b/tests/common/config_test.py @@ -47,6 +47,7 @@ def test_all_examples_are_valid(self): config_path = os.path.join(example_dir, example_name, filename) try: config = load_config(config_path) + config.checkpoint_root_dir = "./.cache/" config.check_and_update() except Exception as e: print(f"Error loading config {config_path}: {e}") diff --git a/trinity/algorithm/key_mapper.py b/trinity/algorithm/key_mapper.py new file mode 100644 index 0000000000..09c1f988a6 --- /dev/null +++ b/trinity/algorithm/key_mapper.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +"""Key Mapper""" + +from typing import Dict + + +class KeyMapper: + def __init__(self, to_trinity_map: Dict[str, str]): + self.to_trinity_map = to_trinity_map + self.from_trinity_map = {v: k for k, v in self.to_trinity_map.items()} + + def to_trinity(self, key: str) -> str: + return self.to_trinity_map.get(key, key) + + def from_trinity(self, key: str) -> str: + return self.from_trinity_map.get(key, key) + + +ALL_MAPPERS = { + "verl": KeyMapper( + { + "log_prob": "logprob", + "old_log_probs": "old_logprob", + "ref_log_prob": "ref_logprob", + "response_mask": "action_mask", + "advantages": "advantages", + } + ), +} diff --git a/trinity/algorithm/policy_loss_fn/dpo_loss.py b/trinity/algorithm/policy_loss_fn/dpo_loss.py index 7dfbb7141d..0858cb7002 100644 --- a/trinity/algorithm/policy_loss_fn/dpo_loss.py +++ b/trinity/algorithm/policy_loss_fn/dpo_loss.py @@ -1,6 +1,6 @@ """DPO loss function.""" -from typing import Dict, List, Tuple +from typing import Dict, Tuple import torch import torch.nn.functional as F @@ -13,9 +13,11 @@ class DPOLossFn(PolicyLossFn): def __init__( self, + backend: str = "verl", beta: float = 0.1, label_smoothing: float = 0.0, ) -> None: + super().__init__(backend=backend) self.beta = beta self.label_smoothing = label_smoothing @@ -63,10 +65,3 @@ def default_args(cls) -> Dict: "beta": 0.1, "label_smoothing": 0.0, } - - @property - def select_keys(self) -> List[str]: - return [ - "ref_logprob", - "action_mask", - ] diff --git a/trinity/algorithm/policy_loss_fn/mix_policy_loss.py b/trinity/algorithm/policy_loss_fn/mix_policy_loss.py index 84679b0ea8..76c89c42d9 100644 --- a/trinity/algorithm/policy_loss_fn/mix_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/mix_policy_loss.py @@ -1,6 +1,6 @@ """Mix policy loss function.""" -from typing import Dict, List, Optional, Tuple +from typing import Dict, Optional, Tuple import torch @@ -26,27 +26,29 @@ class MIXPolicyLossFn(PolicyLossFn): def __init__( self, + backend: str = "verl", mu: float = 0.1, clip_range: Optional[float] = None, clip_range_low: Optional[float] = None, clip_range_high: Optional[float] = None, use_dynamic_bsz: Optional[bool] = None, - repeat_times: Optional[int] = None, - ppo_mini_batch_size: Optional[int] = None, - ppo_micro_batch_size_per_gpu: Optional[int] = None, - ngpus_trainer: Optional[int] = None, - read_batch_size_usual: Optional[int] = None, - read_batch_size_expert: Optional[int] = None, + repeat_times: int = 1, + ppo_mini_batch_size: int = 1, + ppo_micro_batch_size_per_gpu: int = 1, + ngpus_trainer: int = 1, + read_batch_size_usual: int = 1, + read_batch_size_expert: int = 1, use_token_level_loss_in_sft: bool = True, ) -> None: + super().__init__(backend=backend) self.mu = mu self.use_dynamic_bsz = use_dynamic_bsz - self.experience_per_gpu = ppo_mini_batch_size * repeat_times // ngpus_trainer # type: ignore + self.experience_per_gpu = ppo_mini_batch_size * repeat_times // ngpus_trainer self.gradient_accumulation = ( - ppo_mini_batch_size * repeat_times // ppo_micro_batch_size_per_gpu # type: ignore + ppo_mini_batch_size * repeat_times // ppo_micro_batch_size_per_gpu ) - self.read_batch_size_usual = read_batch_size_usual - self.read_batch_size_expert = read_batch_size_expert + self.read_batch_size_usual = read_batch_size_usual // ngpus_trainer + self.read_batch_size_expert = read_batch_size_expert // ngpus_trainer self.grpo_loss_fn = PPOPolicyLossFn( clip_range=clip_range, clip_range_low=clip_range_low, @@ -60,11 +62,9 @@ def __call__( # type: ignore old_logprob: torch.Tensor, action_mask: torch.Tensor, advantages: torch.Tensor, + is_expert_mask: torch.Tensor, **kwargs, ) -> Tuple[torch.Tensor, Dict]: - is_expert_mask = kwargs.get("is_expert_mask", None) - if is_expert_mask is None: - raise ValueError("is_expert_mask is required in MIX") assert ( len(is_expert_mask) == logprob.shape[0] ), f"Error: {len(is_expert_mask)=} != {logprob.shape[0]=}" @@ -127,7 +127,3 @@ def default_args(cls) -> Dict: "mu": 0.1, "clip_range": 0.2, } - - @property - def select_keys(self) -> List[str]: - return ["old_logprob", "action_mask", "advantages", "is_expert_mask"] diff --git a/trinity/algorithm/policy_loss_fn/opmd_policy_loss.py b/trinity/algorithm/policy_loss_fn/opmd_policy_loss.py index 042d26b341..618301b319 100644 --- a/trinity/algorithm/policy_loss_fn/opmd_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/opmd_policy_loss.py @@ -1,6 +1,6 @@ """OPMD policy loss function.""" -from typing import Dict, List, Tuple +from typing import Dict, Tuple import torch @@ -10,13 +10,13 @@ @POLICY_LOSS_FN.register_module("opmd") class OPMDPolicyLossFn(PolicyLossFn): - def __init__(self, tau: float = 1.0) -> None: + def __init__(self, backend: str = "verl", tau: float = 1.0) -> None: + super().__init__(backend=backend) self.tau = tau def __call__( # type: ignore self, logprob: torch.Tensor, - old_logprob: torch.Tensor, # NOT USED! action_mask: torch.Tensor, advantages: torch.Tensor, **kwargs, @@ -29,11 +29,3 @@ def __call__( # type: ignore @classmethod def default_args(cls) -> Dict: return {"tau": 1.0} - - @property - def select_keys(self) -> List[str]: - return [ - "old_logprob", - "action_mask", - "advantages", - ] diff --git a/trinity/algorithm/policy_loss_fn/policy_loss_fn.py b/trinity/algorithm/policy_loss_fn/policy_loss_fn.py index 6c1a29b3e9..aa6025252e 100644 --- a/trinity/algorithm/policy_loss_fn/policy_loss_fn.py +++ b/trinity/algorithm/policy_loss_fn/policy_loss_fn.py @@ -1,18 +1,92 @@ -from abc import ABC, abstractmethod -from typing import Dict, List, Tuple +import inspect +from abc import ABC, ABCMeta, abstractmethod +from typing import Dict, Tuple import torch +from trinity.algorithm.key_mapper import ALL_MAPPERS from trinity.utils.registry import Registry POLICY_LOSS_FN = Registry("policy_loss_fn") -class PolicyLossFn(ABC): +class PolicyLossFnMeta(ABCMeta): + """Metaclass for policy loss functions that handles parameter name mapping and filtering.""" + + ignore_keys = {"self", "kwargs", "logprob"} # Keys to exclude from parameter selection + + def __new__(cls, name, bases, dct): + """ + Metaclass constructor that automatically generates parameter handling logic. + + For example with `PPOPolicyLossFn` class: + .. code-block:: python + class PPOPolicyLossFn(PolicyLossFn): + ... + def __call__( + self, + logprob: torch.Tensor, + old_logprob: torch.Tensor, + action_mask: torch.Tensor, + advantages: torch.Tensor, + **kwargs, + ) -> Tuple[torch.Tensor, Dict]: + ... + + This metaclass analyzes the __call__ method's parameters to: + 1. Generate _select_keys containing all non-ignored parameters + 2. Create select_keys property that maps parameters to trainer-specific names + 3. Apply decorator to automatically convert input parameter names using the mapper + """ + signature = inspect.signature(dct["__call__"]) + param_names = [ + key for key in signature.parameters.keys() if key not in PolicyLossFnMeta.ignore_keys + ] + dct["_select_keys"] = param_names + + # Property to return trainer-specific parameter names + def select_keys(self): + """Returns parameter keys mapped to the specific training framework's naming convention.""" + keys = [self.mapper.from_trinity(key) for key in self._select_keys] + return keys + + # Decorator to handle parameter name conversion before calling __call__ + def decorator(func): + def wrapper(self, *args, **kwargs): + """Filters and converts parameter names according to the training framework's convention.""" + new_kwargs = {} + for key, value in kwargs.items(): + key = self.mapper.to_trinity(key) + if key == "logprob" or key in self._select_keys: # remove unused keys + new_kwargs[key] = value + return func(self, *args, **new_kwargs) + + return wrapper + + # Add the property and decorated method to the class + dct["select_keys"] = property(select_keys) + dct["__call__"] = decorator(dct["__call__"]) + return super().__new__(cls, name, bases, dct) + + +class PolicyLossFn(ABC, metaclass=PolicyLossFnMeta): """ - Policy Loss Function + Abstract base class for policy loss functions. + + This class provides the interface for implementing different policy gradient loss functions + while handling parameter name mapping between different training frameworks. """ + def __init__(self, backend: str = "verl"): + """ + Initialize the policy loss function. + + Args: + backend: The training framework/backend to use (e.g., "verl") + """ + self.backend = backend + self.mapper = ALL_MAPPERS[self.backend] + @abstractmethod def __call__( self, @@ -20,8 +94,12 @@ def __call__( **kwargs, ) -> Tuple[torch.Tensor, Dict]: """ + Calculate the policy loss. + Args: logprob (`torch.Tensor`): The log probability generated by the policy model. + + Kwargs (optional): old_logprob (`torch.Tensor`): The log probability generated by the reference model. action_mask (`torch.Tensor`): The action mask. advantages (`torch.Tensor`): The advantages. @@ -36,14 +114,8 @@ def __call__( @abstractmethod def default_args(cls) -> Dict: """ - Returns: - `Dict`: The default init arguments for the policy loss function. - """ + Get default initialization arguments for this loss function. - @property - @abstractmethod - def select_keys(self) -> List[str]: - """ Returns: - `List[str]`: The keys to select from input data. + `Dict`: The default init arguments for the policy loss function. """ diff --git a/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py b/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py index 5c735d4d6a..a4cc0b2d03 100644 --- a/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py @@ -3,7 +3,7 @@ Modified from https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py """ -from typing import Dict, List, Optional, Tuple +from typing import Dict, Optional, Tuple import torch @@ -15,10 +15,12 @@ class PPOPolicyLossFn(PolicyLossFn): def __init__( self, + backend: str = "verl", clip_range: Optional[float] = None, clip_range_low: Optional[float] = None, clip_range_high: Optional[float] = None, ) -> None: + super().__init__(backend=backend) if clip_range_low is None: self.clip_range_low = clip_range else: @@ -61,11 +63,3 @@ def default_args(cls) -> Dict: return { "clip_range": 0.2, } - - @property - def select_keys(self) -> List[str]: - return [ - "old_logprob", - "action_mask", - "advantages", - ] diff --git a/trinity/algorithm/policy_loss_fn/sft_loss.py b/trinity/algorithm/policy_loss_fn/sft_loss.py index dd1c75a4a2..2c824f1c09 100644 --- a/trinity/algorithm/policy_loss_fn/sft_loss.py +++ b/trinity/algorithm/policy_loss_fn/sft_loss.py @@ -1,6 +1,6 @@ """SFT loss function.""" -from typing import Dict, List, Tuple +from typing import Dict, Tuple import torch @@ -10,7 +10,8 @@ @POLICY_LOSS_FN.register_module("sft") class SFTLossFn(PolicyLossFn): - def __init__(self, use_token_level_loss: bool = True) -> None: + def __init__(self, backend: str = "verl", use_token_level_loss: bool = True) -> None: + super().__init__(backend=backend) self.use_token_level_loss = use_token_level_loss def __call__( # type: ignore @@ -30,7 +31,3 @@ def default_args(cls): return { "use_token_level_loss": True, } - - @property - def select_keys(self) -> List[str]: - return ["action_mask"] diff --git a/trinity/trainer/verl/dp_actor.py b/trinity/trainer/verl/dp_actor.py index 6a57e58144..e7eb34ea17 100644 --- a/trinity/trainer/verl/dp_actor.py +++ b/trinity/trainer/verl/dp_actor.py @@ -56,7 +56,7 @@ def __init__( def set_algorithm(self, algorithm_config: AlgorithmConfig): self.policy_loss_fn = POLICY_LOSS_FN.get(algorithm_config.policy_loss_fn)( - **algorithm_config.policy_loss_fn_args + backend="verl", **algorithm_config.policy_loss_fn_args ) self.kl_loss_fn = KL_FN.get(algorithm_config.kl_loss_fn)(**algorithm_config.kl_loss_fn_args) self.entropy_loss_fn = ENTROPY_LOSS_FN.get(algorithm_config.entropy_loss_fn)( @@ -152,21 +152,7 @@ def update_policy(self, data: DataProto): "responses", "response_mask", ] - select_keys_verl2trinity = { - "old_log_probs": "old_logprob", - "ref_log_prob": "ref_logprob", - "response_mask": "action_mask", - "advantages": "advantages", - } - select_keys_trinity2verl = {value: key for key, value in select_keys_verl2trinity.items()} - for trinity_key in self.policy_loss_fn.select_keys: - if trinity_key in select_keys_trinity2verl: - verl_key = select_keys_trinity2verl[trinity_key] - else: - verl_key = trinity_key - select_keys_verl2trinity.update({verl_key: trinity_key}) - select_keys_trinity2verl.update({trinity_key: verl_key}) - select_keys.append(verl_key) + select_keys.extend(self.policy_loss_fn.select_keys) if not isinstance(self.kl_loss_fn, DummyKLFn): select_keys.append("ref_log_prob") select_keys = list(set(select_keys)) @@ -240,14 +226,8 @@ def update_policy(self, data: DataProto): calculate_entropy=calculate_entropy, ) - kwargs = { - select_keys_verl2trinity[verl_key]: value - for verl_key, value in data.items() - if verl_key in select_keys_verl2trinity - } pg_loss, pg_loss_metrics = self.policy_loss_fn( # type: ignore - logprob=log_prob, - **kwargs, + logprob=log_prob, **data ) prefix_metrics( src_metrics=pg_loss_metrics, prefix="actor", dst_metrics=micro_batch_metrics