Skip to content

Commit 2657324

Browse files
committed
Merge branch 'main' into weight-loading
2 parents 17e0c05 + 1223473 commit 2657324

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+2194
-1371
lines changed

.github/workflows/unit_test.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ jobs:
3232
eval "$(ssh-agent -s)"
3333
ssh-add - <<< '${{ secrets.FORGE_GITHUB_CI_FOR_TORCHSTORE }}'
3434
python -m pip install git+ssh://[email protected]/meta-pytorch/torchstore.git
35+
- name: Install torchtitan
36+
run: |
37+
pip install --pre torchtitan==0.1.0.dev20250826+cpu --extra-index-url https://download.pytorch.org/whl/nightly/cpu
38+
pip install tyro
3539
- name: Install dependencies
3640
run: python -m pip install --no-build-isolation -e ".[dev]"
3741
- name: Run unit tests with coverage

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,3 +193,6 @@ cover/
193193
wandb/
194194

195195
assets/wheels/vllm*.whl
196+
197+
# DCP artifacts
198+
model_state_dict/

apps/grpo/main.py

Lines changed: 80 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
# Usage: python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml
88

99
import asyncio
10-
import time
1110
import uuid
1211
from dataclasses import dataclass
1312
from typing import Any, Callable
@@ -17,39 +16,20 @@
1716
import torchstore as ts
1817
from datasets import load_dataset
1918
from forge.actors.policy import Policy
20-
from forge.actors.reference_model import ReferenceModel # noqa: F401
19+
from forge.actors.reference_model import ReferenceModel
2120
from forge.actors.replay_buffer import ReplayBuffer
2221
from forge.actors.torchstore_utils import get_param_key
23-
from forge.actors.trainer import _qwen3_hf_to_vllm
22+
from forge.actors.trainer import RLTrainer
2423
from forge.cli.config import parse
2524
from forge.controller.actor import ForgeActor
2625
from forge.controller.provisioner import shutdown
2726
from forge.data.rewards import MathReward, ThinkingReward
28-
from forge.losses.grpo_loss import SimpleGRPOLoss
2927
from forge.util.metric_logging import get_metric_logger
3028
from monarch.actor import endpoint
3129
from omegaconf import DictConfig
32-
from torchstore.state_dict_utils import DELIM
33-
from torchtitan.config.job_config import Model as TitanJobModelConfig
34-
from transformers import AutoModelForCausalLM
3530
from vllm.transformers_utils.tokenizer import get_tokenizer
3631

3732

