Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions tests/algorithm/policy_loss_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,17 @@ def test_ppo_policy_loss(self):
policy_loss_fn_args = policy_loss_fn_cls.default_args()
policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args)
loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch)
ppo_loss = torch.tensor(0.28560468554496765)
ppo_loss = torch.tensor(0.26889559626579285)
pg_clipfrac = torch.tensor(0.3541666567325592)
ppo_kl = torch.tensor(-0.21663446724414825)
pg_clipfrac_lower = torch.tensor(0.0625)
self.assertTrue(torch.allclose(loss, ppo_loss))
self.assertTrue(torch.allclose(torch.tensor(metrics["pg_clipfrac"]), pg_clipfrac))
self.assertTrue(torch.allclose(torch.tensor(metrics["ppo_kl"]), ppo_kl))
self.assertTrue(torch.allclose(torch.tensor(metrics["pg_loss"]), ppo_loss))
self.assertTrue(
torch.allclose(torch.tensor(metrics["pg_clipfrac_lower"]), pg_clipfrac_lower)
)

def test_gspo_policy_loss(self):
policy_loss_fn_cls = POLICY_LOSS_FN.get("gspo")
Expand All @@ -52,7 +56,6 @@ def test_gspo_policy_loss(self):
pg_clipfrac_expected = torch.tensor(0.375)
ppo_kl_seq_expected = torch.tensor(-0.21027061343193054)
ppo_kl_expected = torch.tensor(-0.21663446724414825)
print(f"{loss.item()=}, {metrics=}")
self.assertTrue(torch.allclose(loss, gspo_loss_expected))
self.assertTrue(torch.allclose(torch.tensor(metrics["pg_clipfrac"]), pg_clipfrac_expected))
self.assertTrue(torch.allclose(torch.tensor(metrics["ppo_kl_seq"]), ppo_kl_seq_expected))
Expand Down Expand Up @@ -97,14 +100,18 @@ def test_mix_policy_loss(self):
policy_loss_fn_args = policy_loss_fn_cls.default_args()
policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args)
loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch)
mix_loss = torch.tensor(0.6581965088844299)
mix_loss = torch.tensor(0.6298247575759888)
pg_clipfrac = torch.tensor(0.7777777910232544)
ppo_kl = torch.tensor(-1.0737695693969727)
pg_loss = torch.tensor(0.7236452102661133)
pg_loss = torch.tensor(0.6921210885047913)
sft_loss = torch.tensor(0.06915830634534359)
pg_clipfrac_lower = torch.tensor(0.2222222238779068)
self.assertTrue(torch.allclose(loss, mix_loss))
self.assertTrue(torch.allclose(torch.tensor(metrics["usual/pg_clipfrac"]), pg_clipfrac))
self.assertTrue(torch.allclose(torch.tensor(metrics["usual/ppo_kl"]), ppo_kl))
self.assertTrue(torch.allclose(torch.tensor(metrics["usual/pg_loss"]), pg_loss))
self.assertTrue(
torch.allclose(torch.tensor(metrics["usual/pg_clipfrac_lower"]), pg_clipfrac_lower)
)
self.assertTrue(torch.allclose(torch.tensor(metrics["expert/sft_loss"]), sft_loss))
self.assertTrue(torch.allclose(torch.tensor(metrics["loss"]), mix_loss))
8 changes: 5 additions & 3 deletions trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch

from trinity.algorithm.utils import masked_mean
from trinity.algorithm.utils import aggregate_loss
from trinity.utils.registry import Registry

ENTROPY_LOSS_FN = Registry("entropy_loss_fn")
Expand Down Expand Up @@ -53,9 +53,10 @@ def __call__(
self,
entropy: torch.Tensor,
action_mask: torch.Tensor,
loss_agg_mode: str = "token-mean",
**kwargs,
) -> Tuple[torch.Tensor, Dict]:
entropy_loss = masked_mean(entropy, action_mask)
entropy_loss = aggregate_loss(entropy, action_mask, loss_agg_mode=loss_agg_mode)
return entropy_loss * self.entropy_coef, {"entropy_loss": entropy_loss.detach().item()}


Expand All @@ -73,6 +74,7 @@ def __call__(
entropy: torch.Tensor,
action_mask: torch.Tensor,
expert_mask: torch.Tensor = None,
loss_agg_mode: str = "token-mean",
**kwargs,
) -> Tuple[torch.Tensor, Dict]:
if expert_mask is None:
Expand All @@ -82,7 +84,7 @@ def __call__(
), f"Error: {len(expert_mask)=} != {entropy.shape[0]=}"
entropy = entropy[~expert_mask]
action_mask = action_mask[~expert_mask]
entropy_loss = masked_mean(entropy, action_mask)
entropy_loss = aggregate_loss(entropy, action_mask, loss_agg_mode=loss_agg_mode)
return entropy_loss * self.entropy_coef, {"entropy_loss": entropy_loss.detach().item()}


