diff --git a/.github/workflows/sphinx-doc.yaml b/.github/workflows/sphinx-doc.yaml index fc7963bbbf..934d71060b 100644 --- a/.github/workflows/sphinx-doc.yaml +++ b/.github/workflows/sphinx-doc.yaml @@ -21,6 +21,10 @@ jobs: OS: ${{ matrix.os }} PYTHON: '3.10' steps: + - name: Free up disk space + run: | + sudo rm -rf /usr/share/dotnet /opt/ghc /usr/local/lib/android + docker system prune -af - name: Checkout PR branch if: github.event_name == 'pull_request' uses: actions/checkout@v4 @@ -28,7 +32,6 @@ jobs: repository: ${{ github.event.pull_request.head.repo.full_name }} ref: ${{ github.event.pull_request.head.ref }} fetch-depth: 0 - - name: Checkout main branch if: github.event_name != 'pull_request' uses: actions/checkout@v4 @@ -41,7 +44,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install Dependencies run: | - pip install -q -e .[doc] + pip install -e .[doc] - id: build name: Build Documentation run: | diff --git a/tests/algorithm/policy_loss_test.py b/tests/algorithm/policy_loss_test.py index 134635c05a..b1484d4f41 100644 --- a/tests/algorithm/policy_loss_test.py +++ b/tests/algorithm/policy_loss_test.py @@ -35,13 +35,17 @@ def test_ppo_policy_loss(self): policy_loss_fn_args = policy_loss_fn_cls.default_args() policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args) loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch) - ppo_loss = torch.tensor(0.28560468554496765) + ppo_loss = torch.tensor(0.26889559626579285) pg_clipfrac = torch.tensor(0.3541666567325592) ppo_kl = torch.tensor(-0.21663446724414825) + pg_clipfrac_lower = torch.tensor(0.0625) self.assertTrue(torch.allclose(loss, ppo_loss)) self.assertTrue(torch.allclose(torch.tensor(metrics["pg_clipfrac"]), pg_clipfrac)) self.assertTrue(torch.allclose(torch.tensor(metrics["ppo_kl"]), ppo_kl)) self.assertTrue(torch.allclose(torch.tensor(metrics["pg_loss"]), ppo_loss)) + self.assertTrue( + torch.allclose(torch.tensor(metrics["pg_clipfrac_lower"]), pg_clipfrac_lower) + ) def test_gspo_policy_loss(self): policy_loss_fn_cls = POLICY_LOSS_FN.get("gspo") @@ -52,7 +56,6 @@ def test_gspo_policy_loss(self): pg_clipfrac_expected = torch.tensor(0.375) ppo_kl_seq_expected = torch.tensor(-0.21027061343193054) ppo_kl_expected = torch.tensor(-0.21663446724414825) - print(f"{loss.item()=}, {metrics=}") self.assertTrue(torch.allclose(loss, gspo_loss_expected)) self.assertTrue(torch.allclose(torch.tensor(metrics["pg_clipfrac"]), pg_clipfrac_expected)) self.assertTrue(torch.allclose(torch.tensor(metrics["ppo_kl_seq"]), ppo_kl_seq_expected)) @@ -97,14 +100,18 @@ def test_mix_policy_loss(self): policy_loss_fn_args = policy_loss_fn_cls.default_args() policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args) loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch) - mix_loss = torch.tensor(0.6581965088844299) + mix_loss = torch.tensor(0.6298247575759888) pg_clipfrac = torch.tensor(0.7777777910232544) ppo_kl = torch.tensor(-1.0737695693969727) - pg_loss = torch.tensor(0.7236452102661133) + pg_loss = torch.tensor(0.6921210885047913) sft_loss = torch.tensor(0.06915830634534359) + pg_clipfrac_lower = torch.tensor(0.2222222238779068) self.assertTrue(torch.allclose(loss, mix_loss)) self.assertTrue(torch.allclose(torch.tensor(metrics["usual/pg_clipfrac"]), pg_clipfrac)) self.assertTrue(torch.allclose(torch.tensor(metrics["usual/ppo_kl"]), ppo_kl)) self.assertTrue(torch.allclose(torch.tensor(metrics["usual/pg_loss"]), pg_loss)) + self.assertTrue( + torch.allclose(torch.tensor(metrics["usual/pg_clipfrac_lower"]), pg_clipfrac_lower) + ) self.assertTrue(torch.allclose(torch.tensor(metrics["expert/sft_loss"]), sft_loss)) self.assertTrue(torch.allclose(torch.tensor(metrics["loss"]), mix_loss)) diff --git a/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py b/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py index 7a81fe96bf..d5bf55dc1b 100644 --- a/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py +++ b/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py @@ -3,7 +3,7 @@ import torch -from trinity.algorithm.utils import masked_mean +from trinity.algorithm.utils import aggregate_loss from trinity.utils.registry import Registry ENTROPY_LOSS_FN = Registry("entropy_loss_fn") @@ -53,9 +53,10 @@ def __call__( self, entropy: torch.Tensor, action_mask: torch.Tensor, + loss_agg_mode: str = "token-mean", **kwargs, ) -> Tuple[torch.Tensor, Dict]: - entropy_loss = masked_mean(entropy, action_mask) + entropy_loss = aggregate_loss(entropy, action_mask, loss_agg_mode=loss_agg_mode) return entropy_loss * self.entropy_coef, {"entropy_loss": entropy_loss.detach().item()} @@ -73,6 +74,7 @@ def __call__( entropy: torch.Tensor, action_mask: torch.Tensor, expert_mask: torch.Tensor = None, + loss_agg_mode: str = "token-mean", **kwargs, ) -> Tuple[torch.Tensor, Dict]: if expert_mask is None: @@ -82,7 +84,7 @@ def __call__( ), f"Error: {len(expert_mask)=} != {entropy.shape[0]=}" entropy = entropy[~expert_mask] action_mask = action_mask[~expert_mask] - entropy_loss = masked_mean(entropy, action_mask) + entropy_loss = aggregate_loss(entropy, action_mask, loss_agg_mode=loss_agg_mode) return entropy_loss * self.entropy_coef, {"entropy_loss": entropy_loss.detach().item()} diff --git a/trinity/algorithm/kl_fn/kl_fn.py b/trinity/algorithm/kl_fn/kl_fn.py index 62ed48cd49..49c59f2367 100644 --- a/trinity/algorithm/kl_fn/kl_fn.py +++ b/trinity/algorithm/kl_fn/kl_fn.py @@ -11,7 +11,7 @@ import torch -from trinity.algorithm.utils import masked_mean +from trinity.algorithm.utils import aggregate_loss, masked_mean from trinity.utils.registry import Registry KL_FN = Registry("kl_fn") @@ -81,10 +81,11 @@ def calculate_kl_loss( logprob: torch.Tensor, ref_logprob: torch.Tensor, response_mask: torch.Tensor, + loss_agg_mode: str, ) -> Tuple[torch.Tensor, Dict]: """Compute KL loss.""" kl = self.calculate_kl(logprob, ref_logprob) - kl_loss = masked_mean(kl, response_mask) + kl_loss = aggregate_loss(kl, response_mask, loss_agg_mode=loss_agg_mode) metrics = { "kl_loss": kl_loss.detach().item(), "kl_coef": self.kl_coef, @@ -119,6 +120,7 @@ def calculate_kl_loss( logprob: torch.Tensor, ref_logprob: torch.Tensor, response_mask: torch.Tensor, + loss_agg_mode: str, ) -> Tuple[torch.Tensor, Dict]: # return a zero tensor return torch.tensor(0.0), {} @@ -155,6 +157,20 @@ def calculate_kl(self, logprob: torch.Tensor, ref_logprob: torch.Tensor) -> torc return logr.exp() - 1 - logr +@KL_FN.register_module("low_var_kl") +class LowVarKLFn(KLFn): + """ + Low Variance KL function. + """ + + def calculate_kl(self, logprob: torch.Tensor, ref_logprob: torch.Tensor) -> torch.Tensor: + kl = ref_logprob - logprob + kl = torch.clamp(kl, min=-20, max=20) + ratio = torch.exp(kl) + kld = (ratio - kl - 1).contiguous() + return torch.clamp(kld, min=-10, max=10) + + @KL_FN.register_module("abs") class AbsFn(KLFn): """ diff --git a/trinity/algorithm/policy_loss_fn/chord_policy_loss.py b/trinity/algorithm/policy_loss_fn/chord_policy_loss.py index c3b0d007cb..1a6a893041 100644 --- a/trinity/algorithm/policy_loss_fn/chord_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/chord_policy_loss.py @@ -8,7 +8,7 @@ from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn from trinity.algorithm.policy_loss_fn.ppo_policy_loss import PPOPolicyLossFn from trinity.algorithm.policy_loss_fn.sft_loss import SFTLossFn -from trinity.algorithm.utils import masked_loss +from trinity.algorithm.utils import aggregate_loss def mu_schedule_function( @@ -48,7 +48,7 @@ def __call__( # type: ignore **kwargs, ) -> Tuple[torch.Tensor, Dict]: token_prob = torch.exp(logprob) - sft_loss = masked_loss( + sft_loss = aggregate_loss( -logprob * token_prob.detach(), action_mask, loss_agg_mode=self.loss_agg_mode ) return sft_loss, {"sft_is_loss": sft_loss.detach().item()} @@ -94,7 +94,7 @@ def __call__( # type: ignore weighted_phi = phi_function(token_prob) - sft_loss = masked_loss( + sft_loss = aggregate_loss( -logprob * weighted_phi.detach(), action_mask, loss_agg_mode=self.loss_agg_mode ) return sft_loss, {"sft_phi_loss": sft_loss.detach().item()} @@ -141,8 +141,9 @@ def __init__( ngpus_trainer: int = 1, train_batch_size_usual: int = 1, train_batch_size_expert: int = 1, - sft_loss_agg_mode: str = "token-mean", - grpo_loss_agg_mode: str = "token-mean", + loss_agg_mode: str = "token-mean", + sft_loss_agg_mode: Optional[str] = None, + grpo_loss_agg_mode: Optional[str] = None, ) -> None: super().__init__(backend=backend) self.mu_warmup_steps = mu_warmup_steps @@ -159,12 +160,12 @@ def __init__( clip_range=clip_range, clip_range_low=clip_range_low, clip_range_high=clip_range_high, - loss_agg_mode=grpo_loss_agg_mode, + loss_agg_mode=grpo_loss_agg_mode or loss_agg_mode, ) if enable_phi_function: - self.sft_loss_fn = SFTPhiLossFn(loss_agg_mode=sft_loss_agg_mode) + self.sft_loss_fn = SFTPhiLossFn(loss_agg_mode=sft_loss_agg_mode or loss_agg_mode) else: - self.sft_loss_fn = SFTLossFn(loss_agg_mode=sft_loss_agg_mode) + self.sft_loss_fn = SFTLossFn(loss_agg_mode=sft_loss_agg_mode or loss_agg_mode) def __call__( # type: ignore self, @@ -255,4 +256,5 @@ def default_args(cls) -> Dict: "mu_valley": 0.1, "clip_range": 0.2, "enable_phi_function": True, + "loss_agg_mode": "token-mean", } diff --git a/trinity/algorithm/policy_loss_fn/cispo_policy_loss.py b/trinity/algorithm/policy_loss_fn/cispo_policy_loss.py index 7b9c16d203..6da3d07526 100644 --- a/trinity/algorithm/policy_loss_fn/cispo_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/cispo_policy_loss.py @@ -7,7 +7,7 @@ import torch from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn -from trinity.algorithm.utils import masked_loss, masked_mean +from trinity.algorithm.utils import aggregate_loss, masked_mean @POLICY_LOSS_FN.register_module("cispo") @@ -63,7 +63,7 @@ def __call__( # type: ignore cispo_loss = -advantages * ratio_clamped.detach() * mask.detach() * logprob - loss = masked_loss(cispo_loss, action_mask, loss_agg_mode=self.loss_agg_mode) + loss = aggregate_loss(cispo_loss, action_mask, loss_agg_mode=self.loss_agg_mode) unmasked_frac = masked_mean(mask, action_mask) metrics = { diff --git a/trinity/algorithm/policy_loss_fn/gspo_policy_loss.py b/trinity/algorithm/policy_loss_fn/gspo_policy_loss.py index 6ebd5af243..14e76dc02b 100644 --- a/trinity/algorithm/policy_loss_fn/gspo_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/gspo_policy_loss.py @@ -8,7 +8,7 @@ import torch from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn -from trinity.algorithm.utils import masked_loss, masked_mean +from trinity.algorithm.utils import aggregate_loss, masked_mean @POLICY_LOSS_FN.register_module("gspo") @@ -54,7 +54,7 @@ def __call__( # type: ignore ratio, 1.0 - self.clip_range_low, 1.0 + self.clip_range_high ) # [batch_size, seq_len] - pg_loss = masked_loss( + pg_loss = aggregate_loss( values=torch.max(pg_losses, pg_losses_clipped), mask=action_mask, loss_agg_mode=self.loss_agg_mode, diff --git a/trinity/algorithm/policy_loss_fn/mix_policy_loss.py b/trinity/algorithm/policy_loss_fn/mix_policy_loss.py index 7c0e051371..ae3e3ffb84 100644 --- a/trinity/algorithm/policy_loss_fn/mix_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/mix_policy_loss.py @@ -37,8 +37,9 @@ def __init__( ngpus_trainer: int = 1, train_batch_size_usual: int = 1, train_batch_size_expert: int = 1, - sft_loss_agg_mode: str = "token-mean", - grpo_loss_agg_mode: str = "token-mean", + loss_agg_mode: str = "token-mean", + sft_loss_agg_mode: Optional[str] = None, + grpo_loss_agg_mode: Optional[str] = None, ) -> None: super().__init__(backend=backend) self.mu = mu @@ -51,9 +52,9 @@ def __init__( clip_range=clip_range, clip_range_low=clip_range_low, clip_range_high=clip_range_high, - loss_agg_mode=grpo_loss_agg_mode, + loss_agg_mode=grpo_loss_agg_mode or loss_agg_mode, ) - self.sft_loss_fn = SFTLossFn(loss_agg_mode=sft_loss_agg_mode) + self.sft_loss_fn = SFTLossFn(loss_agg_mode=sft_loss_agg_mode or loss_agg_mode) def __call__( # type: ignore self, @@ -125,4 +126,5 @@ def default_args(cls) -> Dict: return { "mu": 0.1, "clip_range": 0.2, + "loss_agg_mode": "token-mean", } diff --git a/trinity/algorithm/policy_loss_fn/opmd_policy_loss.py b/trinity/algorithm/policy_loss_fn/opmd_policy_loss.py index d3246d71f8..b83a960bba 100644 --- a/trinity/algorithm/policy_loss_fn/opmd_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/opmd_policy_loss.py @@ -5,7 +5,7 @@ import torch from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn -from trinity.algorithm.utils import masked_loss +from trinity.algorithm.utils import aggregate_loss @POLICY_LOSS_FN.register_module("opmd") @@ -25,7 +25,7 @@ def __call__( # type: ignore **kwargs, ) -> Tuple[torch.Tensor, Dict]: pg_losses = -advantages * logprob - opmd_loss = masked_loss(pg_losses, action_mask, loss_agg_mode=self.loss_agg_mode) + opmd_loss = aggregate_loss(pg_losses, action_mask, loss_agg_mode=self.loss_agg_mode) opmd_loss = opmd_loss / (1.0 + self.tau) # for regularization (w.r.t. current pi_theta) return opmd_loss, {"opmd_loss": opmd_loss.detach().item()} diff --git a/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py b/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py index 9c9bbaf2a5..f2c812a0b5 100644 --- a/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py @@ -8,7 +8,7 @@ import torch from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn -from trinity.algorithm.utils import masked_loss, masked_mean +from trinity.algorithm.utils import aggregate_loss, masked_mean @POLICY_LOSS_FN.register_module("ppo") @@ -19,6 +19,7 @@ def __init__( clip_range: Optional[float] = None, clip_range_low: Optional[float] = None, clip_range_high: Optional[float] = None, + clip_ratio_c: float = 3.0, loss_agg_mode: Optional[str] = "token-mean", ) -> None: super().__init__(backend=backend) @@ -30,6 +31,8 @@ def __init__( self.clip_range_high = clip_range else: self.clip_range_high = clip_range_high + self.clip_ratio_c = clip_ratio_c + assert clip_ratio_c > 1.0, "clip_ratio_c must be greater than 1.0." assert self.clip_range_low is not None, "clip_range_low must be specified." assert self.clip_range_high is not None, "clip_range_high must be specified." self.loss_agg_mode = loss_agg_mode @@ -43,20 +46,30 @@ def __call__( # type: ignore **kwargs, ) -> Tuple[torch.Tensor, Dict]: negative_approx_kl = logprob - old_logprob + # Clamp negative_approx_kl for stability + negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0) ratio = torch.exp(negative_approx_kl) ppo_kl = masked_mean(-negative_approx_kl, action_mask) - pg_losses = -advantages * ratio + pg_losses1 = -advantages * ratio pg_losses2 = -advantages * torch.clamp( ratio, 1.0 - self.clip_range_low, 1.0 + self.clip_range_high # type: ignore ) - pg_loss = masked_loss( - torch.max(pg_losses, pg_losses2), action_mask, loss_agg_mode=self.loss_agg_mode + clip_pg_losses1 = torch.maximum(pg_losses1, pg_losses2) + + pg_clip_frac = masked_mean(torch.gt(pg_losses2, pg_losses1).float(), action_mask) + + pg_losses3 = -advantages * self.clip_ratio_c + clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1) + pg_clipfrac_lower = masked_mean( + torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), action_mask ) - pg_clipfrac = masked_mean(torch.gt(pg_losses2, pg_losses).float(), action_mask) + pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1) + pg_loss = aggregate_loss(pg_losses, action_mask, loss_agg_mode=self.loss_agg_mode) metrics = { - "pg_clipfrac": pg_clipfrac.detach().item(), + "pg_clipfrac": pg_clip_frac.detach().item(), + "pg_clipfrac_lower": pg_clipfrac_lower.detach().item(), "ppo_kl": ppo_kl.detach().item(), "pg_loss": pg_loss.detach().item(), } @@ -66,5 +79,6 @@ def __call__( # type: ignore def default_args(cls) -> Dict: return { "clip_range": 0.2, + "clip_ratio_c": 3.0, "loss_agg_mode": "token-mean", } diff --git a/trinity/algorithm/policy_loss_fn/sft_loss.py b/trinity/algorithm/policy_loss_fn/sft_loss.py index f7e1daddf4..bd81d15380 100644 --- a/trinity/algorithm/policy_loss_fn/sft_loss.py +++ b/trinity/algorithm/policy_loss_fn/sft_loss.py @@ -5,7 +5,7 @@ import torch from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn -from trinity.algorithm.utils import masked_loss +from trinity.algorithm.utils import aggregate_loss @POLICY_LOSS_FN.register_module("sft") @@ -20,7 +20,7 @@ def __call__( # type: ignore action_mask: torch.Tensor, **kwargs, ) -> Tuple[torch.Tensor, Dict]: - sft_loss = masked_loss(-logprob, action_mask, loss_agg_mode=self.loss_agg_mode) + sft_loss = aggregate_loss(-logprob, action_mask, loss_agg_mode=self.loss_agg_mode) return sft_loss, {"sft_loss": sft_loss.detach().item()} diff --git a/trinity/algorithm/policy_loss_fn/sppo_loss_fn.py b/trinity/algorithm/policy_loss_fn/sppo_loss_fn.py index 647bd27d96..068a201c26 100644 --- a/trinity/algorithm/policy_loss_fn/sppo_loss_fn.py +++ b/trinity/algorithm/policy_loss_fn/sppo_loss_fn.py @@ -7,7 +7,7 @@ import torch from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn -from trinity.algorithm.utils import masked_loss, masked_mean +from trinity.algorithm.utils import aggregate_loss, masked_mean @POLICY_LOSS_FN.register_module("sppo") @@ -41,7 +41,7 @@ def __call__( # type: ignore is_in_range = (ratio >= (1 / (1 + self.epsilon))) * (ratio <= (1 + self.epsilon)) is_clipped_mask = ~is_in_range pg_losses = -advantages * (logprob - old_logprob) * is_in_range.float() - pg_loss = masked_loss(pg_losses, action_mask, loss_agg_mode=self.loss_agg_mode) + pg_loss = aggregate_loss(pg_losses, action_mask, loss_agg_mode=self.loss_agg_mode) pg_clipfrac = masked_mean(is_clipped_mask.float(), action_mask) metrics = { "pg_clipfrac": pg_clipfrac.item(), diff --git a/trinity/algorithm/policy_loss_fn/topr_policy_loss.py b/trinity/algorithm/policy_loss_fn/topr_policy_loss.py index 1ae23d57ec..cb6500754a 100644 --- a/trinity/algorithm/policy_loss_fn/topr_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/topr_policy_loss.py @@ -6,7 +6,7 @@ import torch from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn -from trinity.algorithm.utils import masked_loss, masked_mean +from trinity.algorithm.utils import aggregate_loss, masked_mean @POLICY_LOSS_FN.register_module("topr") @@ -56,7 +56,7 @@ def __call__( # type: ignore topr_loss = -alpha.detach() * rewards * logprob # detach alpha as it's used with stop-grad # Apply masking and compute mean - loss = masked_loss(topr_loss, action_mask, loss_agg_mode=self.loss_agg_mode) + loss = aggregate_loss(topr_loss, action_mask, loss_agg_mode=self.loss_agg_mode) # Average alpha value for monitoring avg_alpha = masked_mean(alpha, action_mask) diff --git a/trinity/algorithm/utils.py b/trinity/algorithm/utils.py index 271cc00352..707639f4cc 100644 --- a/trinity/algorithm/utils.py +++ b/trinity/algorithm/utils.py @@ -6,7 +6,7 @@ import torch -def masked_loss(values, mask, loss_agg_mode="token-mean", normalizer=None): +def aggregate_loss(values, mask, loss_agg_mode="token-mean", normalizer=None): """ Compute loss from values and mask with various aggregation modes. Modified from: https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py @@ -51,12 +51,13 @@ def masked_loss(values, mask, loss_agg_mode="token-mean", normalizer=None): def masked_sum(values, mask, axis=None): """Compute mean of tensor with a masked values.""" - return (values * mask).sum(axis=axis) + valid_values = torch.where(mask.bool(), values, 0.0) + return (valid_values * mask).sum(axis=axis) def masked_mean(values, mask, axis=None): """Compute mean of tensor with a masked values.""" - return (values * mask).sum(axis=axis) / (mask.sum(axis=axis) + 1e-8) + return masked_sum(values, mask, axis=axis) / (mask.sum(axis=axis) + 1e-8) def masked_var(values, mask, unbiased=True): diff --git a/trinity/common/config.py b/trinity/common/config.py index c722959b96..7d37f74bd1 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -533,6 +533,10 @@ class AlgorithmConfig: # If not set, use entropy_loss_fn.default_args() entropy_loss_fn_args: Optional[dict] = None + # aggregation mode for losses: 'token-mean' or 'seq-mean-token-sum' or 'seq-mean-token-mean' or 'seq-mean-token-sum-norm' + # If not set, use 'token-mean' + loss_agg_mode: Optional[str] = None + @dataclass class ClusterConfig: @@ -1042,6 +1046,7 @@ def _check_algorithm(self) -> None: "kl_penalty_fn": "none", "kl_loss_fn": "k2", "entropy_loss_fn": "default", + "loss_agg_mode": "token-mean", } default_config.update(algorithm.default_config()) for key, value in default_config.items(): @@ -1060,6 +1065,9 @@ def check_and_set(name, registry, args_attr): check_and_set("kl_loss_fn", KL_FN, "kl_loss_fn_args") check_and_set("kl_penalty_fn", KL_FN, "kl_penalty_fn_args") check_and_set("entropy_loss_fn", ENTROPY_LOSS_FN, "entropy_loss_fn_args") + if "loss_agg_mode" in self.algorithm.policy_loss_fn_args: # type: ignore [operator] + # override loss_agg_mode in policy_loss_fn_args + self.algorithm.policy_loss_fn_args["loss_agg_mode"] = self.algorithm.loss_agg_mode # type: ignore [index] def _check_model(self) -> None: model = self.model diff --git a/trinity/trainer/verl/dp_actor.py b/trinity/trainer/verl/dp_actor.py index 5b9a5256c7..0d571ed497 100644 --- a/trinity/trainer/verl/dp_actor.py +++ b/trinity/trainer/verl/dp_actor.py @@ -53,6 +53,7 @@ def __init__( self.entropy_loss_fn = None def set_algorithm(self, algorithm_config: AlgorithmConfig): + self.loss_agg_mode = algorithm_config.loss_agg_mode self.policy_loss_fn = POLICY_LOSS_FN.get(algorithm_config.policy_loss_fn)( backend="verl", **algorithm_config.policy_loss_fn_args ) @@ -89,13 +90,8 @@ def update_policy(self, data: DataProto): # noqa: C901 mini_batches = data.split(self.config.ppo_mini_batch_size) # EXPERIMENTAL: apply loss scale fix - loss_agg_mode = ( - self.policy_loss_fn.loss_agg_mode - if hasattr(self.policy_loss_fn, "loss_agg_mode") - else "token-mean" - ) do_fix_actor_microbatch_loss_scale = self.config.fix_actor_microbatch_loss_scale and ( - loss_agg_mode == "token-mean" + self.loss_agg_mode == "token-mean" ) metrics = {} @@ -149,6 +145,7 @@ def update_policy(self, data: DataProto): # noqa: C901 entropy_loss, entropy_loss_metrics = self.entropy_loss_fn( # type: ignore entropy=entropy, action_mask=response_mask, + loss_agg_mode=self.loss_agg_mode, **model_inputs, ) prefix_metrics( @@ -164,6 +161,7 @@ def update_policy(self, data: DataProto): # noqa: C901 logprob=log_prob, ref_logprob=model_inputs.get("ref_log_prob", None), response_mask=response_mask, + loss_agg_mode=self.loss_agg_mode, ) prefix_metrics( src_metrics=kl_loss_metrics, @@ -185,6 +183,7 @@ def update_policy(self, data: DataProto): # noqa: C901 loss_scale = torch.sum(response_mask).item() / (mini_batch_token_num + 1e-6) loss = policy_loss * loss_scale + micro_batch_metrics["actor/final_loss"] = loss.detach().item() loss.backward() append_to_dict(metrics, micro_batch_metrics) diff --git a/trinity/trainer/verl/megatron_actor.py b/trinity/trainer/verl/megatron_actor.py index 4f4633ceca..4e9183609f 100644 --- a/trinity/trainer/verl/megatron_actor.py +++ b/trinity/trainer/verl/megatron_actor.py @@ -64,6 +64,7 @@ def __init__( self.entropy_loss_fn = None def set_algorithm(self, algorithm_config: AlgorithmConfig): + self.loss_agg_mode = algorithm_config.loss_agg_mode self.policy_loss_fn = POLICY_LOSS_FN.get(algorithm_config.policy_loss_fn)( backend="verl", **algorithm_config.policy_loss_fn_args ) @@ -187,6 +188,7 @@ def loss_func(output, data, meta_info): entropy_loss, entropy_loss_metrics = self.entropy_loss_fn( # type: ignore entropy=entropy, action_mask=response_mask, + loss_agg_mode=self.loss_agg_mode, **data, ) prefix_metrics( @@ -207,6 +209,7 @@ def loss_func(output, data, meta_info): logprob=log_prob, ref_logprob=data.get("ref_log_prob", None), response_mask=response_mask, + loss_agg_mode=self.loss_agg_mode, ) prefix_metrics( src_metrics=kl_loss_metrics,