Skip to content

Commit d9d2135

Browse files
authored
Add loss_agg_mode for kl and entropy_loss (#388)
1 parent bf5c134 commit d9d2135

File tree

17 files changed

+107
-50
lines changed

17 files changed

+107
-50
lines changed

.github/workflows/sphinx-doc.yaml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,17 @@ jobs:
2121
OS: ${{ matrix.os }}
2222
PYTHON: '3.10'
2323
steps:
24+
- name: Free up disk space
25+
run: |
26+
sudo rm -rf /usr/share/dotnet /opt/ghc /usr/local/lib/android
27+
docker system prune -af
2428
- name: Checkout PR branch
2529
if: github.event_name == 'pull_request'
2630
uses: actions/checkout@v4
2731
with:
2832
repository: ${{ github.event.pull_request.head.repo.full_name }}
2933
ref: ${{ github.event.pull_request.head.ref }}
3034
fetch-depth: 0
31-
3235
- name: Checkout main branch
3336
if: github.event_name != 'pull_request'
3437
uses: actions/checkout@v4
@@ -41,7 +44,7 @@ jobs:
4144
python-version: ${{ matrix.python-version }}
4245
- name: Install Dependencies
4346
run: |
44-
pip install -q -e .[doc]
47+
pip install -e .[doc]
4548
- id: build
4649
name: Build Documentation
4750
run: |

tests/algorithm/policy_loss_test.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,17 @@ def test_ppo_policy_loss(self):
3535
policy_loss_fn_args = policy_loss_fn_cls.default_args()
3636
policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args)
3737
loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch)
38-
ppo_loss = torch.tensor(0.28560468554496765)
38+
ppo_loss = torch.tensor(0.26889559626579285)
3939
pg_clipfrac = torch.tensor(0.3541666567325592)
4040
ppo_kl = torch.tensor(-0.21663446724414825)
41+
pg_clipfrac_lower = torch.tensor(0.0625)
4142
self.assertTrue(torch.allclose(loss, ppo_loss))
4243
self.assertTrue(torch.allclose(torch.tensor(metrics["pg_clipfrac"]), pg_clipfrac))
4344
self.assertTrue(torch.allclose(torch.tensor(metrics["ppo_kl"]), ppo_kl))
4445
self.assertTrue(torch.allclose(torch.tensor(metrics["pg_loss"]), ppo_loss))
46+
self.assertTrue(
47+
torch.allclose(torch.tensor(metrics["pg_clipfrac_lower"]), pg_clipfrac_lower)
48+
)
4549

4650
def test_gspo_policy_loss(self):
4751
policy_loss_fn_cls = POLICY_LOSS_FN.get("gspo")
@@ -52,7 +56,6 @@ def test_gspo_policy_loss(self):
5256
pg_clipfrac_expected = torch.tensor(0.375)
5357
ppo_kl_seq_expected = torch.tensor(-0.21027061343193054)
5458
ppo_kl_expected = torch.tensor(-0.21663446724414825)
55-
print(f"{loss.item()=}, {metrics=}")
5659
self.assertTrue(torch.allclose(loss, gspo_loss_expected))
5760
self.assertTrue(torch.allclose(torch.tensor(metrics["pg_clipfrac"]), pg_clipfrac_expected))
5861
self.assertTrue(torch.allclose(torch.tensor(metrics["ppo_kl_seq"]), ppo_kl_seq_expected))
@@ -97,14 +100,18 @@ def test_mix_policy_loss(self):
97100
policy_loss_fn_args = policy_loss_fn_cls.default_args()
98101
policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args)
99102
loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch)
100-
mix_loss = torch.tensor(0.6581965088844299)
103+
mix_loss = torch.tensor(0.6298247575759888)
101104
pg_clipfrac = torch.tensor(0.7777777910232544)
102105
ppo_kl = torch.tensor(-1.0737695693969727)
103-
pg_loss = torch.tensor(0.7236452102661133)
106+
pg_loss = torch.tensor(0.6921210885047913)
104107
sft_loss = torch.tensor(0.06915830634534359)
108+
pg_clipfrac_lower = torch.tensor(0.2222222238779068)
105109
self.assertTrue(torch.allclose(loss, mix_loss))
106110
self.assertTrue(torch.allclose(torch.tensor(metrics["usual/pg_clipfrac"]), pg_clipfrac))
107111
self.assertTrue(torch.allclose(torch.tensor(metrics["usual/ppo_kl"]), ppo_kl))
108112
self.assertTrue(torch.allclose(torch.tensor(metrics["usual/pg_loss"]), pg_loss))
113+
self.assertTrue(
114+
torch.allclose(torch.tensor(metrics["usual/pg_clipfrac_lower"]), pg_clipfrac_lower)
115+
)
109116
self.assertTrue(torch.allclose(torch.tensor(metrics["expert/sft_loss"]), sft_loss))
110117
self.assertTrue(torch.allclose(torch.tensor(metrics["loss"]), mix_loss))

trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import torch
55

6-
from trinity.algorithm.utils import masked_mean
6+
from trinity.algorithm.utils import aggregate_loss
77
from trinity.utils.registry import Registry
88

99
ENTROPY_LOSS_FN = Registry("entropy_loss_fn")
@@ -53,9 +53,10 @@ def __call__(
5353
self,
5454
entropy: torch.Tensor,
5555
action_mask: torch.Tensor,
56+
loss_agg_mode: str = "token-mean",
5657
**kwargs,
5758
) -> Tuple[torch.Tensor, Dict]:
58-
entropy_loss = masked_mean(entropy, action_mask)
59+
entropy_loss = aggregate_loss(entropy, action_mask, loss_agg_mode=loss_agg_mode)
5960
return entropy_loss * self.entropy_coef, {"entropy_loss": entropy_loss.detach().item()}
6061

6162

@@ -73,6 +74,7 @@ def __call__(
7374
entropy: torch.Tensor,
7475
action_mask: torch.Tensor,
7576
expert_mask: torch.Tensor = None,
77+
loss_agg_mode: str = "token-mean",
7678
**kwargs,
7779
) -> Tuple[torch.Tensor, Dict]:
7880
if expert_mask is None:
@@ -82,7 +84,7 @@ def __call__(
8284
), f"Error: {len(expert_mask)=} != {entropy.shape[0]=}"
8385
entropy = entropy[~expert_mask]
8486
action_mask = action_mask[~expert_mask]
85-
entropy_loss = masked_mean(entropy, action_mask)
87+
entropy_loss = aggregate_loss(entropy, action_mask, loss_agg_mode=loss_agg_mode)
8688
return entropy_loss * self.entropy_coef, {"entropy_loss": entropy_loss.detach().item()}
8789

8890

trinity/algorithm/kl_fn/kl_fn.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import torch
1313

14-
from trinity.algorithm.utils import masked_mean
14+
from trinity.algorithm.utils import aggregate_loss, masked_mean
1515
from trinity.utils.registry import Registry
1616

