Skip to content

Commit e31f815

Browse files
committed
Loss updates
1 parent 1a6d6df commit e31f815

File tree

2 files changed

+24
-28
lines changed

2 files changed

+24
-28
lines changed

apps/grpo/main.py

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -49,24 +49,21 @@ def compute_logprobs(
4949

5050
class SimpleGRPOLoss(nn.Module):
5151
"""Simplified GRPO Loss for simplified single step updates
52-
Copied from https://github.com/pytorch/torchtune/blob/main/torchtune/dev/grpo/loss.py.
52+
Inspired by the Hugging Face TRL implementation:
53+
https://github.com/huggingface/trl/blob/417915a3e4d3e3bc8d7b196594308b8eabf928be/trl/trainer/grpo_trainer.py#L1624.
5354
"""
5455

5556
def __init__(self, beta: float = 0.1):
5657
super().__init__()
5758
self.beta = beta
5859

5960
def forward(self, logprobs, ref_logprobs, advantages, padding_mask):
60-
per_token_kl = (
61-
torch.exp(ref_logprobs.detach() - logprobs)
62-
- (ref_logprobs.detach() - logprobs)
63-
- 1
64-
)
61+
kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1
6562
per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages
66-
per_token_loss = -(per_token_policy_loss - self.beta * per_token_kl)
63+
per_token_loss = -(per_token_policy_loss - self.beta * kl)
6764
loss = (
68-
(per_token_loss * padding_mask).sum(dim=1)
69-
/ (padding_mask.sum(dim=1) + 1e-8)
65+
((per_token_loss * padding_mask).sum(dim=1))
66+
/ (padding_mask.sum(dim=1).clamp(min=1.0))
7067
).mean()
7168
return loss
7269

@@ -211,21 +208,21 @@ def _qwen3_hf_to_vllm(self, saved_sd):
211208
return load_sd
212209

213210
@endpoint
214-
async def train_step(self, batch: list[Episode]):
215-
batch = batch[self.dp_rank]
216-
pad_id = batch[0].pad_id
211+
async def train_step(self, batch: list[list[Episode]]):
212+
microbatch = batch[self.dp_rank]
213+
pad_id = microbatch[0].pad_id
217214

218215
# prepare batch
219-
request = [e.request_tensor for e in batch]
216+
request = [e.request_tensor for e in microbatch]
220217
request = torch.stack(request).to(self.device) # [b x s]
221218

222-
response = [e.response_tensor for e in batch]
219+
response = [e.response_tensor for e in microbatch]
223220
response = torch.stack(response).to(self.device) # [b x s]
224221

225-
ref_logprobs = [e.ref_logprobs for e in batch]
222+
ref_logprobs = [e.ref_logprobs for e in microbatch]
226223
ref_logprobs = torch.stack(ref_logprobs).to(self.device).squeeze() # [b x s]
227224

228-
advantages = [e.advantage for e in batch]
225+
advantages = [e.advantage for e in microbatch]
229226
advantages = torch.tensor(advantages).to(self.device).unsqueeze(-1) # [b x 1]
230227
del batch
231228

@@ -522,10 +519,10 @@ async def continuous_training():
522519
)
523520

524521

525-
@parse
526-
def recipe_main(cfg: DictConfig) -> None:
527-
asyncio.run(main(cfg))
522+
if __name__ == "__main__":
528523

524+
@parse
525+
def _main(cfg):
526+
asyncio.run(main(cfg))
529527

530-
if __name__ == "__main__":
531-
recipe_main()
528+
_main() # @parse grabs the cfg from CLI

src/forge/actors/policy.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,8 @@
1111
from collections.abc import Mapping
1212
from copy import copy
1313
from dataclasses import asdict, dataclass, field, fields
14-
from typing import Dict, List
1514

1615
import torch
17-
18-
from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh
19-
20-
from forge.data.sharding import VLLMSharding
21-
from forge.interfaces import Policy as PolicyInterface
22-
from forge.types import ProcessConfig
2316
from monarch.actor import current_rank, endpoint, ProcMesh
2417
from torchstore import MultiProcessStore
2518
from torchstore._state_dict_utils import DELIM
@@ -44,6 +37,12 @@
4437
from vllm.v1.structured_output import StructuredOutputManager
4538
from vllm.worker.worker_base import WorkerWrapperBase
4639

40+
from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh
41+
42+
from forge.data.sharding import VLLMSharding
43+
from forge.interfaces import Policy as PolicyInterface
44+
from forge.types import ProcessConfig
45+
4746

4847
@dataclass
4948
class SamplingConfig:

0 commit comments

Comments
 (0)