38-
def compute_logprobs(
39-
logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 1.0
40-
) -> torch.Tensor:
41-
context_length = logits.shape[1] - input_ids.shape[1]
42-
43-
# Truncate request logits and drop last
44-
logits = logits[:, context_length - 1 : -1]
45-
46-
# Compute logprobs
47-
logprobs = torch.log_softmax(logits / temperature, dim=-1)
48-
logprobs = torch.gather(logprobs, 2, input_ids.unsqueeze(-1)).squeeze(-1)
49-
50-
return logprobs
51-
52-
5333
@dataclass
5434
class Episode:
5535
# TODO: add adtional layer for multi-turn
@@ -118,64 +98,64 @@ def new_group(
11898
return cls(str(group_id), episodes)
11999

120100

121-
@dataclass
122-
class Trainer(ForgeActor):
123-
"""GRPO Trainer implementation for policy optimization."""
124-
125-
model_name: str
126-
learning_rate: float = 1e-5
127-
beta: float = 0.1
128-
device: torch.device | None = None
129-
state_dict_key: str = "model_state_dict"
130-
dp_rank: int = 0 # TODO: support data parallelism, hard code it for now
131-
132-
@endpoint
133-
async def setup(self):
134-
if self.device is None:
135-
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
136-
137-
self.model = AutoModelForCausalLM.from_pretrained(
138-
self.model_name,
139-
dtype=torch.bfloat16,
140-
trust_remote_code=True,
141-
).to(self.device)
142-
self.model.train()
143-
144-
self.optimizer = torch.optim.AdamW(
145-
self.model.parameters(), lr=self.learning_rate
146-
)
147-
self.optimizer.zero_grad()
101+
def collate(batches: list[list[Episode]]):
102+
inputs = []
103+
targets = []
104+
for batch in batches:
105+
request = [e.request_tensor for e in batch]
106+
request = torch.stack(request) # [b x s]
148107

149-
self.loss = SimpleGRPOLoss(self.beta)
108+
response = [e.response_tensor for e in batch]
109+
response = torch.stack(response) # [b x s]
150110

151-
self.logger.info(f"Trainer model initialized on {self.device}")
111+
ref_logprobs = [e.ref_logprobs for e in batch]
112+
ref_logprobs = torch.stack(ref_logprobs).squeeze() # [b x s]
152113

153-
@endpoint
154-
async def train_step(self, batch: list[list[Episode]]):
155-
microbatch = batch[self.dp_rank]
156-
pad_id = microbatch[0].pad_id
114+
advantages = [e.advantage for e in batch]
115+
advantages = torch.tensor(advantages).unsqueeze(-1) # [b x 1]
157116

158-
# prepare batch
159-
request = [e.request_tensor for e in microbatch]
160-
request = torch.stack(request).to(self.device) # [b x s]
117+
pad_id = batch[0].pad_id
118+
mask = response != pad_id
161119

162-
response = [e.response_tensor for e in microbatch]
163-
response = torch.stack(response).to(self.device) # [b x s]
120+
input = {"tokens": torch.cat([request, response], dim=1)}
121+
target = {
122+
"response": response,
123+
"ref_logprobs": ref_logprobs,
124+
"advantages": advantages,
125+
"padding_mask": mask,
126+
}
127+
inputs.append(input)
128+
targets.append(target)
129+
return inputs, targets
164130

165-
ref_logprobs = [e.ref_logprobs for e in microbatch]
166-
ref_logprobs = torch.stack(ref_logprobs).to(self.device).squeeze() # [b x s]
167131

168-
advantages = [e.advantage for e in microbatch]
169-
advantages = torch.tensor(advantages).to(self.device).unsqueeze(-1) # [b x 1]
170-
del batch
132+
def compute_logprobs(
133+
logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 1.0
134+
) -> torch.Tensor:
135+
context_length = logits.shape[1] - input_ids.shape[1]
136+
logits = logits[:, context_length - 1 : -1]
137+
logprobs = torch.log_softmax(logits / temperature, dim=-1).to(input_ids.device)
138+
logprobs = torch.gather(logprobs, 2, input_ids.unsqueeze(-1)).squeeze(-1)
139+
return logprobs
171140

172-
input_ids = torch.cat([request, response], dim=1)
173-
mask = input_ids != pad_id
174-
logits = self.model(input_ids=input_ids, attention_mask=mask).logits
175-
logprobs = compute_logprobs(logits, response)
176-
del logits
177141

178-
mask = response != pad_id
142+
def simple_grpo_loss(
143+
logits: torch.Tensor,
144+
response: torch.Tensor,
145+
ref_logprobs: torch.Tensor,
146+
advantages: torch.Tensor,
147+
padding_mask: torch.Tensor,
148+
beta: float = 0.1,
149+
) -> torch.Tensor:
150+
logprobs = compute_logprobs(logits, response)
151+
kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1
152+
per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages
153+
per_token_loss = -(per_token_policy_loss - beta * kl)
154+
loss = (
155+
((per_token_loss * padding_mask).sum(dim=1))
156+
/ (padding_mask.sum(dim=1).clamp(min=1.0))
157+
).mean()
158+
return loss
179159
loss = self.loss(logprobs, ref_logprobs, advantages, mask)
180160
loss.backward()
181161
self.optimizer.step()
@@ -223,38 +203,6 @@ async def compute(self, group: Group) -> list[float]:
223203
return advantages.squeeze(0).tolist()
224204

225205

226-
class RefModel(ForgeActor):
227-
def __init__(self, model_name, device: torch.device | None = None):
228-
super().__init__()
229-
self.model_name = model_name
230-
231-
if device is None:
232-
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
233-
else:
234-
self.device = device
235-
236-
self.model = AutoModelForCausalLM.from_pretrained(
237-
model_name,
238-
dtype=torch.bfloat16,
239-
trust_remote_code=True,
240-
).to(self.device)
241-
self.model.eval()
242-
243-
self.logger.info(f"Model initialized on {self.device}")
244-
245-
@endpoint
246-
async def forward(self, episode: Episode) -> torch.Tensor:
247-
req, res = episode.request_tensor, episode.response_tensor
248-
input_ids = torch.cat([req, res]).to(self.device).unsqueeze(0)
249-
mask = input_ids != episode.pad_id
250-
251-
with torch.inference_mode():
252-
logits = self.model(input_ids=input_ids, attention_mask=mask).logits
253-
254-
input_ids = input_ids[:, len(req) :]
255-
return compute_logprobs(logits, input_ids)
256-
257-
258206
@dataclass
259207
class DatasetActor(ForgeActor):
260208
"""Actor wrapper for HuggingFace dataset to provide async interface."""
@@ -309,10 +257,7 @@ async def pad_token(self):
309257

310258
async def main(cfg: DictConfig):
311259
"""Main GRPO training loop with rollout and training processes."""
312-
titan_model = TitanJobModelConfig(name="qwen3", flavor="1.7B")
313-
# Get parameters from config with fallbacks
314260
group_size = cfg.group_size
315-
model = cfg.model
316261
max_req_tokens = cfg.max_req_tokens
317262
max_res_tokens = cfg.max_res_tokens
318263
mlogger = get_metric_logger(
@@ -322,7 +267,7 @@ async def main(cfg: DictConfig):
322267
)
323268

324269
# ---- Setup services ---- #
325-
await ts.initialize()
270+
await ts.initialize(strategy=ts.ControllerStorageVolumes())
326271
(
327272
dataloader,
328273
policy,
@@ -334,17 +279,18 @@ async def main(cfg: DictConfig):
334279
) = await asyncio.gather(
335280
DatasetActor.options(**cfg.services.dataset).as_service(**cfg.dataset),
336281
Policy.options(**cfg.services.policy).as_service(**cfg.policy),
337-
Trainer.options(**cfg.services.trainer).as_service(**cfg.trainer),
282+
RLTrainer.options(**cfg.services.trainer).as_service(
283+
**cfg.trainer, loss=simple_grpo_loss
284+
),
338285
ReplayBuffer.options(**cfg.services.replay_buffer).as_service(
339-
**cfg.replay_buffer
286+
**cfg.replay_buffer, collate=collate
340287
),
341288
ComputeAdvantages.options(**cfg.services.compute_advantages).as_service(),
342-
RefModel.options(**cfg.services.ref_model).as_service(**cfg.ref_model),
289+
ReferenceModel.options(**cfg.services.ref_model).as_service(**cfg.ref_model),
343290
RewardActor.options(**cfg.services.reward_actor).as_service(
344291
reward_functions=[MathReward(), ThinkingReward()]
345292
),
346293
)
347-
348294
print("All services initialized successfully!")
349295

350296
# ---- Core RL loops ---- #
@@ -358,6 +304,7 @@ async def continuous_rollouts():
358304
return
359305
prompt, target = sample["request"], sample["target"]
360306
responses = await policy.generate.choose(prompt)
307+
# TODO: this shall be part of the responses metadata instead of a separate call
361308
version = await policy.get_version.choose()
362309
group = Group.new_group(
363310
group_id=rollout_count,
@@ -370,20 +317,36 @@ async def continuous_rollouts():
370317
target=target,
371318
)
372319

373-
# TODO: Parallelize the following calculation
374-
for episode, response in zip(group.episodes, responses.outputs):
375-
episode.request_tokens = responses.prompt_token_ids
320+
input_ids = torch.ones(
321+
(group_size, max_req_tokens + max_req_tokens),
322+
dtype=torch.long,
323+
device="cuda",
324+
)
325+
# Populate episode info and calculate rewards
326+
for i, (episode, response) in enumerate(zip(group.episodes, responses)):
327+
episode.request_tokens = response.prompt_ids
376328
episode.response_tokens = response.token_ids
377329
episode.response = response.text
378-
episode.ref_logprobs = await ref_model.forward.choose(episode)
330+
input_ids[i, :max_req_tokens] = episode.request_tensor
331+
input_ids[i, max_req_tokens:] = episode.response_tensor
379332
episode.reward = await reward_actor.evaluate_response.choose(
380333
prompt=prompt, response=response.text, target=target
381334
)
335+
336+
# Calculate reference logprobs
337+
ref_logits = await ref_model.forward.choose(input_ids)
338+
ref_logprobs = compute_logprobs(ref_logits, input_ids[:, max_req_tokens:])
339+
for i, episode in enumerate(group.episodes):
340+
episode.ref_logprobs = ref_logprobs[i]
341+
del ref_logits, ref_logprobs, input_ids
342+
343+
# Calculate advantages and add to replay buffer
382344
advantages = await compute_advantages.compute.choose(group)
383345
for episode, advantage in zip(group.episodes, advantages):
384346
episode.advantage = advantage
385347
await replay_buffer.add.choose(episode)
386348

349+
# Log metrics
387350
avg_response_len = (
388351
sum(len(e.response_tokens) for e in group.episodes) / group_size
389352
)
@@ -402,7 +365,8 @@ async def continuous_training():
402365
if batch is None:
403366
await asyncio.sleep(0.1)
404367
else:
405-
loss = await trainer.train_step.choose(batch)
368+
inputs, targets = batch
369+
loss = await trainer.train_step.choose(inputs, targets)
406370
training_step += 1
407371
mlogger.log("loss/training_step", loss, training_step)
408372
start_time = time.perf_counter()

0 commit comments

Comments
 (0)