1717
KL_FN = Registry("kl_fn")
@@ -81,10 +81,11 @@ def calculate_kl_loss(
8181
logprob: torch.Tensor,
8282
ref_logprob: torch.Tensor,
8383
response_mask: torch.Tensor,
84+
loss_agg_mode: str,
8485
) -> Tuple[torch.Tensor, Dict]:
8586
"""Compute KL loss."""
8687
kl = self.calculate_kl(logprob, ref_logprob)
87-
kl_loss = masked_mean(kl, response_mask)
88+
kl_loss = aggregate_loss(kl, response_mask, loss_agg_mode=loss_agg_mode)
8889
metrics = {
8990
"kl_loss": kl_loss.detach().item(),
9091
"kl_coef": self.kl_coef,
@@ -119,6 +120,7 @@ def calculate_kl_loss(
119120
logprob: torch.Tensor,
120121
ref_logprob: torch.Tensor,
121122
response_mask: torch.Tensor,
123+
loss_agg_mode: str,
122124
) -> Tuple[torch.Tensor, Dict]:
123125
# return a zero tensor
124126
return torch.tensor(0.0), {}
@@ -155,6 +157,20 @@ def calculate_kl(self, logprob: torch.Tensor, ref_logprob: torch.Tensor) -> torc
155157
return logr.exp() - 1 - logr
156158

157159

160+
@KL_FN.register_module("low_var_kl")
161+
class LowVarKLFn(KLFn):
162+
"""
163+
Low Variance KL function.
164+
"""
165+
166+
def calculate_kl(self, logprob: torch.Tensor, ref_logprob: torch.Tensor) -> torch.Tensor:
167+
kl = ref_logprob - logprob
168+
kl = torch.clamp(kl, min=-20, max=20)
169+
ratio = torch.exp(kl)
170+
kld = (ratio - kl - 1).contiguous()
171+
return torch.clamp(kld, min=-10, max=10)
172+
173+
158174
@KL_FN.register_module("abs")
159175
class AbsFn(KLFn):
160176
"""

trinity/algorithm/policy_loss_fn/chord_policy_loss.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
99
from trinity.algorithm.policy_loss_fn.ppo_policy_loss import PPOPolicyLossFn
1010
from trinity.algorithm.policy_loss_fn.sft_loss import SFTLossFn
11-
from trinity.algorithm.utils import masked_loss
11+
from trinity.algorithm.utils import aggregate_loss
1212

1313

1414
def mu_schedule_function(
@@ -48,7 +48,7 @@ def __call__( # type: ignore
4848
**kwargs,
4949
) -> Tuple[torch.Tensor, Dict]:
5050
token_prob = torch.exp(logprob)
51-
sft_loss = masked_loss(
51+
sft_loss = aggregate_loss(
5252
-logprob * token_prob.detach(), action_mask, loss_agg_mode=self.loss_agg_mode
5353
)
5454
return sft_loss, {"sft_is_loss": sft_loss.detach().item()}
@@ -94,7 +94,7 @@ def __call__( # type: ignore
9494

9595
weighted_phi = phi_function(token_prob)
9696

97-
sft_loss = masked_loss(
97+
sft_loss = aggregate_loss(
9898
-logprob * weighted_phi.detach(), action_mask, loss_agg_mode=self.loss_agg_mode
9999
)
100100
return sft_loss, {"sft_phi_loss": sft_loss.detach().item()}
@@ -141,8 +141,9 @@ def __init__(
141141
ngpus_trainer: int = 1,
142142
train_batch_size_usual: int = 1,
143143
train_batch_size_expert: int = 1,
144-
sft_loss_agg_mode: str = "token-mean",
145-
grpo_loss_agg_mode: str = "token-mean",
144+
loss_agg_mode: str = "token-mean",
145+
sft_loss_agg_mode: Optional[str] = None,
146+
grpo_loss_agg_mode: Optional[str] = None,
146147
) -> None:
147148
super().__init__(backend=backend)
148149
self.mu_warmup_steps = mu_warmup_steps
@@ -159,12 +160,12 @@ def __init__(
159160
clip_range=clip_range,
160161
clip_range_low=clip_range_low,
161162
clip_range_high=clip_range_high,
162-
loss_agg_mode=grpo_loss_agg_mode,
163+
loss_agg_mode=grpo_loss_agg_mode or loss_agg_mode,
163164
)
164165
if enable_phi_function:
165-
self.sft_loss_fn = SFTPhiLossFn(loss_agg_mode=sft_loss_agg_mode)
166+
self.sft_loss_fn = SFTPhiLossFn(loss_agg_mode=sft_loss_agg_mode or loss_agg_mode)
166167
else:
167-
self.sft_loss_fn = SFTLossFn(loss_agg_mode=sft_loss_agg_mode)
168+
self.sft_loss_fn = SFTLossFn(loss_agg_mode=sft_loss_agg_mode or loss_agg_mode)
168169

169170
def __call__( # type: ignore
170171
self,
@@ -255,4 +256,5 @@ def default_args(cls) -> Dict:
255256
"mu_valley": 0.1,
256257
"clip_range": 0.2,
257258
"enable_phi_function": True,
259+
"loss_agg_mode": "token-mean",
258260
}

trinity/algorithm/policy_loss_fn/cispo_policy_loss.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch
88

99
from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
10-
from trinity.algorithm.utils import masked_loss, masked_mean
10+
from trinity.algorithm.utils import aggregate_loss, masked_mean
1111

1212

1313
@POLICY_LOSS_FN.register_module("cispo")
@@ -63,7 +63,7 @@ def __call__( # type: ignore
6363

6464
cispo_loss = -advantages * ratio_clamped.detach() * mask.detach() * logprob
6565

66-
loss = masked_loss(cispo_loss, action_mask, loss_agg_mode=self.loss_agg_mode)
66+
loss = aggregate_loss(cispo_loss, action_mask, loss_agg_mode=self.loss_agg_mode)
6767
unmasked_frac = masked_mean(mask, action_mask)
6868

6969
metrics = {

trinity/algorithm/policy_loss_fn/gspo_policy_loss.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import torch
99

1010
from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
11-
from trinity.algorithm.utils import masked_loss, masked_mean
11+
from trinity.algorithm.utils import aggregate_loss, masked_mean
1212

1313

1414
@POLICY_LOSS_FN.register_module("gspo")
@@ -54,7 +54,7 @@ def __call__( # type: ignore
5454
ratio, 1.0 - self.clip_range_low, 1.0 + self.clip_range_high
5555
) # [batch_size, seq_len]
5656

57-
pg_loss = masked_loss(
57+
pg_loss = aggregate_loss(
5858
values=torch.max(pg_losses, pg_losses_clipped),
5959
mask=action_mask,
6060
loss_agg_mode=self.loss_agg_mode,

trinity/algorithm/policy_loss_fn/mix_policy_loss.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,9 @@ def __init__(
3737
ngpus_trainer: int = 1,
3838
train_batch_size_usual: int = 1,
3939
train_batch_size_expert: int = 1,
40-
sft_loss_agg_mode: str = "token-mean",
41-
grpo_loss_agg_mode: str = "token-mean",
40+
loss_agg_mode: str = "token-mean",
41+
sft_loss_agg_mode: Optional[str] = None,
42+
grpo_loss_agg_mode: Optional[str] = None,
4243
) -> None:
4344
super().__init__(backend=backend)
4445
self.mu = mu
@@ -51,9 +52,9 @@ def __init__(
5152
clip_range=clip_range,
5253
clip_range_low=clip_range_low,
5354
clip_range_high=clip_range_high,
54-
loss_agg_mode=grpo_loss_agg_mode,
55+
loss_agg_mode=grpo_loss_agg_mode or loss_agg_mode,
5556
)
56-
self.sft_loss_fn = SFTLossFn(loss_agg_mode=sft_loss_agg_mode)
57+
self.sft_loss_fn = SFTLossFn(loss_agg_mode=sft_loss_agg_mode or loss_agg_mode)
5758

5859
def __call__( # type: ignore
5960
self,
@@ -125,4 +126,5 @@ def default_args(cls) -> Dict:
125126
return {
126127
"mu": 0.1,
127128
"clip_range": 0.2,
129+
"loss_agg_mode": "token-mean",
128130
}

trinity/algorithm/policy_loss_fn/opmd_policy_loss.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66

77
from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
8-
from trinity.algorithm.utils import masked_loss
8+
from trinity.algorithm.utils import aggregate_loss
99

1010

1111
@POLICY_LOSS_FN.register_module("opmd")
@@ -25,7 +25,7 @@ def __call__( # type: ignore
2525
**kwargs,
2626
) -> Tuple[torch.Tensor, Dict]:
2727
pg_losses = -advantages * logprob
28-
opmd_loss = masked_loss(pg_losses, action_mask, loss_agg_mode=self.loss_agg_mode)
28+
opmd_loss = aggregate_loss(pg_losses, action_mask, loss_agg_mode=self.loss_agg_mode)
2929
opmd_loss = opmd_loss / (1.0 + self.tau) # for regularization (w.r.t. current pi_theta)
3030
return opmd_loss, {"opmd_loss": opmd_loss.detach().item()}
3131

trinity/algorithm/policy_loss_fn/ppo_policy_loss.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import torch
99

1010
from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
11-
from trinity.algorithm.utils import masked_loss, masked_mean
11+
from trinity.algorithm.utils import aggregate_loss, masked_mean
1212

1313

1414
@POLICY_LOSS_FN.register_module("ppo")
@@ -19,6 +19,7 @@ def __init__(
1919
clip_range: Optional[float] = None,
2020
clip_range_low: Optional[float] = None,
2121
clip_range_high: Optional[float] = None,
22+
clip_ratio_c: float = 3.0,
2223
loss_agg_mode: Optional[str] = "token-mean",
2324
) -> None:
2425
super().__init__(backend=backend)
@@ -30,6 +31,8 @@ def __init__(
3031
self.clip_range_high = clip_range
3132
else:
3233
self.clip_range_high = clip_range_high
34+
self.clip_ratio_c = clip_ratio_c
35+
assert clip_ratio_c > 1.0, "clip_ratio_c must be greater than 1.0."
3336
assert self.clip_range_low is not None, "clip_range_low must be specified."
3437
assert self.clip_range_high is not None, "clip_range_high must be specified."
3538
self.loss_agg_mode = loss_agg_mode
@@ -43,20 +46,30 @@ def __call__( # type: ignore
4346
**kwargs,
4447
) -> Tuple[torch.Tensor, Dict]:
4548
negative_approx_kl = logprob - old_logprob
49+
# Clamp negative_approx_kl for stability
50+
negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0)
4651
ratio = torch.exp(negative_approx_kl)
4752
ppo_kl = masked_mean(-negative_approx_kl, action_mask)
4853

49-
pg_losses = -advantages * ratio
54+
pg_losses1 = -advantages * ratio
5055
pg_losses2 = -advantages * torch.clamp(
5156
ratio, 1.0 - self.clip_range_low, 1.0 + self.clip_range_high # type: ignore
5257
)
5358

54-
pg_loss = masked_loss(
55-
torch.max(pg_losses, pg_losses2), action_mask, loss_agg_mode=self.loss_agg_mode
59+
clip_pg_losses1 = torch.maximum(pg_losses1, pg_losses2)
60+
61+
pg_clip_frac = masked_mean(torch.gt(pg_losses2, pg_losses1).float(), action_mask)
62+
63+
pg_losses3 = -advantages * self.clip_ratio_c
64+
clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1)
65+
pg_clipfrac_lower = masked_mean(
66+
torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), action_mask
5667
)
57-
pg_clipfrac = masked_mean(torch.gt(pg_losses2, pg_losses).float(), action_mask)
68+
pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)
69+
pg_loss = aggregate_loss(pg_losses, action_mask, loss_agg_mode=self.loss_agg_mode)
5870
metrics = {
59-
"pg_clipfrac": pg_clipfrac.detach().item(),
71+
"pg_clipfrac": pg_clip_frac.detach().item(),
72+
"pg_clipfrac_lower": pg_clipfrac_lower.detach().item(),
6073
"ppo_kl": ppo_kl.detach().item(),
6174
"pg_loss": pg_loss.detach().item(),
6275
}
@@ -66,5 +79,6 @@ def __call__( # type: ignore
6679
def default_args(cls) -> Dict:
6780
return {
6881
"clip_range": 0.2,
82+
"clip_ratio_c": 3.0,
6983
"loss_agg_mode": "token-mean",
7084
}

0 commit comments

Comments
 (0)