Expand Down
20 changes: 18 additions & 2 deletions trinity/algorithm/kl_fn/kl_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import torch

from trinity.algorithm.utils import masked_mean
from trinity.algorithm.utils import aggregate_loss, masked_mean
from trinity.utils.registry import Registry

KL_FN = Registry("kl_fn")
Expand Down Expand Up @@ -81,10 +81,11 @@ def calculate_kl_loss(
logprob: torch.Tensor,
ref_logprob: torch.Tensor,
response_mask: torch.Tensor,
loss_agg_mode: str,
) -> Tuple[torch.Tensor, Dict]:
"""Compute KL loss."""
kl = self.calculate_kl(logprob, ref_logprob)
kl_loss = masked_mean(kl, response_mask)
kl_loss = aggregate_loss(kl, response_mask, loss_agg_mode=loss_agg_mode)
metrics = {
"kl_loss": kl_loss.detach().item(),
"kl_coef": self.kl_coef,
Expand Down Expand Up @@ -119,6 +120,7 @@ def calculate_kl_loss(
logprob: torch.Tensor,
ref_logprob: torch.Tensor,
response_mask: torch.Tensor,
loss_agg_mode: str,
) -> Tuple[torch.Tensor, Dict]:
# return a zero tensor
return torch.tensor(0.0), {}
Expand Down Expand Up @@ -155,6 +157,20 @@ def calculate_kl(self, logprob: torch.Tensor, ref_logprob: torch.Tensor) -> torc
return logr.exp() - 1 - logr


@KL_FN.register_module("low_var_kl")
class LowVarKLFn(KLFn):
"""
Low Variance KL function.
"""

def calculate_kl(self, logprob: torch.Tensor, ref_logprob: torch.Tensor) -> torch.Tensor:
kl = ref_logprob - logprob
kl = torch.clamp(kl, min=-20, max=20)
ratio = torch.exp(kl)
kld = (ratio - kl - 1).contiguous()
return torch.clamp(kld, min=-10, max=10)


@KL_FN.register_module("abs")
class AbsFn(KLFn):
"""
Expand Down
18 changes: 10 additions & 8 deletions trinity/algorithm/policy_loss_fn/chord_policy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
from trinity.algorithm.policy_loss_fn.ppo_policy_loss import PPOPolicyLossFn
from trinity.algorithm.policy_loss_fn.sft_loss import SFTLossFn
from trinity.algorithm.utils import masked_loss
from trinity.algorithm.utils import aggregate_loss


def mu_schedule_function(
Expand Down Expand Up @@ -48,7 +48,7 @@ def __call__( # type: ignore
**kwargs,
) -> Tuple[torch.Tensor, Dict]:
token_prob = torch.exp(logprob)
sft_loss = masked_loss(
sft_loss = aggregate_loss(
-logprob * token_prob.detach(), action_mask, loss_agg_mode=self.loss_agg_mode
)
return sft_loss, {"sft_is_loss": sft_loss.detach().item()}
Expand Down Expand Up @@ -94,7 +94,7 @@ def __call__( # type: ignore

weighted_phi = phi_function(token_prob)

sft_loss = masked_loss(
sft_loss = aggregate_loss(
-logprob * weighted_phi.detach(), action_mask, loss_agg_mode=self.loss_agg_mode
)
return sft_loss, {"sft_phi_loss": sft_loss.detach().item()}
Expand Down Expand Up @@ -141,8 +141,9 @@ def __init__(
ngpus_trainer: int = 1,
train_batch_size_usual: int = 1,
train_batch_size_expert: int = 1,
sft_loss_agg_mode: str = "token-mean",
grpo_loss_agg_mode: str = "token-mean",
loss_agg_mode: str = "token-mean",
sft_loss_agg_mode: Optional[str] = None,
grpo_loss_agg_mode: Optional[str] = None,
) -> None:
super().__init__(backend=backend)
self.mu_warmup_steps = mu_warmup_steps
Expand All @@ -159,12 +160,12 @@ def __init__(
clip_range=clip_range,
clip_range_low=clip_range_low,
clip_range_high=clip_range_high,
loss_agg_mode=grpo_loss_agg_mode,
loss_agg_mode=grpo_loss_agg_mode or loss_agg_mode,
)
if enable_phi_function:
self.sft_loss_fn = SFTPhiLossFn(loss_agg_mode=sft_loss_agg_mode)
self.sft_loss_fn = SFTPhiLossFn(loss_agg_mode=sft_loss_agg_mode or loss_agg_mode)
else:
self.sft_loss_fn = SFTLossFn(loss_agg_mode=sft_loss_agg_mode)
self.sft_loss_fn = SFTLossFn(loss_agg_mode=sft_loss_agg_mode or loss_agg_mode)

def __call__( # type: ignore
self,
Expand Down Expand Up @@ -255,4 +256,5 @@ def default_args(cls) -> Dict:
"mu_valley": 0.1,
"clip_range": 0.2,
"enable_phi_function": True,
"loss_agg_mode": "token-mean",
}
4 changes: 2 additions & 2 deletions trinity/algorithm/policy_loss_fn/cispo_policy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch

from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
from trinity.algorithm.utils import masked_loss, masked_mean
from trinity.algorithm.utils import aggregate_loss, masked_mean


@POLICY_LOSS_FN.register_module("cispo")
Expand Down Expand Up @@ -63,7 +63,7 @@ def __call__( # type: ignore

cispo_loss = -advantages * ratio_clamped.detach() * mask.detach() * logprob

loss = masked_loss(cispo_loss, action_mask, loss_agg_mode=self.loss_agg_mode)
loss = aggregate_loss(cispo_loss, action_mask, loss_agg_mode=self.loss_agg_mode)
unmasked_frac = masked_mean(mask, action_mask)

metrics = {
Expand Down
4 changes: 2 additions & 2 deletions trinity/algorithm/policy_loss_fn/gspo_policy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch

from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
from trinity.algorithm.utils import masked_loss, masked_mean
from trinity.algorithm.utils import aggregate_loss, masked_mean


@POLICY_LOSS_FN.register_module("gspo")
Expand Down Expand Up @@ -54,7 +54,7 @@ def __call__( # type: ignore
ratio, 1.0 - self.clip_range_low, 1.0 + self.clip_range_high
) # [batch_size, seq_len]

pg_loss = masked_loss(
pg_loss = aggregate_loss(
values=torch.max(pg_losses, pg_losses_clipped),
mask=action_mask,
loss_agg_mode=self.loss_agg_mode,
Expand Down
10 changes: 6 additions & 4 deletions trinity/algorithm/policy_loss_fn/mix_policy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@ def __init__(
ngpus_trainer: int = 1,
train_batch_size_usual: int = 1,
train_batch_size_expert: int = 1,
sft_loss_agg_mode: str = "token-mean",
grpo_loss_agg_mode: str = "token-mean",
loss_agg_mode: str = "token-mean",
sft_loss_agg_mode: Optional[str] = None,
grpo_loss_agg_mode: Optional[str] = None,
) -> None:
super().__init__(backend=backend)
self.mu = mu
Expand All @@ -51,9 +52,9 @@ def __init__(
clip_range=clip_range,
clip_range_low=clip_range_low,
clip_range_high=clip_range_high,
loss_agg_mode=grpo_loss_agg_mode,
loss_agg_mode=grpo_loss_agg_mode or loss_agg_mode,
)
self.sft_loss_fn = SFTLossFn(loss_agg_mode=sft_loss_agg_mode)
self.sft_loss_fn = SFTLossFn(loss_agg_mode=sft_loss_agg_mode or loss_agg_mode)

def __call__( # type: ignore
self,
Expand Down Expand Up @@ -125,4 +126,5 @@ def default_args(cls) -> Dict:
return {
"mu": 0.1,
"clip_range": 0.2,
"loss_agg_mode": "token-mean",
}
4 changes: 2 additions & 2 deletions trinity/algorithm/policy_loss_fn/opmd_policy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch

from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
from trinity.algorithm.utils import masked_loss
from trinity.algorithm.utils import aggregate_loss


@POLICY_LOSS_FN.register_module("opmd")
Expand All @@ -25,7 +25,7 @@ def __call__( # type: ignore
**kwargs,
) -> Tuple[torch.Tensor, Dict]:
pg_losses = -advantages * logprob
opmd_loss = masked_loss(pg_losses, action_mask, loss_agg_mode=self.loss_agg_mode)
opmd_loss = aggregate_loss(pg_losses, action_mask, loss_agg_mode=self.loss_agg_mode)
opmd_loss = opmd_loss / (1.0 + self.tau) # for regularization (w.r.t. current pi_theta)
return opmd_loss, {"opmd_loss": opmd_loss.detach().item()}

Expand Down
26 changes: 20 additions & 6 deletions trinity/algorithm/policy_loss_fn/ppo_policy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch

from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
from trinity.algorithm.utils import masked_loss, masked_mean
from trinity.algorithm.utils import aggregate_loss, masked_mean


@POLICY_LOSS_FN.register_module("ppo")
Expand All @@ -19,6 +19,7 @@ def __init__(
clip_range: Optional[float] = None,
clip_range_low: Optional[float] = None,
clip_range_high: Optional[float] = None,
clip_ratio_c: float = 3.0,
loss_agg_mode: Optional[str] = "token-mean",
) -> None:
super().__init__(backend=backend)
Expand All @@ -30,6 +31,8 @@ def __init__(
self.clip_range_high = clip_range
else:
self.clip_range_high = clip_range_high
self.clip_ratio_c = clip_ratio_c
assert clip_ratio_c > 1.0, "clip_ratio_c must be greater than 1.0."
assert self.clip_range_low is not None, "clip_range_low must be specified."
assert self.clip_range_high is not None, "clip_range_high must be specified."
self.loss_agg_mode = loss_agg_mode
Expand All @@ -43,20 +46,30 @@ def __call__( # type: ignore
**kwargs,
) -> Tuple[torch.Tensor, Dict]:
negative_approx_kl = logprob - old_logprob
# Clamp negative_approx_kl for stability
negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0)
ratio = torch.exp(negative_approx_kl)
ppo_kl = masked_mean(-negative_approx_kl, action_mask)

pg_losses = -advantages * ratio
pg_losses1 = -advantages * ratio
pg_losses2 = -advantages * torch.clamp(
ratio, 1.0 - self.clip_range_low, 1.0 + self.clip_range_high # type: ignore
)

pg_loss = masked_loss(
torch.max(pg_losses, pg_losses2), action_mask, loss_agg_mode=self.loss_agg_mode
clip_pg_losses1 = torch.maximum(pg_losses1, pg_losses2)

pg_clip_frac = masked_mean(torch.gt(pg_losses2, pg_losses1).float(), action_mask)

pg_losses3 = -advantages * self.clip_ratio_c
clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1)
pg_clipfrac_lower = masked_mean(
torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), action_mask
)
pg_clipfrac = masked_mean(torch.gt(pg_losses2, pg_losses).float(), action_mask)
pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)
pg_loss = aggregate_loss(pg_losses, action_mask, loss_agg_mode=self.loss_agg_mode)
metrics = {
"pg_clipfrac": pg_clipfrac.detach().item(),
"pg_clipfrac": pg_clip_frac.detach().item(),
"pg_clipfrac_lower": pg_clipfrac_lower.detach().item(),
"ppo_kl": ppo_kl.detach().item(),
"pg_loss": pg_loss.detach().item(),
}
Expand All @@ -66,5 +79,6 @@ def __call__( # type: ignore
def default_args(cls) -> Dict:
return {
"clip_range": 0.2,
"clip_ratio_c": 3.0,
"loss_agg_mode": "token-mean",
}
4 changes: 2 additions & 2 deletions trinity/algorithm/policy_loss_fn/sft_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch

from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
from trinity.algorithm.utils import masked_loss
from trinity.algorithm.utils import aggregate_loss


@POLICY_LOSS_FN.register_module("sft")
Expand All @@ -20,7 +20,7 @@ def __call__( # type: ignore
action_mask: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict]:
sft_loss = masked_loss(-logprob, action_mask, loss_agg_mode=self.loss_agg_mode)
sft_loss = aggregate_loss(-logprob, action_mask, loss_agg_mode=self.loss_agg_mode)

return sft_loss, {"sft_loss": sft_loss.detach().item()}

Expand Down
4 changes: 2 additions & 2 deletions trinity/algorithm/policy_loss_fn/sppo_loss_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch

from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
from trinity.algorithm.utils import masked_loss, masked_mean
from trinity.algorithm.utils import aggregate_loss, masked_mean


@POLICY_LOSS_FN.register_module("sppo")
Expand Down Expand Up @@ -41,7 +41,7 @@ def __call__( # type: ignore
is_in_range = (ratio >= (1 / (1 + self.epsilon))) * (ratio <= (1 + self.epsilon))
is_clipped_mask = ~is_in_range
pg_losses = -advantages * (logprob - old_logprob) * is_in_range.float()
pg_loss = masked_loss(pg_losses, action_mask, loss_agg_mode=self.loss_agg_mode)
pg_loss = aggregate_loss(pg_losses, action_mask, loss_agg_mode=self.loss_agg_mode)
pg_clipfrac = masked_mean(is_clipped_mask.float(), action_mask)
metrics = {
"pg_clipfrac": pg_clipfrac.item(),
Expand Down
Loading