From 51d229f5ccb1d5ed5cd376c39f6d576789f54d4a Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Tue, 17 Jun 2025 11:49:29 +0800 Subject: [PATCH 1/9] Refactor on `select_keys` --- trinity/algorithm/policy_loss_fn/dpo_loss.py | 9 +----- .../policy_loss_fn/opmd_policy_loss.py | 10 +------ .../policy_loss_fn/policy_loss_fn.py | 29 ++++++++++++------- .../policy_loss_fn/ppo_policy_loss.py | 10 +------ trinity/algorithm/policy_loss_fn/sft_loss.py | 6 +--- 5 files changed, 22 insertions(+), 42 deletions(-) diff --git a/trinity/algorithm/policy_loss_fn/dpo_loss.py b/trinity/algorithm/policy_loss_fn/dpo_loss.py index 7dfbb7141d..321c58bf9f 100644 --- a/trinity/algorithm/policy_loss_fn/dpo_loss.py +++ b/trinity/algorithm/policy_loss_fn/dpo_loss.py @@ -1,6 +1,6 @@ """DPO loss function.""" -from typing import Dict, List, Tuple +from typing import Dict, Tuple import torch import torch.nn.functional as F @@ -63,10 +63,3 @@ def default_args(cls) -> Dict: "beta": 0.1, "label_smoothing": 0.0, } - - @property - def select_keys(self) -> List[str]: - return [ - "ref_logprob", - "action_mask", - ] diff --git a/trinity/algorithm/policy_loss_fn/opmd_policy_loss.py b/trinity/algorithm/policy_loss_fn/opmd_policy_loss.py index 042d26b341..1d5db89b71 100644 --- a/trinity/algorithm/policy_loss_fn/opmd_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/opmd_policy_loss.py @@ -1,6 +1,6 @@ """OPMD policy loss function.""" -from typing import Dict, List, Tuple +from typing import Dict, Tuple import torch @@ -29,11 +29,3 @@ def __call__( # type: ignore @classmethod def default_args(cls) -> Dict: return {"tau": 1.0} - - @property - def select_keys(self) -> List[str]: - return [ - "old_logprob", - "action_mask", - "advantages", - ] diff --git a/trinity/algorithm/policy_loss_fn/policy_loss_fn.py b/trinity/algorithm/policy_loss_fn/policy_loss_fn.py index 6c1a29b3e9..1935ae3c4f 100644 --- a/trinity/algorithm/policy_loss_fn/policy_loss_fn.py +++ b/trinity/algorithm/policy_loss_fn/policy_loss_fn.py @@ -1,5 +1,6 @@ -from abc import ABC, abstractmethod -from typing import Dict, List, Tuple +import inspect +from abc import ABC, ABCMeta, abstractmethod +from typing import Dict, Tuple import torch @@ -8,7 +9,21 @@ POLICY_LOSS_FN = Registry("policy_loss_fn") -class PolicyLossFn(ABC): +class PolicyLossFnMeta(ABCMeta): + """Meta class for policy loss function.""" + + ignore_keys = {"self", "kwargs", "logprob"} + + def __new__(cls, name, bases, dct): + signature = inspect.signature(dct["__call__"]) + param_names = [ + key for key in signature.parameters.keys() if key not in PolicyLossFnMeta.ignore_keys + ] + dct["select_keys"] = property(lambda self: param_names) + return super().__new__(cls, name, bases, dct) + + +class PolicyLossFn(ABC, metaclass=PolicyLossFnMeta): """ Policy Loss Function """ @@ -39,11 +54,3 @@ def default_args(cls) -> Dict: Returns: `Dict`: The default init arguments for the policy loss function. """ - - @property - @abstractmethod - def select_keys(self) -> List[str]: - """ - Returns: - `List[str]`: The keys to select from input data. - """ diff --git a/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py b/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py index 5c735d4d6a..da971c5266 100644 --- a/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py @@ -3,7 +3,7 @@ Modified from https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py """ -from typing import Dict, List, Optional, Tuple +from typing import Dict, Optional, Tuple import torch @@ -61,11 +61,3 @@ def default_args(cls) -> Dict: return { "clip_range": 0.2, } - - @property - def select_keys(self) -> List[str]: - return [ - "old_logprob", - "action_mask", - "advantages", - ] diff --git a/trinity/algorithm/policy_loss_fn/sft_loss.py b/trinity/algorithm/policy_loss_fn/sft_loss.py index dd1c75a4a2..bbc70c0490 100644 --- a/trinity/algorithm/policy_loss_fn/sft_loss.py +++ b/trinity/algorithm/policy_loss_fn/sft_loss.py @@ -1,6 +1,6 @@ """SFT loss function.""" -from typing import Dict, List, Tuple +from typing import Dict, Tuple import torch @@ -30,7 +30,3 @@ def default_args(cls): return { "use_token_level_loss": True, } - - @property - def select_keys(self) -> List[str]: - return ["action_mask"] From 68c2997a2e41b06d67f3a415db68996fc4050c40 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Tue, 17 Jun 2025 15:00:04 +0800 Subject: [PATCH 2/9] Add `key_mapper` --- trinity/algorithm/key_mapper.py | 28 +++++++++++++++++++ trinity/algorithm/policy_loss_fn/dpo_loss.py | 2 ++ .../policy_loss_fn/opmd_policy_loss.py | 3 +- .../policy_loss_fn/policy_loss_fn.py | 27 +++++++++++++++++- .../policy_loss_fn/ppo_policy_loss.py | 2 ++ trinity/algorithm/policy_loss_fn/sft_loss.py | 3 +- trinity/trainer/verl/dp_actor.py | 21 ++------------ 7 files changed, 65 insertions(+), 21 deletions(-) create mode 100644 trinity/algorithm/key_mapper.py diff --git a/trinity/algorithm/key_mapper.py b/trinity/algorithm/key_mapper.py new file mode 100644 index 0000000000..94e1abdb4a --- /dev/null +++ b/trinity/algorithm/key_mapper.py @@ -0,0 +1,28 @@ +"""Key Mapper""" + + +from typing import Dict + + +class KeyMapper: + def __init__(self, to_trinity_map: Dict[str, str]): + self.to_trinity_map = to_trinity_map + self.from_trinity_map = {v: k for k, v in self.to_trinity_map.items()} + + def to_trinity(self, key: str) -> str: + return self.to_trinity_map.get(key, key) + + def from_trinity(self, key: str) -> str: + return self.from_trinity_map.get(key, key) + + +ALL_MAPPERS = { + "verl": KeyMapper( + { + "old_log_probs": "old_logprob", + "ref_log_prob": "ref_logprob", + "response_mask": "action_mask", + "advantages": "advantages", + } + ), +} \ No newline at end of file diff --git a/trinity/algorithm/policy_loss_fn/dpo_loss.py b/trinity/algorithm/policy_loss_fn/dpo_loss.py index 321c58bf9f..0858cb7002 100644 --- a/trinity/algorithm/policy_loss_fn/dpo_loss.py +++ b/trinity/algorithm/policy_loss_fn/dpo_loss.py @@ -13,9 +13,11 @@ class DPOLossFn(PolicyLossFn): def __init__( self, + backend: str = "verl", beta: float = 0.1, label_smoothing: float = 0.0, ) -> None: + super().__init__(backend=backend) self.beta = beta self.label_smoothing = label_smoothing diff --git a/trinity/algorithm/policy_loss_fn/opmd_policy_loss.py b/trinity/algorithm/policy_loss_fn/opmd_policy_loss.py index 1d5db89b71..3c29d0ca2d 100644 --- a/trinity/algorithm/policy_loss_fn/opmd_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/opmd_policy_loss.py @@ -10,7 +10,8 @@ @POLICY_LOSS_FN.register_module("opmd") class OPMDPolicyLossFn(PolicyLossFn): - def __init__(self, tau: float = 1.0) -> None: + def __init__(self, backend: str = "verl", tau: float = 1.0) -> None: + super().__init__(backend=backend) self.tau = tau def __call__( # type: ignore diff --git a/trinity/algorithm/policy_loss_fn/policy_loss_fn.py b/trinity/algorithm/policy_loss_fn/policy_loss_fn.py index 1935ae3c4f..f513c4c38f 100644 --- a/trinity/algorithm/policy_loss_fn/policy_loss_fn.py +++ b/trinity/algorithm/policy_loss_fn/policy_loss_fn.py @@ -4,6 +4,7 @@ import torch +from trinity.algorithm.key_mapper import ALL_MAPPERS from trinity.utils.registry import Registry POLICY_LOSS_FN = Registry("policy_loss_fn") @@ -19,7 +20,28 @@ def __new__(cls, name, bases, dct): param_names = [ key for key in signature.parameters.keys() if key not in PolicyLossFnMeta.ignore_keys ] - dct["select_keys"] = property(lambda self: param_names) + dct["_select_keys"] = param_names + + def select_keys(self): + mapper = ALL_MAPPERS[self.backend] + keys = [mapper.from_trinity(key) for key in self._select_keys] + return keys + + def decorator(func): + def wrapper(self, *args, **kwargs): + mapper = ALL_MAPPERS[self.backend] + new_kwargs = {} + for key, value in kwargs.items(): + key = mapper.from_trinity(key) + if key in self._select_keys: # remove unused keys + new_kwargs[key] = value + kwargs = new_kwargs + return func(self, *args, **new_kwargs) + + return wrapper + + dct["select_keys"] = property(select_keys) + dct["__call__"] = decorator(dct["__call__"]) return super().__new__(cls, name, bases, dct) @@ -28,6 +50,9 @@ class PolicyLossFn(ABC, metaclass=PolicyLossFnMeta): Policy Loss Function """ + def __init__(self, backend: str = "verl"): + self.backend = backend + @abstractmethod def __call__( self, diff --git a/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py b/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py index da971c5266..a4cc0b2d03 100644 --- a/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py @@ -15,10 +15,12 @@ class PPOPolicyLossFn(PolicyLossFn): def __init__( self, + backend: str = "verl", clip_range: Optional[float] = None, clip_range_low: Optional[float] = None, clip_range_high: Optional[float] = None, ) -> None: + super().__init__(backend=backend) if clip_range_low is None: self.clip_range_low = clip_range else: diff --git a/trinity/algorithm/policy_loss_fn/sft_loss.py b/trinity/algorithm/policy_loss_fn/sft_loss.py index bbc70c0490..2c824f1c09 100644 --- a/trinity/algorithm/policy_loss_fn/sft_loss.py +++ b/trinity/algorithm/policy_loss_fn/sft_loss.py @@ -10,7 +10,8 @@ @POLICY_LOSS_FN.register_module("sft") class SFTLossFn(PolicyLossFn): - def __init__(self, use_token_level_loss: bool = True) -> None: + def __init__(self, backend: str = "verl", use_token_level_loss: bool = True) -> None: + super().__init__(backend=backend) self.use_token_level_loss = use_token_level_loss def __call__( # type: ignore diff --git a/trinity/trainer/verl/dp_actor.py b/trinity/trainer/verl/dp_actor.py index 0d750c8303..5d65990411 100644 --- a/trinity/trainer/verl/dp_actor.py +++ b/trinity/trainer/verl/dp_actor.py @@ -56,7 +56,7 @@ def __init__( def set_algorithm(self, algorithm_config: AlgorithmConfig): self.policy_loss_fn = POLICY_LOSS_FN.get(algorithm_config.policy_loss_fn)( - **algorithm_config.policy_loss_fn_args + backend="verl", **algorithm_config.policy_loss_fn_args ) self.kl_loss_fn = KL_FN.get(algorithm_config.kl_loss_fn)(**algorithm_config.kl_loss_fn_args) self.entropy_loss_fn = ENTROPY_LOSS_FN.get(algorithm_config.entropy_loss_fn)( @@ -152,16 +152,7 @@ def update_policy(self, data: DataProto): "responses", "response_mask", ] - select_keys_verl2trinity = { - "old_log_probs": "old_logprob", - "ref_log_prob": "ref_logprob", - "response_mask": "action_mask", - "advantages": "advantages", - } - select_keys_trinity2verl = {value: key for key, value in select_keys_verl2trinity.items()} - for trinity_key in self.policy_loss_fn.select_keys: - verl_key = select_keys_trinity2verl[trinity_key] - select_keys.append(verl_key) + select_keys.extend(self.policy_loss_fn.select_keys) if not isinstance(self.kl_loss_fn, DummyKLFn): select_keys.append("ref_log_prob") select_keys = list(set(select_keys)) @@ -235,14 +226,8 @@ def update_policy(self, data: DataProto): calculate_entropy=calculate_entropy, ) - kwargs = { - select_keys_verl2trinity[verl_key]: value - for verl_key, value in data.items() - if verl_key in select_keys_verl2trinity - } pg_loss, pg_loss_metrics = self.policy_loss_fn( # type: ignore - logprob=log_prob, - **kwargs, + logprob=log_prob, **data # [self.policy_loss_fn.select_keys] # .items(), ) prefix_metrics( src_metrics=pg_loss_metrics, prefix="actor", dst_metrics=micro_batch_metrics From 950be68d19e908bf66a9a774aab1984c969bbcb3 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Tue, 17 Jun 2025 15:02:57 +0800 Subject: [PATCH 3/9] bug fix --- trinity/algorithm/key_mapper.py | 4 ++-- trinity/algorithm/policy_loss_fn/policy_loss_fn.py | 7 +++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/trinity/algorithm/key_mapper.py b/trinity/algorithm/key_mapper.py index 94e1abdb4a..1c332fb3a4 100644 --- a/trinity/algorithm/key_mapper.py +++ b/trinity/algorithm/key_mapper.py @@ -1,6 +1,6 @@ +# -*- coding: utf-8 -*- """Key Mapper""" - from typing import Dict @@ -25,4 +25,4 @@ def from_trinity(self, key: str) -> str: "advantages": "advantages", } ), -} \ No newline at end of file +} diff --git a/trinity/algorithm/policy_loss_fn/policy_loss_fn.py b/trinity/algorithm/policy_loss_fn/policy_loss_fn.py index f513c4c38f..285004e1dd 100644 --- a/trinity/algorithm/policy_loss_fn/policy_loss_fn.py +++ b/trinity/algorithm/policy_loss_fn/policy_loss_fn.py @@ -23,16 +23,14 @@ def __new__(cls, name, bases, dct): dct["_select_keys"] = param_names def select_keys(self): - mapper = ALL_MAPPERS[self.backend] - keys = [mapper.from_trinity(key) for key in self._select_keys] + keys = [self.mapper.from_trinity(key) for key in self._select_keys] return keys def decorator(func): def wrapper(self, *args, **kwargs): - mapper = ALL_MAPPERS[self.backend] new_kwargs = {} for key, value in kwargs.items(): - key = mapper.from_trinity(key) + key = self.mapper.to_trinity(key) if key in self._select_keys: # remove unused keys new_kwargs[key] = value kwargs = new_kwargs @@ -52,6 +50,7 @@ class PolicyLossFn(ABC, metaclass=PolicyLossFnMeta): def __init__(self, backend: str = "verl"): self.backend = backend + self.mapper = ALL_MAPPERS[self.backend] @abstractmethod def __call__( From c10d41451df68b4c8ebafaf6424dad1ab03919e2 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Tue, 17 Jun 2025 15:04:43 +0800 Subject: [PATCH 4/9] fix mix_policy_loss --- trinity/algorithm/policy_loss_fn/mix_policy_loss.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/trinity/algorithm/policy_loss_fn/mix_policy_loss.py b/trinity/algorithm/policy_loss_fn/mix_policy_loss.py index 84679b0ea8..3e64302fa3 100644 --- a/trinity/algorithm/policy_loss_fn/mix_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/mix_policy_loss.py @@ -26,6 +26,7 @@ class MIXPolicyLossFn(PolicyLossFn): def __init__( self, + backend: str = "verl", mu: float = 0.1, clip_range: Optional[float] = None, clip_range_low: Optional[float] = None, @@ -39,6 +40,7 @@ def __init__( read_batch_size_expert: Optional[int] = None, use_token_level_loss_in_sft: bool = True, ) -> None: + super().__init__(backend=backend) self.mu = mu self.use_dynamic_bsz = use_dynamic_bsz self.experience_per_gpu = ppo_mini_batch_size * repeat_times // ngpus_trainer # type: ignore @@ -127,7 +129,3 @@ def default_args(cls) -> Dict: "mu": 0.1, "clip_range": 0.2, } - - @property - def select_keys(self) -> List[str]: - return ["old_logprob", "action_mask", "advantages", "is_expert_mask"] From e279c7555946a744927b1555aeee2e873112e2d8 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Tue, 17 Jun 2025 15:11:05 +0800 Subject: [PATCH 5/9] doc fix --- docs/sphinx_doc/source/tutorial/example_mix_algo.md | 10 +++------- trinity/algorithm/policy_loss_fn/mix_policy_loss.py | 6 ++---- trinity/trainer/verl/dp_actor.py | 2 +- 3 files changed, 6 insertions(+), 12 deletions(-) diff --git a/docs/sphinx_doc/source/tutorial/example_mix_algo.md b/docs/sphinx_doc/source/tutorial/example_mix_algo.md index 9dadc76b40..bbe3135c90 100644 --- a/docs/sphinx_doc/source/tutorial/example_mix_algo.md +++ b/docs/sphinx_doc/source/tutorial/example_mix_algo.md @@ -170,6 +170,7 @@ We define a `MixPolicyLoss` class in `trinity/algorithm/policy_loss_fn/mix_polic class MIXPolicyLossFn(PolicyLossFn): def __init__( self, + backend: str = "verl", mu: float = 0.1, clip_range: Optional[float] = None, clip_range_low: Optional[float] = None, @@ -183,6 +184,7 @@ class MIXPolicyLossFn(PolicyLossFn): read_batch_size_expert: Optional[int] = None, use_token_level_loss_in_sft: bool = True, ) -> None: + super().__init__(backend=backend) self.mu = mu self.use_dynamic_bsz = use_dynamic_bsz self.experience_per_gpu = ppo_mini_batch_size * repeat_times // ngpus_trainer # type: ignore @@ -204,11 +206,9 @@ class MIXPolicyLossFn(PolicyLossFn): old_logprob: torch.Tensor, action_mask: torch.Tensor, advantages: torch.Tensor, + is_expert_mask: torch.Tensor, **kwargs, ) -> Tuple[torch.Tensor, Dict]: - is_expert_mask = kwargs.get("is_expert_mask", None) - if is_expert_mask is None: - raise ValueError("is_expert_mask is required in MIX") assert ( len(is_expert_mask) == logprob.shape[0] ), f"Error: {len(is_expert_mask)=} != {logprob.shape[0]=}" @@ -271,10 +271,6 @@ class MIXPolicyLossFn(PolicyLossFn): "mu": 0.1, "clip_range": 0.2, } - - @property - def select_keys(self) -> List[str]: - return ["old_logprob", "action_mask", "advantages", "is_expert_mask"] ``` ## Step 4: Run the Experiment diff --git a/trinity/algorithm/policy_loss_fn/mix_policy_loss.py b/trinity/algorithm/policy_loss_fn/mix_policy_loss.py index 3e64302fa3..1b9446d59a 100644 --- a/trinity/algorithm/policy_loss_fn/mix_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/mix_policy_loss.py @@ -1,6 +1,6 @@ """Mix policy loss function.""" -from typing import Dict, List, Optional, Tuple +from typing import Dict, Optional, Tuple import torch @@ -62,11 +62,9 @@ def __call__( # type: ignore old_logprob: torch.Tensor, action_mask: torch.Tensor, advantages: torch.Tensor, + is_expert_mask: torch.Tensor, **kwargs, ) -> Tuple[torch.Tensor, Dict]: - is_expert_mask = kwargs.get("is_expert_mask", None) - if is_expert_mask is None: - raise ValueError("is_expert_mask is required in MIX") assert ( len(is_expert_mask) == logprob.shape[0] ), f"Error: {len(is_expert_mask)=} != {logprob.shape[0]=}" diff --git a/trinity/trainer/verl/dp_actor.py b/trinity/trainer/verl/dp_actor.py index 5d65990411..e7eb34ea17 100644 --- a/trinity/trainer/verl/dp_actor.py +++ b/trinity/trainer/verl/dp_actor.py @@ -227,7 +227,7 @@ def update_policy(self, data: DataProto): ) pg_loss, pg_loss_metrics = self.policy_loss_fn( # type: ignore - logprob=log_prob, **data # [self.policy_loss_fn.select_keys] # .items(), + logprob=log_prob, **data ) prefix_metrics( src_metrics=pg_loss_metrics, prefix="actor", dst_metrics=micro_batch_metrics From 18960d4eea98ceed25cba954ffeb2da132057e8c Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Tue, 17 Jun 2025 15:45:06 +0800 Subject: [PATCH 6/9] Doc fix --- docs/sphinx_doc/source/conf.py | 3 ++- docs/sphinx_doc/source/index.rst | 1 + docs/sphinx_doc/source/tutorial/example_mix_algo.md | 2 +- examples/mix_math/mix_math.yaml | 2 +- 4 files changed, 5 insertions(+), 3 deletions(-) diff --git a/docs/sphinx_doc/source/conf.py b/docs/sphinx_doc/source/conf.py index 4842a34557..ffaabf72c9 100644 --- a/docs/sphinx_doc/source/conf.py +++ b/docs/sphinx_doc/source/conf.py @@ -22,12 +22,13 @@ "sphinx.ext.napoleon", "sphinx.ext.autosectionlabel", "myst_parser", + "sphinx.ext.mathjax", ] source_suffix = { ".rst": "restructuredtext", ".md": "markdown", } -myst_enable_extensions = ["colon_fence"] +myst_enable_extensions = ["colon_fence", "amsmath", "dollarmath"] # Prefix document path to section labels, otherwise autogenerated labels would # look like 'heading' rather than 'path/to/file:heading' diff --git a/docs/sphinx_doc/source/index.rst b/docs/sphinx_doc/source/index.rst index 5604faa15d..a1b6fde647 100644 --- a/docs/sphinx_doc/source/index.rst +++ b/docs/sphinx_doc/source/index.rst @@ -24,6 +24,7 @@ Welcome to Trinity-RFT's documentation! tutorial/example_data_functionalities.md tutorial/trinity_configs.md tutorial/trinity_programming_guide.md + tutorial/example_mix_algo.md .. toctree:: :maxdepth: 1 diff --git a/docs/sphinx_doc/source/tutorial/example_mix_algo.md b/docs/sphinx_doc/source/tutorial/example_mix_algo.md index bbe3135c90..72c75d6509 100644 --- a/docs/sphinx_doc/source/tutorial/example_mix_algo.md +++ b/docs/sphinx_doc/source/tutorial/example_mix_algo.md @@ -1,4 +1,4 @@ -# Integrate An New Algorithm +# Integrate A New Algorithm This guide introduces how to integrate a new algorithm to Trinity-RFT. diff --git a/examples/mix_math/mix_math.yaml b/examples/mix_math/mix_math.yaml index 339d8df394..b680db6fb3 100644 --- a/examples/mix_math/mix_math.yaml +++ b/examples/mix_math/mix_math.yaml @@ -82,7 +82,7 @@ synchronizer: sync_timeout: 1200 trainer: trainer_type: 'verl' - trainer_config_path: 'examples/mix_math/train_math.yaml' + trainer_config_path: 'examples/mix_math/train_mix_math.yaml' save_interval: 50 monitor: monitor_type: wandb From d2be8f517f1ba536bd9bf3192cac55e624e9c094 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Tue, 17 Jun 2025 18:57:16 +0800 Subject: [PATCH 7/9] bug fix && add unittest for policy_loss --- .../source/tutorial/example_mix_algo.md | 2 +- examples/mix_math/README.md | 2 +- tests/algorithm/policy_loss_test.py | 94 +++++++++++++++++++ trinity/algorithm/key_mapper.py | 1 + .../policy_loss_fn/mix_policy_loss.py | 20 ++-- .../policy_loss_fn/policy_loss_fn.py | 49 +++++++++- 6 files changed, 152 insertions(+), 16 deletions(-) create mode 100644 tests/algorithm/policy_loss_test.py diff --git a/docs/sphinx_doc/source/tutorial/example_mix_algo.md b/docs/sphinx_doc/source/tutorial/example_mix_algo.md index 72c75d6509..ee0010ba24 100644 --- a/docs/sphinx_doc/source/tutorial/example_mix_algo.md +++ b/docs/sphinx_doc/source/tutorial/example_mix_algo.md @@ -6,7 +6,7 @@ As an example, we incorporate some "expert" data generated by a more advanced LL $$ \mathcal{J}_{\text{Mix}}(\theta) = -\mathcal{J}_{\text{GRPO}}(\theta) +(1-\mu) \mathcal{J}_{\text{GRPO}}(\theta) + \mu \cdot \underbrace{\frac{1}{B'} \sum_{b=1}^{B'} \left[ diff --git a/examples/mix_math/README.md b/examples/mix_math/README.md index 8e84f233bc..2ef160b0f2 100644 --- a/examples/mix_math/README.md +++ b/examples/mix_math/README.md @@ -4,4 +4,4 @@ This example shows the usage of a new algorithm MIX on the MATH dataset. For more detailed information, please refer to the [documentation](../../docs/sphinx_doc/source/tutorial/example_mix_algo.md). -The config files are located in [`mix_math.yaml`](mix.yaml) and [`train_mix_math.yaml`](train_mix_math.yaml). +The config files are located in [`mix_math.yaml`](mix_math.yaml) and [`train_mix_math.yaml`](train_mix_math.yaml). diff --git a/tests/algorithm/policy_loss_test.py b/tests/algorithm/policy_loss_test.py new file mode 100644 index 0000000000..ba88feb2d7 --- /dev/null +++ b/tests/algorithm/policy_loss_test.py @@ -0,0 +1,94 @@ +# -*- coding: utf-8 -*- +"""Test for policy loss functions""" + +import unittest + +import torch +from verl import DataProto + +from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN + + +class VerlPolicyLossTest(unittest.TestCase): + def setUp(self): + seed = 42 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + shape = (5, 20) + self.logprob = 2 * torch.rand(shape) - 1 + self.input_data = DataProto.from_dict( + { + "old_log_probs": 2 * torch.rand(shape) - 1, + "ref_log_prob": 2 * torch.rand(shape) - 1, + "response_mask": torch.rand(shape) > 0.5, + "advantages": 2 * torch.rand(shape) - 1, + "is_expert_mask": torch.rand(shape[0]) > 0.5, + } + ) + + def test_ppo_policy_loss(self): + policy_loss_fn_cls = POLICY_LOSS_FN.get("ppo") + policy_loss_fn_args = policy_loss_fn_cls.default_args() + policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args) + loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch) + ppo_loss = torch.tensor(0.28560468554496765) + pg_clipfrac = torch.tensor(0.3541666567325592) + ppo_kl = torch.tensor(-0.21663446724414825) + self.assertTrue(torch.allclose(loss, ppo_loss)) + self.assertTrue(torch.allclose(torch.tensor(metrics["pg_clipfrac"]), pg_clipfrac)) + self.assertTrue(torch.allclose(torch.tensor(metrics["ppo_kl"]), ppo_kl)) + self.assertTrue(torch.allclose(torch.tensor(metrics["pg_loss"]), ppo_loss)) + + def test_sft_policy_loss(self): + policy_loss_fn_cls = POLICY_LOSS_FN.get("sft") + policy_loss_fn_args = policy_loss_fn_cls.default_args() + policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args) + loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch) + sft_loss = torch.tensor(-0.07560186833143234) + self.assertTrue(torch.allclose(loss, sft_loss)) + self.assertTrue(torch.allclose(torch.tensor(metrics["sft_loss"]), sft_loss)) + + def test_dpo_policy_loss(self): + policy_loss_fn_cls = POLICY_LOSS_FN.get("dpo") + policy_loss_fn_args = policy_loss_fn_cls.default_args() + policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args) + loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch) + dpo_loss = torch.tensor(0.5406752228736877) + chosen_reward = torch.tensor(0.7082431316375732) + rejected_reward = torch.tensor(0.3757950782775879) + accuracy_mean = torch.tensor(1.0) + self.assertTrue(torch.allclose(loss, dpo_loss)) + self.assertTrue(torch.allclose(torch.tensor(metrics["chosen_reward"]), chosen_reward)) + self.assertTrue(torch.allclose(torch.tensor(metrics["rejected_reward"]), rejected_reward)) + self.assertTrue(torch.allclose(torch.tensor(metrics["accuracy_mean"]), accuracy_mean)) + self.assertTrue(torch.allclose(torch.tensor(metrics["dpo_loss"]), dpo_loss)) + + def test_opmd_policy_loss(self): + policy_loss_fn_cls = POLICY_LOSS_FN.get("opmd") + policy_loss_fn_args = policy_loss_fn_cls.default_args() + policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args) + loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch) + opmd_loss = torch.tensor(-0.009589947760105133) + self.assertTrue(torch.allclose(loss, opmd_loss)) + self.assertTrue(torch.allclose(torch.tensor(metrics["opmd_loss"]), opmd_loss)) + + def test_mix_policy_loss(self): + policy_loss_fn_cls = POLICY_LOSS_FN.get("mix") + policy_loss_fn_args = policy_loss_fn_cls.default_args() + policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args) + loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch) + mix_loss = torch.tensor(0.6581965088844299) + pg_clipfrac = torch.tensor(0.7777777910232544) + ppo_kl = torch.tensor(-1.0737695693969727) + pg_loss = torch.tensor(0.7236452102661133) + sft_loss = torch.tensor(0.06915830634534359) + self.assertTrue(torch.allclose(loss, mix_loss)) + self.assertTrue(torch.allclose(torch.tensor(metrics["usual/pg_clipfrac"]), pg_clipfrac)) + self.assertTrue(torch.allclose(torch.tensor(metrics["usual/ppo_kl"]), ppo_kl)) + self.assertTrue(torch.allclose(torch.tensor(metrics["usual/pg_loss"]), pg_loss)) + self.assertTrue(torch.allclose(torch.tensor(metrics["expert/sft_loss"]), sft_loss)) + self.assertTrue(torch.allclose(torch.tensor(metrics["loss"]), mix_loss)) diff --git a/trinity/algorithm/key_mapper.py b/trinity/algorithm/key_mapper.py index 1c332fb3a4..09c1f988a6 100644 --- a/trinity/algorithm/key_mapper.py +++ b/trinity/algorithm/key_mapper.py @@ -19,6 +19,7 @@ def from_trinity(self, key: str) -> str: ALL_MAPPERS = { "verl": KeyMapper( { + "log_prob": "logprob", "old_log_probs": "old_logprob", "ref_log_prob": "ref_logprob", "response_mask": "action_mask", diff --git a/trinity/algorithm/policy_loss_fn/mix_policy_loss.py b/trinity/algorithm/policy_loss_fn/mix_policy_loss.py index 1b9446d59a..76c89c42d9 100644 --- a/trinity/algorithm/policy_loss_fn/mix_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/mix_policy_loss.py @@ -32,23 +32,23 @@ def __init__( clip_range_low: Optional[float] = None, clip_range_high: Optional[float] = None, use_dynamic_bsz: Optional[bool] = None, - repeat_times: Optional[int] = None, - ppo_mini_batch_size: Optional[int] = None, - ppo_micro_batch_size_per_gpu: Optional[int] = None, - ngpus_trainer: Optional[int] = None, - read_batch_size_usual: Optional[int] = None, - read_batch_size_expert: Optional[int] = None, + repeat_times: int = 1, + ppo_mini_batch_size: int = 1, + ppo_micro_batch_size_per_gpu: int = 1, + ngpus_trainer: int = 1, + read_batch_size_usual: int = 1, + read_batch_size_expert: int = 1, use_token_level_loss_in_sft: bool = True, ) -> None: super().__init__(backend=backend) self.mu = mu self.use_dynamic_bsz = use_dynamic_bsz - self.experience_per_gpu = ppo_mini_batch_size * repeat_times // ngpus_trainer # type: ignore + self.experience_per_gpu = ppo_mini_batch_size * repeat_times // ngpus_trainer self.gradient_accumulation = ( - ppo_mini_batch_size * repeat_times // ppo_micro_batch_size_per_gpu # type: ignore + ppo_mini_batch_size * repeat_times // ppo_micro_batch_size_per_gpu ) - self.read_batch_size_usual = read_batch_size_usual - self.read_batch_size_expert = read_batch_size_expert + self.read_batch_size_usual = read_batch_size_usual // ngpus_trainer + self.read_batch_size_expert = read_batch_size_expert // ngpus_trainer self.grpo_loss_fn = PPOPolicyLossFn( clip_range=clip_range, clip_range_low=clip_range_low, diff --git a/trinity/algorithm/policy_loss_fn/policy_loss_fn.py b/trinity/algorithm/policy_loss_fn/policy_loss_fn.py index 285004e1dd..c8ca241cd8 100644 --- a/trinity/algorithm/policy_loss_fn/policy_loss_fn.py +++ b/trinity/algorithm/policy_loss_fn/policy_loss_fn.py @@ -11,33 +11,59 @@ class PolicyLossFnMeta(ABCMeta): - """Meta class for policy loss function.""" + """Metaclass for policy loss functions that handles parameter name mapping and filtering.""" - ignore_keys = {"self", "kwargs", "logprob"} + ignore_keys = {"self", "kwargs"} # Keys to exclude from parameter selection def __new__(cls, name, bases, dct): + """ + Metaclass constructor that automatically generates parameter handling logic. + + For example with `PPOPolicyLossFn` class: + .. code-block:: python + class PPOPolicyLossFn(PolicyLossFn): + ... + def __call__( + self, + logprob: torch.Tensor, + old_logprob: torch.Tensor, + action_mask: torch.Tensor, + advantages: torch.Tensor, + **kwargs, + ) -> Tuple[torch.Tensor, Dict]: + ... + + This metaclass analyzes the __call__ method's parameters to: + 1. Generate _select_keys containing all non-ignored parameters + 2. Create select_keys property that maps parameters to trainer-specific names + 3. Apply decorator to automatically convert input parameter names using the mapper + """ signature = inspect.signature(dct["__call__"]) param_names = [ key for key in signature.parameters.keys() if key not in PolicyLossFnMeta.ignore_keys ] dct["_select_keys"] = param_names + # Property to return trainer-specific parameter names def select_keys(self): + """Returns parameter keys mapped to the specific training framework's naming convention.""" keys = [self.mapper.from_trinity(key) for key in self._select_keys] return keys + # Decorator to handle parameter name conversion before calling __call__ def decorator(func): def wrapper(self, *args, **kwargs): + """Filters and converts parameter names according to the training framework's convention.""" new_kwargs = {} for key, value in kwargs.items(): key = self.mapper.to_trinity(key) if key in self._select_keys: # remove unused keys new_kwargs[key] = value - kwargs = new_kwargs return func(self, *args, **new_kwargs) return wrapper + # Add the property and decorated method to the class dct["select_keys"] = property(select_keys) dct["__call__"] = decorator(dct["__call__"]) return super().__new__(cls, name, bases, dct) @@ -45,10 +71,19 @@ def wrapper(self, *args, **kwargs): class PolicyLossFn(ABC, metaclass=PolicyLossFnMeta): """ - Policy Loss Function + Abstract base class for policy loss functions. + + This class provides the interface for implementing different policy gradient loss functions + while handling parameter name mapping between different training frameworks. """ def __init__(self, backend: str = "verl"): + """ + Initialize the policy loss function. + + Args: + backend: The training framework/backend to use (e.g., "verl") + """ self.backend = backend self.mapper = ALL_MAPPERS[self.backend] @@ -59,8 +94,12 @@ def __call__( **kwargs, ) -> Tuple[torch.Tensor, Dict]: """ + Calculate the policy loss. + Args: logprob (`torch.Tensor`): The log probability generated by the policy model. + + Kwargs (optional): old_logprob (`torch.Tensor`): The log probability generated by the reference model. action_mask (`torch.Tensor`): The action mask. advantages (`torch.Tensor`): The advantages. @@ -75,6 +114,8 @@ def __call__( @abstractmethod def default_args(cls) -> Dict: """ + Get default initialization arguments for this loss function. + Returns: `Dict`: The default init arguments for the policy loss function. """ From ae96b5f7a4d6330e1f05a5200cd118486c9b711c Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Tue, 17 Jun 2025 19:40:36 +0800 Subject: [PATCH 8/9] bug fix --- examples/mix_math/mix_math.yaml | 1 - tests/common/config_test.py | 1 + trinity/algorithm/policy_loss_fn/opmd_policy_loss.py | 1 - trinity/algorithm/policy_loss_fn/policy_loss_fn.py | 4 ++-- 4 files changed, 3 insertions(+), 4 deletions(-) diff --git a/examples/mix_math/mix_math.yaml b/examples/mix_math/mix_math.yaml index b680db6fb3..b92edd4b25 100644 --- a/examples/mix_math/mix_math.yaml +++ b/examples/mix_math/mix_math.yaml @@ -27,7 +27,6 @@ cluster: buffer: total_epochs: 1 batch_size: 40 - explore_batch_size: 36 max_retry_times: 3 max_retry_interval: 1 explorer_input: diff --git a/tests/common/config_test.py b/tests/common/config_test.py index e1ac0aa7d4..8aa4118a3c 100644 --- a/tests/common/config_test.py +++ b/tests/common/config_test.py @@ -47,6 +47,7 @@ def test_all_examples_are_valid(self): config_path = os.path.join(example_dir, example_name, filename) try: config = load_config(config_path) + config.checkpoint_root_dir = './.cache/' config.check_and_update() except Exception as e: print(f"Error loading config {config_path}: {e}") diff --git a/trinity/algorithm/policy_loss_fn/opmd_policy_loss.py b/trinity/algorithm/policy_loss_fn/opmd_policy_loss.py index 3c29d0ca2d..618301b319 100644 --- a/trinity/algorithm/policy_loss_fn/opmd_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/opmd_policy_loss.py @@ -17,7 +17,6 @@ def __init__(self, backend: str = "verl", tau: float = 1.0) -> None: def __call__( # type: ignore self, logprob: torch.Tensor, - old_logprob: torch.Tensor, # NOT USED! action_mask: torch.Tensor, advantages: torch.Tensor, **kwargs, diff --git a/trinity/algorithm/policy_loss_fn/policy_loss_fn.py b/trinity/algorithm/policy_loss_fn/policy_loss_fn.py index c8ca241cd8..aa6025252e 100644 --- a/trinity/algorithm/policy_loss_fn/policy_loss_fn.py +++ b/trinity/algorithm/policy_loss_fn/policy_loss_fn.py @@ -13,7 +13,7 @@ class PolicyLossFnMeta(ABCMeta): """Metaclass for policy loss functions that handles parameter name mapping and filtering.""" - ignore_keys = {"self", "kwargs"} # Keys to exclude from parameter selection + ignore_keys = {"self", "kwargs", "logprob"} # Keys to exclude from parameter selection def __new__(cls, name, bases, dct): """ @@ -57,7 +57,7 @@ def wrapper(self, *args, **kwargs): new_kwargs = {} for key, value in kwargs.items(): key = self.mapper.to_trinity(key) - if key in self._select_keys: # remove unused keys + if key == "logprob" or key in self._select_keys: # remove unused keys new_kwargs[key] = value return func(self, *args, **new_kwargs) From 37386d050788bc01b40e7e807b02f42f3b3385d3 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Tue, 17 Jun 2025 20:00:02 +0800 Subject: [PATCH 9/9] doc fix --- tests/common/config_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/common/config_test.py b/tests/common/config_test.py index 8aa4118a3c..da4fd914a0 100644 --- a/tests/common/config_test.py +++ b/tests/common/config_test.py @@ -47,7 +47,7 @@ def test_all_examples_are_valid(self): config_path = os.path.join(example_dir, example_name, filename) try: config = load_config(config_path) - config.checkpoint_root_dir = './.cache/' + config.checkpoint_root_dir = "./.cache/" config.check_and_update() except Exception as e: print(f"Error loading config {config_path}: {e}")